diff options
235 files changed, 8691 insertions, 3109 deletions
@@ -4,10 +4,10 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") # Load go bazel rules and gazelle. http_archive( name = "io_bazel_rules_go", - sha256 = "94f90feaa65c9cdc840cd21f67d967870b5943d684966a47569da8073e42063d", + sha256 = "db2b2d35293f405430f553bc7a865a8749a8ef60c30287e90d2b278c32771afe", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.0/rules_go-v0.22.0.tar.gz", - "https://github.com/bazelbuild/rules_go/releases/download/v0.22.0/rules_go-v0.22.0.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz", + "https://github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz", ], ) @@ -25,7 +25,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_depe go_rules_dependencies() go_register_toolchains( - go_version = "1.14", + go_version = "1.14.2", nogo = "@//:nogo", ) @@ -99,11 +99,11 @@ pip_install() # See releases at https://releases.bazel.build/bazel-toolchains.html http_archive( name = "bazel_toolchains", - sha256 = "b5a8039df7119d618402472f3adff8a1bd0ae9d5e253f53fcc4c47122e91a3d2", - strip_prefix = "bazel-toolchains-2.1.1", + sha256 = "239a1a673861eabf988e9804f45da3b94da28d1aff05c373b013193c315d9d9e", + strip_prefix = "bazel-toolchains-3.0.1", urls = [ - "https://github.com/bazelbuild/bazel-toolchains/releases/download/2.1.1/bazel-toolchains-2.1.1.tar.gz", - "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/2.1.1.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz", ], ) @@ -400,6 +400,20 @@ go_repository( version = "v0.20.0", ) +go_repository( + name = "org_uber_go_atomic", + importpath = "go.uber.org/atomic", + version = "v1.6.0", + sum = "h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=", +) + +go_repository( + name = "org_uber_go_multierr", + importpath = "go.uber.org/multierr", + version = "v1.5.0", + sum = "h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A=", +) + # BigQuery Dependencies for Benchmarks go_repository( name = "com_google_cloud_go", diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py index ca785a148..fc59cf505 100644 --- a/benchmarks/runner/__init__.py +++ b/benchmarks/runner/__init__.py @@ -19,6 +19,7 @@ import logging import pkgutil import pydoc import re +import subprocess import sys import types from typing import List @@ -125,9 +126,8 @@ def run_gcp(ctx, image_file: str, zone_file: str, internal: bool, """Runs all benchmarks on GCP instances.""" # Resolve all files. - image = open(image_file).read().rstrip() - zone = open(zone_file).read().rstrip() - + image = subprocess.check_output([image_file]).rstrip() + zone = subprocess.check_output([zone_file]).rstrip() key_file = harness.make_key() producer = gcloud_producer.GCloudProducer( diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py index 194804527..e8289f6c5 100644 --- a/benchmarks/runner/commands.py +++ b/benchmarks/runner/commands.py @@ -101,15 +101,15 @@ class GCPCommand(RunCommand): image_file = click.core.Option( ("--image_file",), - help="The file containing the image for VMs.", + help="The binary that emits the GCP image.", default=os.path.join( - os.path.dirname(__file__), "../../tools/images/ubuntu1604.txt"), + os.path.dirname(__file__), "../../tools/images/ubuntu1604"), ) zone_file = click.core.Option( ("--zone_file",), - help="The file containing the GCP zone.", + help="The binary that emits the GCP zone.", default=os.path.join( - os.path.dirname(__file__), "../../tools/images/zone.txt"), + os.path.dirname(__file__), "../../tools/images/zone"), ) internal = click.core.Option( ("--internal/--no-internal",), diff --git a/pkg/context/context.go b/pkg/context/context.go index 23e009ef3..5319b6d8d 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -127,10 +127,6 @@ func (logContext) Value(key interface{}) interface{} { var bgContext = &logContext{Logger: log.Log()} // Background returns an empty context using the default logger. -// -// Users should be wary of using a Background context. Please tag any use with -// FIXME(b/38173783) and a note to remove this use. -// // Generally, one should use the Task as their context when available, or avoid // having to use a context in places where a Task is unavailable. // diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go index 7f41b4a27..43750360b 100644 --- a/pkg/eventchannel/event_test.go +++ b/pkg/eventchannel/event_test.go @@ -78,7 +78,7 @@ func TestMultiEmitter(t *testing.T) { for _, name := range names { m := testMessage{name: name} if _, err := me.Emit(m); err != nil { - t.Fatal("me.Emit(%v) failed: %v", m, err) + t.Fatalf("me.Emit(%v) failed: %v", m, err) } } @@ -96,7 +96,7 @@ func TestMultiEmitter(t *testing.T) { // Close multiEmitter. if err := me.Close(); err != nil { - t.Fatal("me.Close() failed: %v", err) + t.Fatalf("me.Close() failed: %v", err) } // All testEmitters should be closed. diff --git a/pkg/log/glog.go b/pkg/log/glog.go index b4f7bb5a4..f57c4427b 100644 --- a/pkg/log/glog.go +++ b/pkg/log/glog.go @@ -25,7 +25,7 @@ import ( // GoogleEmitter is a wrapper that emits logs in a format compatible with // package github.com/golang/glog. type GoogleEmitter struct { - Writer + *Writer } // pid is used for the threadid component of the header. @@ -46,7 +46,7 @@ var pid = os.Getpid() // line The line number // msg The user-supplied message // -func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) { +func (g GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) { // Log level. prefix := byte('?') switch level { @@ -81,5 +81,5 @@ func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format message := fmt.Sprintf(format, args...) // Emit the formatted result. - fmt.Fprintf(&g.Writer, "%c%02d%02d %02d:%02d:%02d.%06d % 7d %s:%d] %s\n", prefix, int(month), day, hour, minute, second, microsecond, pid, file, line, message) + fmt.Fprintf(g.Writer, "%c%02d%02d %02d:%02d:%02d.%06d % 7d %s:%d] %s\n", prefix, int(month), day, hour, minute, second, microsecond, pid, file, line, message) } diff --git a/pkg/log/json.go b/pkg/log/json.go index 0943db1cc..bdf9d691e 100644 --- a/pkg/log/json.go +++ b/pkg/log/json.go @@ -58,7 +58,7 @@ func (lv *Level) UnmarshalJSON(b []byte) error { // JSONEmitter logs messages in json format. type JSONEmitter struct { - Writer + *Writer } // Emit implements Emitter.Emit. diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go index 6c6fc8b6f..5883e95e1 100644 --- a/pkg/log/json_k8s.go +++ b/pkg/log/json_k8s.go @@ -29,11 +29,11 @@ type k8sJSONLog struct { // K8sJSONEmitter logs messages in json format that is compatible with // Kubernetes fluent configuration. type K8sJSONEmitter struct { - Writer + *Writer } // Emit implements Emitter.Emit. -func (e *K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) { +func (e K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) { j := k8sJSONLog{ Log: fmt.Sprintf(format, v...), Level: level, diff --git a/pkg/log/log.go b/pkg/log/log.go index a794da1aa..37e0605ad 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -374,5 +374,5 @@ func CopyStandardLogTo(l Level) error { func init() { // Store the initial value for the log. - log.Store(&BasicLogger{Level: Info, Emitter: &GoogleEmitter{Writer{Next: os.Stderr}}}) + log.Store(&BasicLogger{Level: Info, Emitter: GoogleEmitter{&Writer{Next: os.Stderr}}}) } diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go index 402cc29ae..9ff18559b 100644 --- a/pkg/log/log_test.go +++ b/pkg/log/log_test.go @@ -52,7 +52,7 @@ func TestDropMessages(t *testing.T) { t.Fatalf("Write should have failed") } - fmt.Printf("writer: %+v\n", w) + fmt.Printf("writer: %#v\n", &w) tw.fail = false if _, err := w.Write([]byte("line 2\n")); err != nil { @@ -76,7 +76,7 @@ func TestDropMessages(t *testing.T) { func TestCaller(t *testing.T) { tw := &testWriter{} - e := &GoogleEmitter{Writer: Writer{Next: tw}} + e := GoogleEmitter{Writer: &Writer{Next: tw}} bl := &BasicLogger{ Emitter: e, Level: Debug, @@ -94,7 +94,7 @@ func BenchmarkGoogleLogging(b *testing.B) { tw := &testWriter{ limit: 1, // Only record one message. } - e := &GoogleEmitter{Writer: Writer{Next: tw}} + e := GoogleEmitter{Writer: &Writer{Next: tw}} bl := &BasicLogger{ Emitter: e, Level: Debug, diff --git a/pkg/p9/client.go b/pkg/p9/client.go index a6f493b82..71e944c30 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -174,7 +174,7 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client // our sendRecv function to use that functionality. Otherwise, // we stick to sendRecvLegacy. rversion := Rversion{} - err := c.sendRecvLegacy(&Tversion{ + _, err := c.sendRecvLegacy(&Tversion{ Version: versionString(requested), MSize: messageSize, }, &rversion) @@ -219,11 +219,11 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client c.sendRecv = c.sendRecvChannel } else { // Channel setup failed; fallback. - c.sendRecv = c.sendRecvLegacy + c.sendRecv = c.sendRecvLegacySyscallErr } } else { // No channels available: use the legacy mechanism. - c.sendRecv = c.sendRecvLegacy + c.sendRecv = c.sendRecvLegacySyscallErr } // Ensure that the socket and channels are closed when the socket is shut @@ -305,7 +305,7 @@ func (c *Client) openChannel(id int) error { ) // Open the data channel. - if err := c.sendRecvLegacy(&Tchannel{ + if _, err := c.sendRecvLegacy(&Tchannel{ ID: uint32(id), Control: 0, }, &rchannel0); err != nil { @@ -319,7 +319,7 @@ func (c *Client) openChannel(id int) error { defer rchannel0.FilePayload().Close() // Open the channel for file descriptors. - if err := c.sendRecvLegacy(&Tchannel{ + if _, err := c.sendRecvLegacy(&Tchannel{ ID: uint32(id), Control: 1, }, &rchannel1); err != nil { @@ -431,13 +431,28 @@ func (c *Client) waitAndRecv(done chan error) error { } } +// sendRecvLegacySyscallErr is a wrapper for sendRecvLegacy that converts all +// non-syscall errors to EIO. +func (c *Client) sendRecvLegacySyscallErr(t message, r message) error { + received, err := c.sendRecvLegacy(t, r) + if !received { + log.Warningf("p9.Client.sendRecvChannel: %v", err) + return syscall.EIO + } + return err +} + // sendRecvLegacy performs a roundtrip message exchange. // +// sendRecvLegacy returns true if a message was received. This allows us to +// differentiate between failed receives and successful receives where the +// response was an error message. +// // This is called by internal functions. -func (c *Client) sendRecvLegacy(t message, r message) error { +func (c *Client) sendRecvLegacy(t message, r message) (bool, error) { tag, ok := c.tagPool.Get() if !ok { - return ErrOutOfTags + return false, ErrOutOfTags } defer c.tagPool.Put(tag) @@ -457,12 +472,12 @@ func (c *Client) sendRecvLegacy(t message, r message) error { err := send(c.socket, Tag(tag), t) c.sendMu.Unlock() if err != nil { - return err + return false, err } // Co-ordinate with other receivers. if err := c.waitAndRecv(resp.done); err != nil { - return err + return false, err } // Is it an error message? @@ -470,14 +485,14 @@ func (c *Client) sendRecvLegacy(t message, r message) error { // For convenience, we transform these directly // into errors. Handlers need not handle this case. if rlerr, ok := resp.r.(*Rlerror); ok { - return syscall.Errno(rlerr.Error) + return true, syscall.Errno(rlerr.Error) } // At this point, we know it matches. // // Per recv call above, we will only allow a type // match (and give our r) or an instance of Rlerror. - return nil + return true, nil } // sendRecvChannel uses channels to send a message. @@ -486,7 +501,7 @@ func (c *Client) sendRecvChannel(t message, r message) error { c.channelsMu.Lock() if len(c.availableChannels) == 0 { c.channelsMu.Unlock() - return c.sendRecvLegacy(t, r) + return c.sendRecvLegacySyscallErr(t, r) } idx := len(c.availableChannels) - 1 ch := c.availableChannels[idx] @@ -526,7 +541,11 @@ func (c *Client) sendRecvChannel(t message, r message) error { } // Parse the server's response. - _, retErr := ch.recv(r, rsz) + resp, retErr := ch.recv(r, rsz) + if resp == nil { + log.Warningf("p9.Client.sendRecvChannel: p9.channel.recv: %v", retErr) + retErr = syscall.EIO + } // Release the channel. c.channelsMu.Lock() diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go index 29a0afadf..c757583e0 100644 --- a/pkg/p9/client_test.go +++ b/pkg/p9/client_test.go @@ -96,7 +96,12 @@ func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) e } func BenchmarkSendRecvLegacy(b *testing.B) { - benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy }) + benchmarkSendRecv(b, func(c *Client) func(message, message) error { + return func(t message, r message) error { + _, err := c.sendRecvLegacy(t, r) + return err + } + }) } func BenchmarkSendRecvChannel(b *testing.B) { diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index a8b714cf5..1db5797dd 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -48,6 +48,8 @@ func ExtractErrno(err error) syscall.Errno { return ExtractErrno(e.Err) case *os.SyscallError: return ExtractErrno(e.Err) + case *os.LinkError: + return ExtractErrno(e.Err) } // Default case. diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go index a0d274f3b..38038abdf 100644 --- a/pkg/p9/transport_flipcall.go +++ b/pkg/p9/transport_flipcall.go @@ -236,7 +236,7 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) { // Convert errors appropriately; see above. if rlerr, ok := r.(*Rlerror); ok { - return nil, syscall.Errno(rlerr.Error) + return r, syscall.Errno(rlerr.Error) } return r, nil diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s index 129691d68..00b46c18f 100644 --- a/pkg/safecopy/memcpy_amd64.s +++ b/pkg/safecopy/memcpy_amd64.s @@ -55,15 +55,9 @@ TEXT ·memcpy(SB), NOSPLIT, $0-36 MOVQ from+8(FP), SI MOVQ n+16(FP), BX - // REP instructions have a high startup cost, so we handle small sizes - // with some straightline code. The REP MOVSQ instruction is really fast - // for large sizes. The cutover is approximately 2K. tail: - // move_129through256 or smaller work whether or not the source and the - // destination memory regions overlap because they load all data into - // registers before writing it back. move_256through2048 on the other - // hand can be used only when the memory regions don't overlap or the copy - // direction is forward. + // BSR+branch table make almost all memmove/memclr benchmarks worse. Not + // worth doing. TESTQ BX, BX JEQ move_0 CMPQ BX, $2 @@ -83,31 +77,45 @@ tail: JBE move_65through128 CMPQ BX, $256 JBE move_129through256 - // TODO: use branch table and BSR to make this just a single dispatch -/* - * forward copy loop - */ - CMPQ BX, $2048 - JLS move_256through2048 - - // Check alignment - MOVL SI, AX - ORL DI, AX - TESTL $7, AX - JEQ fwdBy8 - - // Do 1 byte at a time - MOVQ BX, CX - REP; MOVSB - RET - -fwdBy8: - // Do 8 bytes at a time - MOVQ BX, CX - SHRQ $3, CX - ANDQ $7, BX - REP; MOVSQ +move_257plus: + SUBQ $256, BX + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU 32(SI), X2 + MOVOU X2, 32(DI) + MOVOU 48(SI), X3 + MOVOU X3, 48(DI) + MOVOU 64(SI), X4 + MOVOU X4, 64(DI) + MOVOU 80(SI), X5 + MOVOU X5, 80(DI) + MOVOU 96(SI), X6 + MOVOU X6, 96(DI) + MOVOU 112(SI), X7 + MOVOU X7, 112(DI) + MOVOU 128(SI), X8 + MOVOU X8, 128(DI) + MOVOU 144(SI), X9 + MOVOU X9, 144(DI) + MOVOU 160(SI), X10 + MOVOU X10, 160(DI) + MOVOU 176(SI), X11 + MOVOU X11, 176(DI) + MOVOU 192(SI), X12 + MOVOU X12, 192(DI) + MOVOU 208(SI), X13 + MOVOU X13, 208(DI) + MOVOU 224(SI), X14 + MOVOU X14, 224(DI) + MOVOU 240(SI), X15 + MOVOU X15, 240(DI) + CMPQ BX, $256 + LEAQ 256(SI), SI + LEAQ 256(DI), DI + JGE move_257plus JMP tail move_1or2: @@ -209,42 +217,3 @@ move_129through256: MOVOU -16(SI)(BX*1), X15 MOVOU X15, -16(DI)(BX*1) RET -move_256through2048: - SUBQ $256, BX - MOVOU (SI), X0 - MOVOU X0, (DI) - MOVOU 16(SI), X1 - MOVOU X1, 16(DI) - MOVOU 32(SI), X2 - MOVOU X2, 32(DI) - MOVOU 48(SI), X3 - MOVOU X3, 48(DI) - MOVOU 64(SI), X4 - MOVOU X4, 64(DI) - MOVOU 80(SI), X5 - MOVOU X5, 80(DI) - MOVOU 96(SI), X6 - MOVOU X6, 96(DI) - MOVOU 112(SI), X7 - MOVOU X7, 112(DI) - MOVOU 128(SI), X8 - MOVOU X8, 128(DI) - MOVOU 144(SI), X9 - MOVOU X9, 144(DI) - MOVOU 160(SI), X10 - MOVOU X10, 160(DI) - MOVOU 176(SI), X11 - MOVOU X11, 176(DI) - MOVOU 192(SI), X12 - MOVOU X12, 192(DI) - MOVOU 208(SI), X13 - MOVOU X13, 208(DI) - MOVOU 224(SI), X14 - MOVOU X14, 224(DI) - MOVOU 240(SI), X15 - MOVOU X15, 240(DI) - CMPQ BX, $256 - LEAQ 256(SI), SI - LEAQ 256(DI), DI - JGE move_256through2048 - JMP tail diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go index f19a005f3..97b16c158 100644 --- a/pkg/segment/test/segment_test.go +++ b/pkg/segment/test/segment_test.go @@ -63,7 +63,7 @@ func checkSet(s *Set, expectedSegments int) error { return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) } if got, want := seg.Value(), seg.Start()+valueOffset; got != want { - return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start, got, want) + return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start(), got, want) } prev = next havePrev = true diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index 1d11cc472..a903d031c 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -88,6 +88,9 @@ type Context interface { // SyscallNo returns the syscall number. SyscallNo() uintptr + // SyscallSaveOrig save orignal register value. + SyscallSaveOrig() + // SyscallArgs returns the syscall arguments in an array. SyscallArgs() SyscallArguments diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go index 09bceabc9..1108fa0bd 100644 --- a/pkg/sentry/arch/stack.go +++ b/pkg/sentry/arch/stack.go @@ -97,7 +97,6 @@ func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) { if c < 0 { return 0, fmt.Errorf("bad binary.Size for %T", v) } - // TODO(b/38173783): Use a real context.Context. n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{}) if err != nil || c != n { return 0, err @@ -121,11 +120,9 @@ func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) { var err error if isVaddr { value := s.Arch.Native(uintptr(0)) - // TODO(b/38173783): Use a real context.Context. n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{}) *vaddr = usermem.Addr(s.Arch.Value(value)) } else { - // TODO(b/38173783): Use a real context.Context. n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{}) } if err != nil { diff --git a/pkg/sentry/arch/syscalls_amd64.go b/pkg/sentry/arch/syscalls_amd64.go index 8b4f23007..3859f41ee 100644 --- a/pkg/sentry/arch/syscalls_amd64.go +++ b/pkg/sentry/arch/syscalls_amd64.go @@ -18,6 +18,13 @@ package arch const restartSyscallNr = uintptr(219) +// SyscallSaveOrig save the value of the register which is clobbered in +// syscall handler(doSyscall()). +// +// Noop on x86. +func (c *context64) SyscallSaveOrig() { +} + // SyscallNo returns the syscall number according to the 64-bit convention. func (c *context64) SyscallNo() uintptr { return uintptr(c.Regs.Orig_rax) diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go index dc13b6124..92d062513 100644 --- a/pkg/sentry/arch/syscalls_arm64.go +++ b/pkg/sentry/arch/syscalls_arm64.go @@ -18,6 +18,17 @@ package arch const restartSyscallNr = uintptr(128) +// SyscallSaveOrig save the value of the register R0 which is clobbered in +// syscall handler(doSyscall()). +// +// In linux, at the entry of the syscall handler(el0_svc_common()), value of R0 +// is saved to the pt_regs.orig_x0 in kernel code. But currently, the orig_x0 +// was not accessible to the user space application, so we have to do the same +// operation in the sentry code to save the R0 value into the App context. +func (c *context64) SyscallSaveOrig() { + c.OrigR0 = c.Regs.Regs[0] +} + // SyscallNo returns the syscall number according to the 64-bit convention. func (c *context64) SyscallNo() uintptr { return uintptr(c.Regs.Regs[8]) @@ -40,7 +51,7 @@ func (c *context64) SyscallNo() uintptr { // R30: the link register. func (c *context64) SyscallArgs() SyscallArguments { return SyscallArguments{ - SyscallArgument{Value: uintptr(c.Regs.Regs[0])}, + SyscallArgument{Value: uintptr(c.OrigR0)}, SyscallArgument{Value: uintptr(c.Regs.Regs[1])}, SyscallArgument{Value: uintptr(c.Regs.Regs[2])}, SyscallArgument{Value: uintptr(c.Regs.Regs[3])}, diff --git a/pkg/sentry/contexttest/contexttest.go b/pkg/sentry/contexttest/contexttest.go index 031fc64ec..8e5658c7a 100644 --- a/pkg/sentry/contexttest/contexttest.go +++ b/pkg/sentry/contexttest/contexttest.go @@ -97,7 +97,7 @@ type hostClock struct { } // Now implements ktime.Clock.Now. -func (hostClock) Now() ktime.Time { +func (*hostClock) Now() ktime.Time { return ktime.FromNanoseconds(time.Now().UnixNano()) } @@ -127,7 +127,7 @@ func (t *TestContext) Value(key interface{}) interface{} { case uniqueid.CtxInotifyCookie: return atomic.AddUint32(&lastInotifyCookie, 1) case ktime.CtxRealtimeClock: - return hostClock{} + return &hostClock{} default: if val, ok := t.otherValues[key]; ok { return val diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index 0266a5287..65be12175 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -312,9 +312,9 @@ func (d *Dirent) SyncAll(ctx context.Context) { // There is nothing to sync for a read-only filesystem. if !d.Inode.MountSource.Flags.ReadOnly { - // FIXME(b/34856369): This should be a mount traversal, not a - // Dirent traversal, because some Inodes that need to be synced - // may no longer be reachable by name (after sys_unlink). + // NOTE(b/34856369): This should be a mount traversal, not a Dirent + // traversal, because some Inodes that need to be synced may no longer + // be reachable by name (after sys_unlink). // // Write out metadata, dirty page cached pages, and sync disk/remote // caches. diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go index 5aff0cc95..a0082ecca 100644 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_test.go @@ -119,7 +119,7 @@ func TestNewPipe(t *testing.T) { continue } if flags := p.flags; test.flags != flags { - t.Errorf("%s: got file flags %s, want %s", test.desc, flags, test.flags) + t.Errorf("%s: got file flags %v, want %v", test.desc, flags, test.flags) continue } if len(test.readAheadBuffer) != len(p.readAheadBuffer) { @@ -136,7 +136,7 @@ func TestNewPipe(t *testing.T) { continue } if !fdnotifier.HasFD(int32(f.FD())) { - t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD) + t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD()) } } } diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go index ff96b28ba..edd6576aa 100644 --- a/pkg/sentry/fs/gofer/file_state.go +++ b/pkg/sentry/fs/gofer/file_state.go @@ -34,7 +34,6 @@ func (f *fileOperations) afterLoad() { flags := f.flags flags.Truncate = false - // TODO(b/38173783): Context is not plumbed to save/restore. f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), flags, f.inodeOperations.cachingInodeOps) if err != nil { return fmt.Errorf("failed to re-open handle: %v", err) diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go index 9f7c3e89f..fc14249be 100644 --- a/pkg/sentry/fs/gofer/handles.go +++ b/pkg/sentry/fs/gofer/handles.go @@ -57,7 +57,6 @@ func (h *handles) DecRef() { } } } - // FIXME(b/38173783): Context is not plumbed here. if err := h.File.close(context.Background()); err != nil { log.Warningf("error closing p9 file: %v", err) } diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 1c934981b..a016c896e 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -273,7 +273,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle // operations on the old will see the new data. Then, make the new handle take // ownereship of the old FD and mark the old readHandle to not close the FD // when done. - if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), 0); err != nil { + if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), syscall.O_CLOEXEC); err != nil { return err } @@ -710,13 +710,10 @@ func init() { } // AddLink implements InodeOperations.AddLink, but is currently a noop. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (*inodeOperations) AddLink() {} // DropLink implements InodeOperations.DropLink, but is currently a noop. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (*inodeOperations) DropLink() {} // NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {} diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go index 238f7804c..a3402e343 100644 --- a/pkg/sentry/fs/gofer/inode_state.go +++ b/pkg/sentry/fs/gofer/inode_state.go @@ -123,7 +123,6 @@ func (i *inodeFileState) afterLoad() { // beforeSave. return fmt.Errorf("failed to find path for inode number %d. Device %s contains %s", i.sattr.InodeID, i.s.connID, fs.InodeMappings(i.s.inodeMappings)) } - // TODO(b/38173783): Context is not plumbed to save/restore. ctx := &dummyClockContext{context.Background()} _, i.file, err = i.s.attach.walk(ctx, splitAbsolutePath(name)) diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go index 111da59f9..2d398b753 100644 --- a/pkg/sentry/fs/gofer/session_state.go +++ b/pkg/sentry/fs/gofer/session_state.go @@ -104,7 +104,6 @@ func (s *session) afterLoad() { // If private unix sockets are enabled, create and fill the session's endpoint // maps. if opts.privateunixsocket { - // TODO(b/38173783): Context is not plumbed to save/restore. ctx := &dummyClockContext{context.Background()} if err = s.restoreEndpointMaps(ctx); err != nil { diff --git a/pkg/sentry/fs/gofer/util.go b/pkg/sentry/fs/gofer/util.go index 2d8d3a2ea..47a6c69bf 100644 --- a/pkg/sentry/fs/gofer/util.go +++ b/pkg/sentry/fs/gofer/util.go @@ -20,17 +20,29 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fs" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) func utimes(ctx context.Context, file contextFile, ts fs.TimeSpec) error { if ts.ATimeOmit && ts.MTimeOmit { return nil } + + // Replace requests to use the "system time" with the current time to + // ensure that timestamps remain consistent with the remote + // filesystem. + now := ktime.NowFromContext(ctx) + if ts.ATimeSetSystemTime { + ts.ATime = now + } + if ts.MTimeSetSystemTime { + ts.MTime = now + } mask := p9.SetAttrMask{ ATime: !ts.ATimeOmit, - ATimeNotSystemTime: !ts.ATimeSetSystemTime, + ATimeNotSystemTime: true, MTime: !ts.MTimeOmit, - MTimeNotSystemTime: !ts.MTimeSetSystemTime, + MTimeNotSystemTime: true, } as, ans := ts.ATime.Unix() ms, mns := ts.MTime.Unix() diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 1da3c0a17..62f1246aa 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -397,15 +397,12 @@ func (i *inodeOperations) StatFS(context.Context) (fs.Info, error) { } // AddLink implements fs.InodeOperations.AddLink. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (i *inodeOperations) AddLink() {} // DropLink implements fs.InodeOperations.DropLink. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (i *inodeOperations) DropLink() {} // NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {} // readdirAll returns all of the directory entries in i. diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index eb4afe520..affdbcacb 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -199,14 +199,14 @@ func TestListen(t *testing.T) { } func TestPasscred(t *testing.T) { - e := ConnectedEndpoint{} + e := &ConnectedEndpoint{} if got, want := e.Passcred(), false; got != want { t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want) } } func TestGetLocalAddress(t *testing.T) { - e := ConnectedEndpoint{path: "foo"} + e := &ConnectedEndpoint{path: "foo"} want := tcpip.FullAddress{Addr: tcpip.Address("foo")} if got, err := e.GetLocalAddress(); err != nil || got != want { t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil) @@ -214,7 +214,7 @@ func TestGetLocalAddress(t *testing.T) { } func TestQueuedSize(t *testing.T) { - e := ConnectedEndpoint{} + e := &ConnectedEndpoint{} tests := []struct { name string f func() int64 diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index 55fb71c16..a34fbc946 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -102,7 +102,6 @@ func (i *Inode) DecRef() { // destroy releases the Inode and releases the msrc reference taken. func (i *Inode) destroy() { - // FIXME(b/38173783): Context is not plumbed here. ctx := context.Background() if err := i.WriteOut(ctx); err != nil { // FIXME(b/65209558): Mark as warning again once noatime is @@ -397,8 +396,6 @@ func (i *Inode) Getlink(ctx context.Context) (*Dirent, error) { // AddLink calls i.InodeOperations.AddLink. func (i *Inode) AddLink() { if i.overlay != nil { - // FIXME(b/63117438): Remove this from InodeOperations altogether. - // // This interface is only used by ramfs to update metadata of // children. These filesystems should _never_ have overlay // Inodes cached as children. So explicitly disallow this diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index d4c4b533d..702fdd392 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -80,7 +80,7 @@ func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir } // Truncate implements fs.InodeOperations.Truncate. -func (tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error { +func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error { return nil } @@ -196,7 +196,7 @@ func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *f } // Truncate implements fs.InodeOperations.Truncate. -func (tcpSack) Truncate(context.Context, *fs.Inode, int64) error { +func (*tcpSack) Truncate(context.Context, *fs.Inode, int64) error { return nil } diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index d6c5dd2c1..4d42eac83 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -57,6 +57,16 @@ func getTaskMM(t *kernel.Task) (*mm.MemoryManager, error) { return m, nil } +func checkTaskState(t *kernel.Task) error { + switch t.ExitState() { + case kernel.TaskExitZombie: + return syserror.EACCES + case kernel.TaskExitDead: + return syserror.ESRCH + } + return nil +} + // taskDir represents a task-level directory. // // +stateify savable @@ -254,11 +264,12 @@ func newExe(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { } func (e *exe) executable() (file fsbridge.File, err error) { + if err := checkTaskState(e.t); err != nil { + return nil, err + } e.t.WithMuLocked(func(t *kernel.Task) { mm := t.MemoryManager() if mm == nil { - // TODO(b/34851096): Check shouldn't allow Readlink once the - // Task is zombied. err = syserror.EACCES return } @@ -268,7 +279,7 @@ func (e *exe) executable() (file fsbridge.File, err error) { // (with locks held). file = mm.Executable() if file == nil { - err = syserror.ENOENT + err = syserror.ESRCH } }) return @@ -313,11 +324,22 @@ func newNamespaceSymlink(t *kernel.Task, msrc *fs.MountSource, name string) *fs. return newProcInode(t, n, msrc, fs.Symlink, t) } +// Readlink reads the symlink value. +func (n *namespaceSymlink) Readlink(ctx context.Context, inode *fs.Inode) (string, error) { + if err := checkTaskState(n.t); err != nil { + return "", err + } + return n.Symlink.Readlink(ctx, inode) +} + // Getlink implements fs.InodeOperations.Getlink. func (n *namespaceSymlink) Getlink(ctx context.Context, inode *fs.Inode) (*fs.Dirent, error) { if !kernel.ContextCanTrace(ctx, n.t, false) { return nil, syserror.EACCES } + if err := checkTaskState(n.t); err != nil { + return nil, err + } // Create a new regular file to fake the namespace file. iops := fsutil.NewNoReadWriteFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0777), linux.PROC_SUPER_MAGIC) diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go index d5be56c3f..bc117ca6a 100644 --- a/pkg/sentry/fs/tmpfs/fs.go +++ b/pkg/sentry/fs/tmpfs/fs.go @@ -44,9 +44,6 @@ const ( // lookup. cacheRevalidate = "revalidate" - // TODO(edahlgren/mpratt): support a tmpfs size limit. - // size = "size" - // Permissions that exceed modeMask will be rejected. modeMask = 01777 diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 48eaccdbc..afea58f65 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -476,7 +476,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath } // ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { _, _, err := fs.walk(rp, false) if err != nil { return nil, err @@ -485,7 +485,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([ } // GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { _, _, err := fs.walk(rp, false) if err != nil { return "", err diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 137260898..cd744bf5e 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -1080,7 +1080,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath } // ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(&ds) @@ -1088,11 +1088,11 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([ if err != nil { return nil, err } - return d.listxattr(ctx) + return d.listxattr(ctx, rp.Credentials(), size) } // GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(&ds) @@ -1100,7 +1100,7 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, nam if err != nil { return "", err } - return d.getxattr(ctx, name) + return d.getxattr(ctx, rp.Credentials(), &opts) } // SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. @@ -1112,7 +1112,7 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt if err != nil { return err } - return d.setxattr(ctx, &opts) + return d.setxattr(ctx, rp.Credentials(), &opts) } // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. @@ -1124,7 +1124,7 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, if err != nil { return err } - return d.removexattr(ctx, name) + return d.removexattr(ctx, rp.Credentials(), name) } // PrependPath implements vfs.FilesystemImpl.PrependPath. diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 20edaf643..2485cdb53 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -34,6 +34,7 @@ package gofer import ( "fmt" "strconv" + "strings" "sync" "sync/atomic" "syscall" @@ -1024,21 +1025,50 @@ func (d *dentry) setDeleted() { atomic.StoreUint32(&d.deleted, 1) } -func (d *dentry) listxattr(ctx context.Context) ([]string, error) { - return nil, syserror.ENOTSUP +// We only support xattrs prefixed with "user." (see b/148380782). Currently, +// there is no need to expose any other xattrs through a gofer. +func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { + xattrMap, err := d.file.listXattr(ctx, size) + if err != nil { + return nil, err + } + xattrs := make([]string, 0, len(xattrMap)) + for x := range xattrMap { + if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) { + xattrs = append(xattrs, x) + } + } + return xattrs, nil } -func (d *dentry) getxattr(ctx context.Context, name string) (string, error) { - // TODO(jamieliu): add vfs.GetxattrOptions.Size - return d.file.getXattr(ctx, name, linux.XATTR_SIZE_MAX) +func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) { + if err := d.checkPermissions(creds, vfs.MayRead); err != nil { + return "", err + } + if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { + return "", syserror.EOPNOTSUPP + } + return d.file.getXattr(ctx, opts.Name, opts.Size) } -func (d *dentry) setxattr(ctx context.Context, opts *vfs.SetxattrOptions) error { +func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetxattrOptions) error { + if err := d.checkPermissions(creds, vfs.MayWrite); err != nil { + return err + } + if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags) } -func (d *dentry) removexattr(ctx context.Context, name string) error { - return syserror.ENOTSUP +func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name string) error { + if err := d.checkPermissions(creds, vfs.MayWrite); err != nil { + return err + } + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + return d.file.removeXattr(ctx, name) } // Preconditions: d.isRegularFile() || d.isDirectory(). @@ -1089,7 +1119,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // description, but this doesn't matter since they refer to the // same file (unless d.fs.opts.overlayfsStaleRead is true, // which we handle separately). - if err := syscall.Dup3(int(h.fd), int(d.handle.fd), 0); err != nil { + if err := syscall.Dup3(int(h.fd), int(d.handle.fd), syscall.O_CLOEXEC); err != nil { d.handleMu.Unlock() ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err) h.close(ctx) @@ -1189,21 +1219,21 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) } // Listxattr implements vfs.FileDescriptionImpl.Listxattr. -func (fd *fileDescription) Listxattr(ctx context.Context) ([]string, error) { - return fd.dentry().listxattr(ctx) +func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { + return fd.dentry().listxattr(ctx, auth.CredentialsFromContext(ctx), size) } // Getxattr implements vfs.FileDescriptionImpl.Getxattr. -func (fd *fileDescription) Getxattr(ctx context.Context, name string) (string, error) { - return fd.dentry().getxattr(ctx, name) +func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) { + return fd.dentry().getxattr(ctx, auth.CredentialsFromContext(ctx), &opts) } // Setxattr implements vfs.FileDescriptionImpl.Setxattr. func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { - return fd.dentry().setxattr(ctx, &opts) + return fd.dentry().setxattr(ctx, auth.CredentialsFromContext(ctx), &opts) } // Removexattr implements vfs.FileDescriptionImpl.Removexattr. func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { - return fd.dentry().removexattr(ctx, name) + return fd.dentry().removexattr(ctx, auth.CredentialsFromContext(ctx), name) } diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go index 755ac2985..87f0b877f 100644 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ b/pkg/sentry/fsimpl/gofer/p9file.go @@ -85,6 +85,13 @@ func (f p9file) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAt return err } +func (f p9file) listXattr(ctx context.Context, size uint64) (map[string]struct{}, error) { + ctx.UninterruptibleSleepStart(false) + xattrs, err := f.file.ListXattr(size) + ctx.UninterruptibleSleepFinish(false) + return xattrs, err +} + func (f p9file) getXattr(ctx context.Context, name string, size uint64) (string, error) { ctx.UninterruptibleSleepStart(false) val, err := f.file.GetXattr(name, size) @@ -99,6 +106,13 @@ func (f p9file) setXattr(ctx context.Context, name, value string, flags uint32) return err } +func (f p9file) removeXattr(ctx context.Context, name string) error { + ctx.UninterruptibleSleepStart(false) + err := f.file.RemoveXattr(name) + ctx.UninterruptibleSleepFinish(false) + return err +} + func (f p9file) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error { ctx.UninterruptibleSleepStart(false) err := f.file.Allocate(mode, offset, length) diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 16a3c18ae..baf81b4db 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -682,7 +682,7 @@ func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu if err != nil { return linux.Statfs{}, err } - // TODO: actually implement statfs + // TODO(gvisor.dev/issue/1193): actually implement statfs. return linux.Statfs{}, syserror.ENOSYS } @@ -763,7 +763,7 @@ func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath } // ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { +func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() @@ -776,7 +776,7 @@ func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([ } // GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { +func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD new file mode 100644 index 000000000..0d411606f --- /dev/null +++ b/pkg/sentry/fsimpl/pipefs/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "pipefs", + srcs = ["pipefs.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/pipe", + "//pkg/sentry/kernel/time", + "//pkg/sentry/vfs", + "//pkg/syserror", + "//pkg/usermem", + ], +) diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go new file mode 100644 index 000000000..faf3179bc --- /dev/null +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -0,0 +1,148 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pipefs provides the filesystem implementation backing +// Kernel.PipeMount. +package pipefs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type filesystemType struct{} + +// Name implements vfs.FilesystemType.Name. +func (filesystemType) Name() string { + return "pipefs" +} + +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (filesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + panic("pipefs.filesystemType.GetFilesystem should never be called") +} + +// filesystem implements vfs.FilesystemImpl. +type filesystem struct { + kernfs.Filesystem + + // TODO(gvisor.dev/issue/1193): + // + // - kernfs does not provide a way to implement statfs, from which we + // should indicate PIPEFS_MAGIC. + // + // - kernfs does not provide a way to override names for + // vfs.FilesystemImpl.PrependPath(); pipefs inodes should use synthetic + // name fmt.Sprintf("pipe:[%d]", inode.ino). +} + +// NewFilesystem sets up and returns a new vfs.Filesystem implemented by +// pipefs. +func NewFilesystem(vfsObj *vfs.VirtualFilesystem) *vfs.Filesystem { + fs := &filesystem{} + fs.Init(vfsObj, filesystemType{}) + return fs.VFSFilesystem() +} + +// inode implements kernfs.Inode. +type inode struct { + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + kernfs.InodeNoopRefCount + + pipe *pipe.VFSPipe + + ino uint64 + uid auth.KUID + gid auth.KGID + // We use the creation timestamp for all of atime, mtime, and ctime. + ctime ktime.Time +} + +func newInode(ctx context.Context, fs *kernfs.Filesystem) *inode { + creds := auth.CredentialsFromContext(ctx) + return &inode{ + pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + ino: fs.NextIno(), + uid: creds.EffectiveKUID, + gid: creds.EffectiveKGID, + ctime: ktime.NowFromContext(ctx), + } +} + +const pipeMode = 0600 | linux.S_IFIFO + +// CheckPermissions implements kernfs.Inode.CheckPermissions. +func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { + return vfs.GenericCheckPermissions(creds, ats, pipeMode, i.uid, i.gid) +} + +// Mode implements kernfs.Inode.Mode. +func (i *inode) Mode() linux.FileMode { + return pipeMode +} + +// Stat implements kernfs.Inode.Stat. +func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds()) + return linux.Statx{ + Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS, + Blksize: usermem.PageSize, + Nlink: 1, + UID: uint32(i.uid), + GID: uint32(i.gid), + Mode: pipeMode, + Ino: i.ino, + Size: 0, + Blocks: 0, + Atime: ts, + Ctime: ts, + Mtime: ts, + // TODO(gvisor.dev/issue/1197): Device number. + }, nil +} + +// SetStat implements kernfs.Inode.SetStat. +func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + if opts.Stat.Mask == 0 { + return nil + } + return syserror.EPERM +} + +// Open implements kernfs.Inode.Open. +func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + // FIXME(b/38173783): kernfs does not plumb Context here. + return i.pipe.Open(context.Background(), rp.Mount(), vfsd, opts.Flags) +} + +// NewConnectedPipeFDs returns a pair of FileDescriptions representing the read +// and write ends of a newly-created pipe, as for pipe(2) and pipe2(2). +// +// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). +func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription) { + fs := mnt.Filesystem().Impl().(*kernfs.Filesystem) + inode := newInode(ctx, fs) + var d kernfs.Dentry + d.Init(inode) + defer d.DecRef() + return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags) +} diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index aee2a4392..888afc0fd 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -214,22 +214,6 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData { return &ioData{ioUsage: t} } -func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry { - // Namespace symlinks should contain the namespace name and the inode number - // for the namespace instance, so for example user:[123456]. We currently fake - // the inode number by sticking the symlink inode in its place. - target := fmt.Sprintf("%s:[%d]", ns, ino) - - inode := &kernfs.StaticSymlink{} - // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), ino, target) - - taskInode := &taskOwnedInode{Inode: inode, owner: task} - d := &kernfs.Dentry{} - d.Init(taskInode) - return d -} - // newCgroupData creates inode that shows cgroup information. // From man 7 cgroups: "For each cgroup hierarchy of which the process is a // member, there is one entry containing three colon-separated fields: diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 9c8656b28..046265eca 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -30,34 +30,35 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -type fdDir struct { - inoGen InoGenerator - task *kernel.Task - - // When produceSymlinks is set, dirents produces for the FDs are reported - // as symlink. Otherwise, they are reported as regular files. - produceSymlink bool -} - -func (i *fdDir) lookup(name string) (*vfs.FileDescription, kernel.FDFlags, error) { - fd, err := strconv.ParseUint(name, 10, 64) - if err != nil { - return nil, kernel.FDFlags{}, syserror.ENOENT - } - +func getTaskFD(t *kernel.Task, fd int32) (*vfs.FileDescription, kernel.FDFlags) { var ( file *vfs.FileDescription flags kernel.FDFlags ) - i.task.WithMuLocked(func(t *kernel.Task) { - if fdTable := t.FDTable(); fdTable != nil { - file, flags = fdTable.GetVFS2(int32(fd)) + t.WithMuLocked(func(t *kernel.Task) { + if fdt := t.FDTable(); fdt != nil { + file, flags = fdt.GetVFS2(fd) } }) + return file, flags +} + +func taskFDExists(t *kernel.Task, fd int32) bool { + file, _ := getTaskFD(t, fd) if file == nil { - return nil, kernel.FDFlags{}, syserror.ENOENT + return false } - return file, flags, nil + file.DecRef() + return true +} + +type fdDir struct { + inoGen InoGenerator + task *kernel.Task + + // When produceSymlinks is set, dirents produces for the FDs are reported + // as symlink. Otherwise, they are reported as regular files. + produceSymlink bool } // IterDirents implements kernfs.inodeDynamicLookup. @@ -128,11 +129,15 @@ func newFDDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { // Lookup implements kernfs.inodeDynamicLookup. func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { - file, _, err := i.lookup(name) + fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { - return nil, err + return nil, syserror.ENOENT + } + fd := int32(fdInt) + if !taskFDExists(i.task, fd) { + return nil, syserror.ENOENT } - taskDentry := newFDSymlink(i.task.Credentials(), file, i.inoGen.NextIno()) + taskDentry := newFDSymlink(i.task, fd, i.inoGen.NextIno()) return taskDentry.VFSDentry(), nil } @@ -169,19 +174,22 @@ func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentia // // +stateify savable type fdSymlink struct { - refs.AtomicRefCount kernfs.InodeAttrs + kernfs.InodeNoopRefCount kernfs.InodeSymlink - file *vfs.FileDescription + task *kernel.Task + fd int32 } var _ kernfs.Inode = (*fdSymlink)(nil) -func newFDSymlink(creds *auth.Credentials, file *vfs.FileDescription, ino uint64) *kernfs.Dentry { - file.IncRef() - inode := &fdSymlink{file: file} - inode.Init(creds, ino, linux.ModeSymlink|0777) +func newFDSymlink(task *kernel.Task, fd int32, ino uint64) *kernfs.Dentry { + inode := &fdSymlink{ + task: task, + fd: fd, + } + inode.Init(task.Credentials(), ino, linux.ModeSymlink|0777) d := &kernfs.Dentry{} d.Init(inode) @@ -189,29 +197,27 @@ func newFDSymlink(creds *auth.Credentials, file *vfs.FileDescription, ino uint64 } func (s *fdSymlink) Readlink(ctx context.Context) (string, error) { + file, _ := getTaskFD(s.task, s.fd) + if file == nil { + return "", syserror.ENOENT + } + defer file.DecRef() root := vfs.RootFromContext(ctx) defer root.DecRef() - - vfsObj := s.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem() - return vfsObj.PathnameWithDeleted(ctx, root, s.file.VirtualDentry()) + return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry()) } func (s *fdSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) { - vd := s.file.VirtualDentry() + file, _ := getTaskFD(s.task, s.fd) + if file == nil { + return vfs.VirtualDentry{}, "", syserror.ENOENT + } + defer file.DecRef() + vd := file.VirtualDentry() vd.IncRef() return vd, "", nil } -func (s *fdSymlink) DecRef() { - s.AtomicRefCount.DecRefWithDestructor(func() { - s.Destroy() - }) -} - -func (s *fdSymlink) Destroy() { - s.file.DecRef() -} - // fdInfoDirInode represents the inode for /proc/[pid]/fdinfo directory. // // +stateify savable @@ -244,12 +250,18 @@ func newFDInfoDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { // Lookup implements kernfs.inodeDynamicLookup. func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { - file, flags, err := i.lookup(name) + fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { - return nil, err + return nil, syserror.ENOENT + } + fd := int32(fdInt) + if !taskFDExists(i.task, fd) { + return nil, syserror.ENOENT + } + data := &fdInfoData{ + task: i.task, + fd: fd, } - - data := &fdInfoData{file: file, flags: flags} dentry := newTaskOwnedFile(i.task, i.inoGen.NextIno(), 0444, data) return dentry.VFSDentry(), nil } @@ -268,26 +280,23 @@ type fdInfoData struct { kernfs.DynamicBytesFile refs.AtomicRefCount - file *vfs.FileDescription - flags kernel.FDFlags + task *kernel.Task + fd int32 } var _ dynamicInode = (*fdInfoData)(nil) -func (d *fdInfoData) DecRef() { - d.AtomicRefCount.DecRefWithDestructor(d.destroy) -} - -func (d *fdInfoData) destroy() { - d.file.DecRef() -} - // Generate implements vfs.DynamicBytesSource.Generate. func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { + file, descriptorFlags := getTaskFD(d.task, d.fd) + if file == nil { + return syserror.ENOENT + } + defer file.DecRef() // TODO(b/121266871): Include pos, locks, and other data. For now we only // have flags. // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt - flags := uint(d.file.StatusFlags()) | d.flags.ToLinuxFileFlags() + flags := uint(file.StatusFlags()) | descriptorFlags.ToLinuxFileFlags() fmt.Fprintf(buf, "flags:\t0%o\n", flags) return nil } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 88ea6a6d8..2c6f8bdfc 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -64,6 +64,16 @@ func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) { return m, nil } +func checkTaskState(t *kernel.Task) error { + switch t.ExitState() { + case kernel.TaskExitZombie: + return syserror.EACCES + case kernel.TaskExitDead: + return syserror.ESRCH + } + return nil +} + type bufferWriter struct { buf *bytes.Buffer } @@ -628,11 +638,13 @@ func (s *exeSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, er } func (s *exeSymlink) executable() (file fsbridge.File, err error) { + if err := checkTaskState(s.task); err != nil { + return nil, err + } + s.task.WithMuLocked(func(t *kernel.Task) { mm := t.MemoryManager() if mm == nil { - // TODO(b/34851096): Check shouldn't allow Readlink once the - // Task is zombied. err = syserror.EACCES return } @@ -642,7 +654,7 @@ func (s *exeSymlink) executable() (file fsbridge.File, err error) { // (with locks held). file = mm.Executable() if file == nil { - err = syserror.ENOENT + err = syserror.ESRCH } }) return @@ -709,3 +721,41 @@ func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error { i.task.Kernel().VFS().GenerateProcMounts(ctx, rootDir, buf) return nil } + +type namespaceSymlink struct { + kernfs.StaticSymlink + + task *kernel.Task +} + +func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry { + // Namespace symlinks should contain the namespace name and the inode number + // for the namespace instance, so for example user:[123456]. We currently fake + // the inode number by sticking the symlink inode in its place. + target := fmt.Sprintf("%s:[%d]", ns, ino) + + inode := &namespaceSymlink{task: task} + // Note: credentials are overridden by taskOwnedInode. + inode.Init(task.Credentials(), ino, target) + + taskInode := &taskOwnedInode{Inode: inode, owner: task} + d := &kernfs.Dentry{} + d.Init(taskInode) + return d +} + +// Readlink implements Inode. +func (s *namespaceSymlink) Readlink(ctx context.Context) (string, error) { + if err := checkTaskState(s.task); err != nil { + return "", err + } + return s.StaticSymlink.Readlink(ctx) +} + +// Getlink implements Inode.Getlink. +func (s *namespaceSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) { + if err := checkTaskState(s.task); err != nil { + return vfs.VirtualDentry{}, "", err + } + return s.StaticSymlink.Getlink(ctx) +} diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 6b2a77328..6595fcee6 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -688,9 +688,9 @@ func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error { if line.prefix == "Tcp" { tcp := stat.(*inet.StatSNMPTCP) // "Tcp" needs special processing because MaxConn is signed. RFC 2012. - fmt.Sprintf("%s: %s %d %s\n", line.prefix, sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:])) + fmt.Fprintf(buf, "%s: %s %d %s\n", line.prefix, sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:])) } else { - fmt.Sprintf("%s: %s\n", line.prefix, sprintSlice(toSlice(stat))) + fmt.Fprintf(buf, "%s: %s\n", line.prefix, sprintSlice(toSlice(stat))) } } return nil diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index f2ac23c88..4e6cd3491 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -51,6 +51,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/vfs", "//pkg/sentry/vfs/lock", + "//pkg/sentry/vfs/memxattr", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 5339d7072..660f5a29b 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -392,7 +392,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open // Can't open symlinks without O_PATH (which is unimplemented). return nil, syserror.ELOOP case *namedPipe: - return newNamedPipeFD(ctx, impl, rp, &d.vfsd, opts.Flags) + return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags) case *deviceFile: return rp.VirtualFilesystem().OpenDeviceSpecialFile(ctx, rp.Mount(), &d.vfsd, impl.kind, impl.major, impl.minor, opts) case *socketFile: @@ -696,51 +696,47 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath } // ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { fs.mu.RLock() defer fs.mu.RUnlock() - _, err := resolveLocked(rp) + d, err := resolveLocked(rp) if err != nil { return nil, err } - // TODO(b/127675828): support extended attributes - return nil, syserror.ENOTSUP + return d.inode.listxattr(size) } // GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { fs.mu.RLock() defer fs.mu.RUnlock() - _, err := resolveLocked(rp) + d, err := resolveLocked(rp) if err != nil { return "", err } - // TODO(b/127675828): support extended attributes - return "", syserror.ENOTSUP + return d.inode.getxattr(rp.Credentials(), &opts) } // SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { fs.mu.RLock() defer fs.mu.RUnlock() - _, err := resolveLocked(rp) + d, err := resolveLocked(rp) if err != nil { return err } - // TODO(b/127675828): support extended attributes - return syserror.ENOTSUP + return d.inode.setxattr(rp.Credentials(), &opts) } // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { fs.mu.RLock() defer fs.mu.RUnlock() - _, err := resolveLocked(rp) + d, err := resolveLocked(rp) if err != nil { return err } - // TODO(b/127675828): support extended attributes - return syserror.ENOTSUP + return d.inode.removexattr(rp.Credentials(), name) } // PrependPath implements vfs.FilesystemImpl.PrependPath. diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go index 2c5c739df..8d77b3fa8 100644 --- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go +++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go @@ -16,10 +16,8 @@ package tmpfs import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" - "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/usermem" ) @@ -33,27 +31,8 @@ type namedPipe struct { // * fs.mu must be locked. // * rp.Mount().CheckBeginWrite() has been called successfully. func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode { - file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)} + file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)} file.inode.init(file, fs, creds, linux.S_IFIFO|mode) file.inode.nlink = 1 // Only the parent has a link. return &file.inode } - -// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented -// entirely via struct embedding. -type namedPipeFD struct { - fileDescription - - *pipe.VFSPipeFD -} - -func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { - var err error - var fd namedPipeFD - fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, vfsd, &fd.vfsfd, flags) - if err != nil { - return nil, err - } - fd.vfsfd.Init(&fd, flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}) - return &fd.vfsfd, nil -} diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go index 3e02e7190..d4f59ee5b 100644 --- a/pkg/sentry/fsimpl/tmpfs/stat_test.go +++ b/pkg/sentry/fsimpl/tmpfs/stat_test.go @@ -140,7 +140,7 @@ func TestSetStatAtime(t *testing.T) { Mask: 0, Atime: linux.NsecToStatxTimestamp(100), }}); err != nil { - t.Errorf("SetStat atime without mask failed: %v") + t.Errorf("SetStat atime without mask failed: %v", err) } // Atime should be unchanged. if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { @@ -155,7 +155,7 @@ func TestSetStatAtime(t *testing.T) { Atime: linux.NsecToStatxTimestamp(100), } if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil { - t.Errorf("SetStat atime with mask failed: %v") + t.Errorf("SetStat atime with mask failed: %v", err) } if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { t.Errorf("Stat got error: %v", err) @@ -205,7 +205,7 @@ func TestSetStat(t *testing.T) { Mask: 0, Atime: linux.NsecToStatxTimestamp(100), }}); err != nil { - t.Errorf("SetStat atime without mask failed: %v") + t.Errorf("SetStat atime without mask failed: %v", err) } // Atime should be unchanged. if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { @@ -220,7 +220,7 @@ func TestSetStat(t *testing.T) { Atime: linux.NsecToStatxTimestamp(100), } if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil { - t.Errorf("SetStat atime with mask failed: %v") + t.Errorf("SetStat atime with mask failed: %v", err) } if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { t.Errorf("Stat got error: %v", err) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 654e788e3..a59b24d45 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -27,6 +27,7 @@ package tmpfs import ( "fmt" "math" + "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -37,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/vfs/lock" + "gvisor.dev/gvisor/pkg/sentry/vfs/memxattr" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -186,6 +188,11 @@ type inode struct { // filesystem.RmdirAt() drops the reference. refs int64 + // xattrs implements extended attributes. + // + // TODO(b/148380782): Support xattrs other than user.* + xattrs memxattr.SimpleExtendedAttributes + // Inode metadata. Writing multiple fields atomically requires holding // mu, othewise atomic operations can be used. mu sync.Mutex @@ -350,6 +357,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu return err } i.mu.Lock() + defer i.mu.Unlock() var ( needsMtimeBump bool needsCtimeBump bool @@ -420,7 +428,6 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu atomic.StoreInt64(&i.ctime, now) } - i.mu.Unlock() return nil } @@ -535,6 +542,56 @@ func (i *inode) touchCMtimeLocked() { atomic.StoreInt64(&i.ctime, now) } +func (i *inode) listxattr(size uint64) ([]string, error) { + return i.xattrs.Listxattr(size) +} + +func (i *inode) getxattr(creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) { + if err := i.checkPermissions(creds, vfs.MayRead); err != nil { + return "", err + } + if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { + return "", syserror.EOPNOTSUPP + } + if !i.userXattrSupported() { + return "", syserror.ENODATA + } + return i.xattrs.Getxattr(opts) +} + +func (i *inode) setxattr(creds *auth.Credentials, opts *vfs.SetxattrOptions) error { + if err := i.checkPermissions(creds, vfs.MayWrite); err != nil { + return err + } + if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + if !i.userXattrSupported() { + return syserror.EPERM + } + return i.xattrs.Setxattr(opts) +} + +func (i *inode) removexattr(creds *auth.Credentials, name string) error { + if err := i.checkPermissions(creds, vfs.MayWrite); err != nil { + return err + } + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + if !i.userXattrSupported() { + return syserror.EPERM + } + return i.xattrs.Removexattr(name) +} + +// Extended attributes in the user.* namespace are only supported for regular +// files and directories. +func (i *inode) userXattrSupported() bool { + filetype := linux.S_IFMT & atomic.LoadUint32(&i.mode) + return filetype == linux.S_IFREG || filetype == linux.S_IFDIR +} + // fileDescription is embedded by tmpfs implementations of // vfs.FileDescriptionImpl. type fileDescription struct { @@ -562,3 +619,23 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) creds := auth.CredentialsFromContext(ctx) return fd.inode().setStat(ctx, creds, &opts.Stat) } + +// Listxattr implements vfs.FileDescriptionImpl.Listxattr. +func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { + return fd.inode().listxattr(size) +} + +// Getxattr implements vfs.FileDescriptionImpl.Getxattr. +func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) { + return fd.inode().getxattr(auth.CredentialsFromContext(ctx), &opts) +} + +// Setxattr implements vfs.FileDescriptionImpl.Setxattr. +func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { + return fd.inode().setxattr(auth.CredentialsFromContext(ctx), &opts) +} + +// Removexattr implements vfs.FileDescriptionImpl.Removexattr. +func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { + return fd.inode().removexattr(auth.CredentialsFromContext(ctx), name) +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e0ff58d8c..e47af66d6 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -170,6 +170,7 @@ go_library( "//pkg/sentry/fs/timerfd", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/fsimpl/pipefs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/hostcpu", "//pkg/sentry/inet", diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index d09d97825..ed40b5303 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -307,6 +307,61 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return fds, nil } +// NewFDsVFS2 allocates new FDs guaranteed to be the lowest number available +// greater than or equal to the fd parameter. All files will share the set +// flags. Success is guaranteed to be all or none. +func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) { + if fd < 0 { + // Don't accept negative FDs. + return nil, syscall.EINVAL + } + + // Default limit. + end := int32(math.MaxInt32) + + // Ensure we don't get past the provided limit. + if limitSet := limits.FromContext(ctx); limitSet != nil { + lim := limitSet.Get(limits.NumberOfFiles) + if lim.Cur != limits.Infinity { + end = int32(lim.Cur) + } + if fd >= end { + return nil, syscall.EMFILE + } + } + + f.mu.Lock() + defer f.mu.Unlock() + + // From f.next to find available fd. + if fd < f.next { + fd = f.next + } + + // Install all entries. + for i := fd; i < end && len(fds) < len(files); i++ { + if d, _, _ := f.getVFS2(i); d == nil { + f.setVFS2(i, files[len(fds)], flags) // Set the descriptor. + fds = append(fds, i) // Record the file descriptor. + } + } + + // Failure? Unwind existing FDs. + if len(fds) < len(files) { + for _, i := range fds { + f.setVFS2(i, nil, FDFlags{}) // Zap entry. + } + return nil, syscall.EMFILE + } + + if fd == f.next { + // Update next search start position. + f.next = fds[len(fds)-1] + 1 + } + + return fds, nil +} + // NewFDVFS2 allocates a file descriptor greater than or equal to minfd for // the given file description. If it succeeds, it takes a reference on file. func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index de8a95854..fef60e636 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -50,6 +50,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/timerfd" "gvisor.dev/gvisor/pkg/sentry/fsbridge" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/hostcpu" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -254,6 +255,10 @@ type Kernel struct { // VFS keeps the filesystem state used across the kernel. vfs vfs.VirtualFilesystem + // pipeMount is the Mount used for pipes created by the pipe() and pipe2() + // syscalls (as opposed to named pipes created by mknod()). + pipeMount *vfs.Mount + // If set to true, report address space activation waits as if the task is in // external wait so that the watchdog doesn't report the task stuck. SleepForAddressSpaceActivation bool @@ -354,19 +359,29 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic} k.futexes = futex.NewManager() k.netlinkPorts = port.New() + if VFS2Enabled { if err := k.vfs.Init(); err != nil { return fmt.Errorf("failed to initialize VFS: %v", err) } - fs := sockfs.NewFilesystem(&k.vfs) - // NewDisconnectedMount will take an additional reference on fs. - defer fs.DecRef() - sm, err := k.vfs.NewDisconnectedMount(fs, nil, &vfs.MountOptions{}) + + pipeFilesystem := pipefs.NewFilesystem(&k.vfs) + defer pipeFilesystem.DecRef() + pipeMount, err := k.vfs.NewDisconnectedMount(pipeFilesystem, nil, &vfs.MountOptions{}) + if err != nil { + return fmt.Errorf("failed to create pipefs mount: %v", err) + } + k.pipeMount = pipeMount + + socketFilesystem := sockfs.NewFilesystem(&k.vfs) + defer socketFilesystem.DecRef() + socketMount, err := k.vfs.NewDisconnectedMount(socketFilesystem, nil, &vfs.MountOptions{}) if err != nil { return fmt.Errorf("failed to initialize socket mount: %v", err) } - k.socketMount = sm + k.socketMount = socketMount } + return nil } @@ -1613,3 +1628,8 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) { func (k *Kernel) VFS() *vfs.VirtualFilesystem { return &k.vfs } + +// PipeMount returns the pipefs mount. +func (k *Kernel) PipeMount() *vfs.Mount { + return k.pipeMount +} diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index a5675bd70..b54f08a30 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -49,38 +49,42 @@ type VFSPipe struct { } // NewVFSPipe returns an initialized VFSPipe. -func NewVFSPipe(sizeBytes, atomicIOBytes int64) *VFSPipe { +func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe { var vp VFSPipe - initPipe(&vp.pipe, true /* isNamed */, sizeBytes, atomicIOBytes) + initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes) return &vp } -// NewVFSPipeFD opens a named pipe. Named pipes have special blocking semantics -// during open: +// ReaderWriterPair returns read-only and write-only FDs for vp. // -// "Normally, opening the FIFO blocks until the other end is opened also. A -// process can open a FIFO in nonblocking mode. In this case, opening for -// read-only will succeed even if no-one has opened on the write side yet, -// opening for write-only will fail with ENXIO (no such device or address) -// unless the other end has already been opened. Under Linux, opening a FIFO -// for read and write will succeed both in blocking and nonblocking mode. POSIX -// leaves this behavior undefined. This can be used to open a FIFO for writing -// while there are no readers available." - fifo(7) -func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) { +// Preconditions: statusFlags should not contain an open access mode. +func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) { + return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags) +} + +// Open opens the pipe represented by vp. +func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, error) { vp.mu.Lock() defer vp.mu.Unlock() - readable := vfs.MayReadFileWithOpenFlags(flags) - writable := vfs.MayWriteFileWithOpenFlags(flags) + readable := vfs.MayReadFileWithOpenFlags(statusFlags) + writable := vfs.MayWriteFileWithOpenFlags(statusFlags) if !readable && !writable { return nil, syserror.EINVAL } - vfd, err := vp.open(vfsd, vfsfd, flags) - if err != nil { - return nil, err - } + fd := vp.newFD(mnt, vfsd, statusFlags) + // Named pipes have special blocking semantics during open: + // + // "Normally, opening the FIFO blocks until the other end is opened also. A + // process can open a FIFO in nonblocking mode. In this case, opening for + // read-only will succeed even if no-one has opened on the write side yet, + // opening for write-only will fail with ENXIO (no such device or address) + // unless the other end has already been opened. Under Linux, opening a + // FIFO for read and write will succeed both in blocking and nonblocking + // mode. POSIX leaves this behavior undefined. This can be used to open a + // FIFO for writing while there are no readers available." - fifo(7) switch { case readable && writable: // Pipes opened for read-write always succeed without blocking. @@ -89,23 +93,26 @@ func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, vfsd *vfs.Dentry, vfsfd *vf case readable: newHandleLocked(&vp.rWakeup) - // If this pipe is being opened as nonblocking and there's no + // If this pipe is being opened as blocking and there's no // writer, we have to wait for a writer to open the other end. - if flags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) { + if vp.pipe.isNamed && statusFlags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) { + fd.DecRef() return nil, syserror.EINTR } case writable: newHandleLocked(&vp.wWakeup) - if !vp.pipe.HasReaders() { - // Nonblocking, write-only opens fail with ENXIO when - // the read side isn't open yet. - if flags&linux.O_NONBLOCK != 0 { + if vp.pipe.isNamed && !vp.pipe.HasReaders() { + // Non-blocking, write-only opens fail with ENXIO when the read + // side isn't open yet. + if statusFlags&linux.O_NONBLOCK != 0 { + fd.DecRef() return nil, syserror.ENXIO } // Wait for a reader to open the other end. if !waitFor(&vp.mu, &vp.rWakeup, ctx) { + fd.DecRef() return nil, syserror.EINTR } } @@ -114,96 +121,93 @@ func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, vfsd *vfs.Dentry, vfsfd *vf panic("invalid pipe flags: must be readable, writable, or both") } - return vfd, nil + return fd, nil } // Preconditions: vp.mu must be held. -func (vp *VFSPipe) open(vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) { - var fd VFSPipeFD - fd.flags = flags - fd.readable = vfs.MayReadFileWithOpenFlags(flags) - fd.writable = vfs.MayWriteFileWithOpenFlags(flags) - fd.vfsfd = vfsfd - fd.pipe = &vp.pipe +func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) *vfs.FileDescription { + fd := &VFSPipeFD{ + pipe: &vp.pipe, + } + fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + UseDentryMetadata: true, + }) switch { - case fd.readable && fd.writable: + case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable(): vp.pipe.rOpen() vp.pipe.wOpen() - case fd.readable: + case fd.vfsfd.IsReadable(): vp.pipe.rOpen() - case fd.writable: + case fd.vfsfd.IsWritable(): vp.pipe.wOpen() default: panic("invalid pipe flags: must be readable, writable, or both") } - return &fd, nil + return &fd.vfsfd } -// VFSPipeFD implements a subset of vfs.FileDescriptionImpl for pipes. It is -// expected that filesystesm will use this in a struct implementing -// vfs.FileDescriptionImpl. +// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. type VFSPipeFD struct { - pipe *Pipe - flags uint32 - readable bool - writable bool - vfsfd *vfs.FileDescription + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + + pipe *Pipe } // Release implements vfs.FileDescriptionImpl.Release. func (fd *VFSPipeFD) Release() { var event waiter.EventMask - if fd.readable { + if fd.vfsfd.IsReadable() { fd.pipe.rClose() - event |= waiter.EventIn + event |= waiter.EventOut } - if fd.writable { + if fd.vfsfd.IsWritable() { fd.pipe.wClose() - event |= waiter.EventOut + event |= waiter.EventIn | waiter.EventHUp } if event == 0 { panic("invalid pipe flags: must be readable, writable, or both") } - if fd.writable { - fd.vfsfd.VirtualDentry().Mount().EndWrite() - } - fd.pipe.Notify(event) } -// OnClose implements vfs.FileDescriptionImpl.OnClose. -func (fd *VFSPipeFD) OnClose(_ context.Context) error { - return nil +// Readiness implements waiter.Waitable.Readiness. +func (fd *VFSPipeFD) Readiness(mask waiter.EventMask) waiter.EventMask { + switch { + case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable(): + return fd.pipe.rwReadiness() + case fd.vfsfd.IsReadable(): + return fd.pipe.rReadiness() + case fd.vfsfd.IsWritable(): + return fd.pipe.wReadiness() + default: + panic("pipe FD is neither readable nor writable") + } } -// PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *VFSPipeFD) PRead(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.ReadOptions) (int64, error) { - return 0, syserror.ESPIPE +// EventRegister implements waiter.Waitable.EventRegister. +func (fd *VFSPipeFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fd.pipe.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (fd *VFSPipeFD) EventUnregister(e *waiter.Entry) { + fd.pipe.EventUnregister(e) } // Read implements vfs.FileDescriptionImpl.Read. func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { - if !fd.readable { - return 0, syserror.EINVAL - } - return fd.pipe.Read(ctx, dst) } -// PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *VFSPipeFD) PWrite(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.WriteOptions) (int64, error) { - return 0, syserror.ESPIPE -} - // Write implements vfs.FileDescriptionImpl.Write. func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { - if !fd.writable { - return 0, syserror.EINVAL - } - return fd.pipe.Write(ctx, src) } @@ -211,3 +215,17 @@ func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.Wr func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { return fd.pipe.Ioctl(ctx, uio, args) } + +// PipeSize implements fcntl(F_GETPIPE_SZ). +func (fd *VFSPipeFD) PipeSize() int64 { + // Inline Pipe.FifoSize() rather than calling it with nil Context and + // fs.File and ignoring the returned error (which is always nil). + fd.pipe.mu.Lock() + defer fd.pipe.mu.Unlock() + return fd.pipe.max +} + +// SetPipeSize implements fcntl(F_SETPIPE_SZ). +func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) { + return fd.pipe.SetFifoSize(size) +} diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 35ad97d5d..e23e796ef 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -184,7 +184,6 @@ func (t *Task) CanTrace(target *Task, attach bool) bool { if targetCreds.PermittedCaps&^callerCreds.PermittedCaps != 0 { return false } - // TODO: Yama LSM return true } diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 208569057..f66cfcc7f 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -461,7 +461,7 @@ func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.A func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) { s.mu.Lock() defer s.mu.Unlock() - // TODO(b/38173783): RemoveMapping may be called during task exit, when ctx + // RemoveMapping may be called during task exit, when ctx // is context.Background. Gracefully handle missing clocks. Failing to // update the detach time in these cases is ok, since no one can observe the // omission. diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 93c4fe969..84156d5a1 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -209,65 +209,61 @@ type Stracer interface { // SyscallEnter is called on syscall entry. // // The returned private data is passed to SyscallExit. - // - // TODO(gvisor.dev/issue/155): remove kernel imports from the strace - // package so that the type can be used directly. SyscallEnter(t *Task, sysno uintptr, args arch.SyscallArguments, flags uint32) interface{} // SyscallExit is called on syscall exit. SyscallExit(context interface{}, t *Task, sysno, rval uintptr, err error) } -// SyscallTable is a lookup table of system calls. Critically, a SyscallTable -// is *immutable*. In order to make supporting suspend and resume sane, they -// must be uniquely registered and may not change during operation. +// SyscallTable is a lookup table of system calls. // -// +stateify savable +// Note that a SyscallTable is not savable directly. Instead, they are saved as +// an OS/Arch pair and lookup happens again on restore. type SyscallTable struct { // OS is the operating system that this syscall table implements. - OS abi.OS `state:"wait"` + OS abi.OS // Arch is the architecture that this syscall table targets. - Arch arch.Arch `state:"wait"` + Arch arch.Arch // The OS version that this syscall table implements. - Version Version `state:"manual"` + Version Version // AuditNumber is a numeric constant that represents the syscall table. If // non-zero, auditNumber must be one of the AUDIT_ARCH_* values defined by // linux/audit.h. - AuditNumber uint32 `state:"manual"` + AuditNumber uint32 // Table is the collection of functions. - Table map[uintptr]Syscall `state:"manual"` + Table map[uintptr]Syscall // lookup is a fixed-size array that holds the syscalls (indexed by // their numbers). It is used for fast look ups. - lookup []SyscallFn `state:"manual"` + lookup []SyscallFn // Emulate is a collection of instruction addresses to emulate. The // keys are addresses, and the values are system call numbers. - Emulate map[usermem.Addr]uintptr `state:"manual"` + Emulate map[usermem.Addr]uintptr // The function to call in case of a missing system call. - Missing MissingFn `state:"manual"` + Missing MissingFn // Stracer traces this syscall table. - Stracer Stracer `state:"manual"` + Stracer Stracer // External is used to handle an external callback. - External func(*Kernel) `state:"manual"` + External func(*Kernel) // ExternalFilterBefore is called before External is called before the syscall is executed. // External is not called if it returns false. - ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"` + ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool // ExternalFilterAfter is called before External is called after the syscall is executed. // External is not called if it returns false. - ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"` + ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool // FeatureEnable stores the strace and one-shot enable bits. - FeatureEnable SyscallFlagsTable `state:"manual"` + FeatureEnable SyscallFlagsTable } // allSyscallTables contains all known tables. @@ -330,6 +326,13 @@ func RegisterSyscallTable(s *SyscallTable) { allSyscallTables = append(allSyscallTables, s) } +// FlushSyscallTablesTestOnly flushes the syscall tables for tests. Used for +// parameterized VFSv2 tests. +// TODO(gvisor.dv/issue/1624): Remove when VFS1 is no longer supported. +func FlushSyscallTablesTestOnly() { + allSyscallTables = nil +} + // Lookup returns the syscall implementation, if one exists. func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn { if sysno < uintptr(len(s.lookup)) { diff --git a/pkg/sentry/kernel/syscalls_state.go b/pkg/sentry/kernel/syscalls_state.go index 00358326b..90f890495 100644 --- a/pkg/sentry/kernel/syscalls_state.go +++ b/pkg/sentry/kernel/syscalls_state.go @@ -14,16 +14,34 @@ package kernel -import "fmt" +import ( + "fmt" -// afterLoad is invoked by stateify. -func (s *SyscallTable) afterLoad() { - otherTable, ok := LookupSyscallTable(s.OS, s.Arch) - if !ok { - // Couldn't find a reference? - panic(fmt.Sprintf("syscall table not found for OS %v Arch %v", s.OS, s.Arch)) + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/sentry/arch" +) + +// syscallTableInfo is used to reload the SyscallTable. +// +// +stateify savable +type syscallTableInfo struct { + OS abi.OS + Arch arch.Arch +} + +// saveSt saves the SyscallTable. +func (tc *TaskContext) saveSt() syscallTableInfo { + return syscallTableInfo{ + OS: tc.st.OS, + Arch: tc.st.Arch, } +} - // Copy the table. - *s = *otherTable +// loadSt loads the SyscallTable. +func (tc *TaskContext) loadSt(sti syscallTableInfo) { + st, ok := LookupSyscallTable(sti.OS, sti.Arch) + if !ok { + panic(fmt.Sprintf("syscall table not found for OS %v, Arch %v", sti.OS, sti.Arch)) + } + tc.st = st // Save the table reference. } diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index d6546735e..e5d133d6c 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -777,6 +777,15 @@ func (t *Task) NewFDs(fd int32, files []*fs.File, flags FDFlags) ([]int32, error return t.fdTable.NewFDs(t, fd, files, flags) } +// NewFDsVFS2 is a convenience wrapper for t.FDTable().NewFDsVFS2. +// +// This automatically passes the task as the context. +// +// Precondition: same as FDTable. +func (t *Task) NewFDsVFS2(fd int32, files []*vfs.FileDescription, flags FDFlags) ([]int32, error) { + return t.fdTable.NewFDsVFS2(t, fd, files, flags) +} + // NewFDFrom is a convenience wrapper for t.FDTable().NewFDs with a single file. // // This automatically passes the task as the context. diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 0158b1788..9fa528384 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -49,7 +49,7 @@ type TaskContext struct { fu *futex.Manager // st is the task's syscall table. - st *SyscallTable + st *SyscallTable `state:".(syscallTableInfo)"` } // release releases all resources held by the TaskContext. release is called by @@ -58,7 +58,6 @@ func (tc *TaskContext) release() { // Nil out pointers so that if the task is saved after release, it doesn't // follow the pointers to possibly now-invalid objects. if tc.MemoryManager != nil { - // TODO(b/38173783) tc.MemoryManager.DecUsers(context.Background()) tc.MemoryManager = nil } diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go index ce3e6ef28..0325967e4 100644 --- a/pkg/sentry/kernel/task_identity.go +++ b/pkg/sentry/kernel/task_identity.go @@ -455,7 +455,7 @@ func (t *Task) SetKeepCaps(k bool) { t.creds.Store(creds) } -// updateCredsForExec updates t.creds to reflect an execve(). +// updateCredsForExecLocked updates t.creds to reflect an execve(). // // NOTE(b/30815691): We currently do not implement privileged executables // (set-user/group-ID bits and file capabilities). This allows us to make a lot diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index 799cbcd93..2ba8d7e63 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -353,7 +353,7 @@ func (app *runApp) execute(t *Task) taskRunState { default: // What happened? Can't continue. t.Warningf("Unexpected SwitchToApp error: %v", err) - t.PrepareExit(ExitStatus{Code: t.ExtractErrno(err, -1)}) + t.PrepareExit(ExitStatus{Code: ExtractErrno(err, -1)}) return (*runExit)(nil) } } diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 8802db142..f07de2089 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -174,7 +174,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS fallthrough case (sre == ERESTARTSYS && !act.IsRestart()): t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo) - t.Arch().SetReturn(uintptr(-t.ExtractErrno(syserror.EINTR, -1))) + t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1))) default: t.Debugf("Restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo) t.Arch().RestartSyscall() @@ -513,8 +513,6 @@ func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool { if t.stop != nil { return false } - // - TODO(b/38173783): No special case for when t is also the sending task, - // because the identity of the sender is unknown. // - Do not choose tasks that have already been interrupted, as they may be // busy handling another signal. if len(t.interruptChan) != 0 { diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index d555d69a8..c9db78e06 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -194,6 +194,19 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u // // The syscall path is very hot; avoid defer. func (t *Task) doSyscall() taskRunState { + // Save value of the register which is clobbered in the following + // t.Arch().SetReturn(-ENOSYS) operation. This is dedicated to arm64. + // + // On x86, register rax was shared by syscall number and return + // value, and at the entry of the syscall handler, the rax was + // saved to regs.orig_rax which was exposed to user space. + // But on arm64, syscall number was passed through X8, and the X0 + // was shared by the first syscall argument and return value. The + // X0 was saved to regs.orig_x0 which was not exposed to user space. + // So we have to do the same operation here to save the X0 value + // into the task context. + t.Arch().SyscallSaveOrig() + sysno := t.Arch().SyscallNo() args := t.Arch().SyscallArgs() @@ -269,6 +282,7 @@ func (*runSyscallAfterSyscallEnterStop) execute(t *Task) taskRunState { return (*runSyscallExit)(nil) } args := t.Arch().SyscallArgs() + return t.doSyscallInvoke(sysno, args) } @@ -298,7 +312,7 @@ func (t *Task) doSyscallInvoke(sysno uintptr, args arch.SyscallArguments) taskRu return ctrl.next } } else if err != nil { - t.Arch().SetReturn(uintptr(-t.ExtractErrno(err, int(sysno)))) + t.Arch().SetReturn(uintptr(-ExtractErrno(err, int(sysno)))) t.haveSyscallReturn = true } else { t.Arch().SetReturn(rval) @@ -417,7 +431,7 @@ func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, calle // A return is not emulated in this case. return (*runApp)(nil) } - t.Arch().SetReturn(uintptr(-t.ExtractErrno(err, int(sysno)))) + t.Arch().SetReturn(uintptr(-ExtractErrno(err, int(sysno)))) } t.Arch().SetIP(t.Arch().Value(caller)) t.Arch().SetStack(t.Arch().Stack() + uintptr(t.Arch().Width())) @@ -427,7 +441,7 @@ func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, calle // ExtractErrno extracts an integer error number from the error. // The syscall number is purely for context in the error case. Use -1 if // syscall number is unknown. -func (t *Task) ExtractErrno(err error, sysno int) int { +func ExtractErrno(err error, sysno int) int { switch err := err.(type) { case nil: return 0 @@ -441,11 +455,11 @@ func (t *Task) ExtractErrno(err error, sysno int) int { // handled (and the SIGBUS is delivered). return int(syscall.EFAULT) case *os.PathError: - return t.ExtractErrno(err.Err, sysno) + return ExtractErrno(err.Err, sysno) case *os.LinkError: - return t.ExtractErrno(err.Err, sysno) + return ExtractErrno(err.Err, sysno) case *os.SyscallError: - return t.ExtractErrno(err.Err, sysno) + return ExtractErrno(err.Err, sysno) default: if errno, ok := syserror.TranslateError(err); ok { return int(errno) diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index 706de83ef..e959700f2 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -245,7 +245,7 @@ type Clock interface { type WallRateClock struct{} // WallTimeUntil implements Clock.WallTimeUntil. -func (WallRateClock) WallTimeUntil(t, now Time) time.Duration { +func (*WallRateClock) WallTimeUntil(t, now Time) time.Duration { return t.Sub(now) } @@ -254,16 +254,16 @@ func (WallRateClock) WallTimeUntil(t, now Time) time.Duration { type NoClockEvents struct{} // Readiness implements waiter.Waitable.Readiness. -func (NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask { +func (*NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask { return 0 } // EventRegister implements waiter.Waitable.EventRegister. -func (NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (*NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) { } // EventUnregister implements waiter.Waitable.EventUnregister. -func (NoClockEvents) EventUnregister(e *waiter.Entry) { +func (*NoClockEvents) EventUnregister(e *waiter.Entry) { } // ClockEventsQueue implements waiter.Waitable by wrapping waiter.Queue and @@ -273,7 +273,7 @@ type ClockEventsQueue struct { } // Readiness implements waiter.Waitable.Readiness. -func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask { +func (*ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask { return 0 } diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go index 0332fc71c..5c667117c 100644 --- a/pkg/sentry/mm/address_space.go +++ b/pkg/sentry/mm/address_space.go @@ -201,8 +201,10 @@ func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, pre if pma.needCOW { perms.Write = false } - if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil { - return err + if perms.Any() { // MapFile precondition + if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil { + return err + } } pseg = pseg.NextSegment() } diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index cb29d94b0..379148903 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -59,25 +59,27 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool { } a.contexts[id] = &AIOContext{ - done: make(chan struct{}, 1), + requestReady: make(chan struct{}, 1), maxOutstanding: events, } return true } -// destroyAIOContext destroys an asynchronous I/O context. +// destroyAIOContext destroys an asynchronous I/O context. It doesn't wait for +// for pending requests to complete. Returns the destroyed AIOContext so it can +// be drained. // -// False is returned if the context does not exist. -func (a *aioManager) destroyAIOContext(id uint64) bool { +// Nil is returned if the context does not exist. +func (a *aioManager) destroyAIOContext(id uint64) *AIOContext { a.mu.Lock() defer a.mu.Unlock() ctx, ok := a.contexts[id] if !ok { - return false + return nil } delete(a.contexts, id) ctx.destroy() - return true + return ctx } // lookupAIOContext looks up the given context. @@ -102,8 +104,8 @@ type ioResult struct { // // +stateify savable type AIOContext struct { - // done is the notification channel used for all requests. - done chan struct{} `state:"nosave"` + // requestReady is the notification channel used for all requests. + requestReady chan struct{} `state:"nosave"` // mu protects below. mu sync.Mutex `state:"nosave"` @@ -129,8 +131,14 @@ func (ctx *AIOContext) destroy() { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.dead = true - if ctx.outstanding == 0 { - close(ctx.done) + ctx.checkForDone() +} + +// Preconditions: ctx.mu must be held by caller. +func (ctx *AIOContext) checkForDone() { + if ctx.dead && ctx.outstanding == 0 { + close(ctx.requestReady) + ctx.requestReady = nil } } @@ -154,11 +162,12 @@ func (ctx *AIOContext) PopRequest() (interface{}, bool) { // Is there anything ready? if e := ctx.results.Front(); e != nil { - ctx.results.Remove(e) - ctx.outstanding-- - if ctx.outstanding == 0 && ctx.dead { - close(ctx.done) + if ctx.outstanding == 0 { + panic("AIOContext outstanding is going negative") } + ctx.outstanding-- + ctx.results.Remove(e) + ctx.checkForDone() return e.data, true } return nil, false @@ -172,26 +181,58 @@ func (ctx *AIOContext) FinishRequest(data interface{}) { // Push to the list and notify opportunistically. The channel notify // here is guaranteed to be safe because outstanding must be non-zero. - // The done channel is only closed when outstanding reaches zero. + // The requestReady channel is only closed when outstanding reaches zero. ctx.results.PushBack(&ioResult{data: data}) select { - case ctx.done <- struct{}{}: + case ctx.requestReady <- struct{}{}: default: } } // WaitChannel returns a channel that is notified when an AIO request is -// completed. -// -// The boolean return value indicates whether or not the context is active. -func (ctx *AIOContext) WaitChannel() (chan struct{}, bool) { +// completed. Returns nil if the context is destroyed and there are no more +// outstanding requests. +func (ctx *AIOContext) WaitChannel() chan struct{} { ctx.mu.Lock() defer ctx.mu.Unlock() - if ctx.outstanding == 0 && ctx.dead { - return nil, false + return ctx.requestReady +} + +// Dead returns true if the context has been destroyed. +func (ctx *AIOContext) Dead() bool { + ctx.mu.Lock() + defer ctx.mu.Unlock() + return ctx.dead +} + +// CancelPendingRequest forgets about a request that hasn't yet completed. +func (ctx *AIOContext) CancelPendingRequest() { + ctx.mu.Lock() + defer ctx.mu.Unlock() + + if ctx.outstanding == 0 { + panic("AIOContext outstanding is going negative") } - return ctx.done, true + ctx.outstanding-- + ctx.checkForDone() +} + +// Drain drops all completed requests. Pending requests remain untouched. +func (ctx *AIOContext) Drain() { + ctx.mu.Lock() + defer ctx.mu.Unlock() + + if ctx.outstanding == 0 { + return + } + size := uint32(ctx.results.Len()) + if ctx.outstanding < size { + panic("AIOContext outstanding is going negative") + } + ctx.outstanding -= size + ctx.results.Reset() + ctx.checkForDone() } // aioMappable implements memmap.MappingIdentity and memmap.Mappable for AIO @@ -332,9 +373,9 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint Length: aioRingBufferSize, MappingIdentity: m, Mappable: m, - // TODO(fvoznika): Linux does "do_mmap_pgoff(..., PROT_READ | - // PROT_WRITE, ...)" in fs/aio.c:aio_setup_ring(); why do we make this - // mapping read-only? + // Linux uses "do_mmap_pgoff(..., PROT_READ | PROT_WRITE, ...)" in + // fs/aio.c:aio_setup_ring(). Since we don't implement AIO_RING_MAGIC, + // user mode should not write to this page. Perms: usermem.Read, MaxPerms: usermem.Read, }) @@ -349,11 +390,11 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint return id, nil } -// DestroyAIOContext destroys an asynchronous I/O context. It returns false if -// the context does not exist. -func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) bool { +// DestroyAIOContext destroys an asynchronous I/O context. It returns the +// destroyed context. nil if the context does not exist. +func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext { if _, ok := mm.LookupAIOContext(ctx, id); !ok { - return false + return nil } // Only unmaps after it assured that the address is a valid aio context to diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go index c37fc9f7b..3dabac1af 100644 --- a/pkg/sentry/mm/aio_context_state.go +++ b/pkg/sentry/mm/aio_context_state.go @@ -16,5 +16,5 @@ package mm // afterLoad is invoked by stateify. func (a *AIOContext) afterLoad() { - a.done = make(chan struct{}, 1) + a.requestReady = make(chan struct{}, 1) } diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index d8a5b9d29..aac56679b 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -84,6 +84,7 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) { dumpability: mm.dumpability, aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, sleepForActivation: mm.sleepForActivation, + vdsoSigReturnAddr: mm.vdsoSigReturnAddr, } // Copy vmas. diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go index 6a49334f4..28e5057f7 100644 --- a/pkg/sentry/mm/metadata.go +++ b/pkg/sentry/mm/metadata.go @@ -167,3 +167,17 @@ func (mm *MemoryManager) SetExecutable(file fsbridge.File) { orig.DecRef() } } + +// VDSOSigReturn returns the address of vdso_sigreturn. +func (mm *MemoryManager) VDSOSigReturn() uint64 { + mm.metadataMu.Lock() + defer mm.metadataMu.Unlock() + return mm.vdsoSigReturnAddr +} + +// SetVDSOSigReturn sets the address of vdso_sigreturn. +func (mm *MemoryManager) SetVDSOSigReturn(addr uint64) { + mm.metadataMu.Lock() + defer mm.metadataMu.Unlock() + mm.vdsoSigReturnAddr = addr +} diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index c2195ae11..34d3bde7a 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -231,6 +231,9 @@ type MemoryManager struct { // before trying to activate the address space. When set to true, delays in // activation are not reported as stuck tasks by the watchdog. sleepForActivation bool + + // vdsoSigReturnAddr is the address of 'vdso_sigreturn'. + vdsoSigReturnAddr uint64 } // vma represents a virtual memory area. diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go index 79045651e..716198712 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -18,6 +18,8 @@ package kvm import ( "syscall" + + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) type kvmOneReg struct { @@ -46,6 +48,6 @@ type userRegs struct { func updateGlobalOnce(fd int) error { physicalInit() err := updateSystemValues(int(fd)) - updateVectorTable() + ring0.Init() return err } diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index b531f2f85..3b35858ae 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -48,69 +48,6 @@ func (m *machine) initArchState() error { return nil } -func getPageWithReflect(p uintptr) []byte { - return (*(*[0xFFFFFF]byte)(unsafe.Pointer(p & ^uintptr(syscall.Getpagesize()-1))))[:syscall.Getpagesize()] -} - -// Work around: move ring0.Vectors() into a specific address with 11-bits alignment. -// -// According to the design documentation of Arm64, -// the start address of exception vector table should be 11-bits aligned. -// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S -// But, we can't align a function's start address to a specific address by using golang. -// We have raised this question in golang community: -// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I -// This function will be removed when golang supports this feature. -// -// There are 2 jobs were implemented in this function: -// 1, move the start address of exception vector table into the specific address. -// 2, modify the offset of each instruction. -func updateVectorTable() { - fromLocation := reflect.ValueOf(ring0.Vectors).Pointer() - offset := fromLocation & (1<<11 - 1) - if offset != 0 { - offset = 1<<11 - offset - } - - toLocation := fromLocation + offset - page := getPageWithReflect(toLocation) - if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil { - panic(err) - } - - page = getPageWithReflect(toLocation + 4096) - if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil { - panic(err) - } - - // Move exception-vector-table into the specific address. - var entry *uint32 - var entryFrom *uint32 - for i := 1; i <= 0x800; i++ { - entry = (*uint32)(unsafe.Pointer(toLocation + 0x800 - uintptr(i))) - entryFrom = (*uint32)(unsafe.Pointer(fromLocation + 0x800 - uintptr(i))) - *entry = *entryFrom - } - - // The offset from the address of each unconditionally branch is changed. - // We should modify the offset of each instruction. - nums := []uint32{0x0, 0x80, 0x100, 0x180, 0x200, 0x280, 0x300, 0x380, 0x400, 0x480, 0x500, 0x580, 0x600, 0x680, 0x700, 0x780} - for _, num := range nums { - entry = (*uint32)(unsafe.Pointer(toLocation + uintptr(num))) - *entry = *entry - (uint32)(offset/4) - } - - page = getPageWithReflect(toLocation) - if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil { - panic(err) - } - - page = getPageWithReflect(toLocation + 4096) - if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil { - panic(err) - } -} - // initArchState initializes architecture-specific state. func (c *vCPU) initArchState() error { var ( diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 934b6fbcd..b69520030 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -72,11 +72,13 @@ go_library( "lib_amd64.s", "lib_arm64.go", "lib_arm64.s", + "lib_arm64_unsafe.go", "ring0.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/cpuid", + "//pkg/safecopy", "//pkg/sentry/platform/ring0/pagetables", "//pkg/usermem", ], diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index af075aae4..242b9305c 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -37,3 +37,10 @@ func SaveVRegs(*byte) // LoadVRegs loads V0-V31 registers. func LoadVRegs(*byte) + +// Init sets function pointers based on architectural features. +// +// This must be called prior to using ring0. +func Init() { + rewriteVectors() +} diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go new file mode 100644 index 000000000..c05166fea --- /dev/null +++ b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go @@ -0,0 +1,108 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +import ( + "reflect" + "syscall" + "unsafe" + + "gvisor.dev/gvisor/pkg/safecopy" + "gvisor.dev/gvisor/pkg/usermem" +) + +const ( + nopInstruction = 0xd503201f + instSize = unsafe.Sizeof(uint32(0)) + vectorsRawLen = 0x800 +) + +func unsafeSlice(addr uintptr, length int) (slice []uint32) { + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + hdr.Data = addr + hdr.Len = length / int(instSize) + hdr.Cap = length / int(instSize) + return slice +} + +// Work around: move ring0.Vectors() into a specific address with 11-bits alignment. +// +// According to the design documentation of Arm64, +// the start address of exception vector table should be 11-bits aligned. +// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S +// But, we can't align a function's start address to a specific address by using golang. +// We have raised this question in golang community: +// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I +// This function will be removed when golang supports this feature. +// +// There are 2 jobs were implemented in this function: +// 1, move the start address of exception vector table into the specific address. +// 2, modify the offset of each instruction. +func rewriteVectors() { + vectorsBegin := reflect.ValueOf(Vectors).Pointer() + + // The exception-vector-table is required to be 11-bits aligned. + // And the size is 0x800. + // Please see the documentation as reference: + // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table + // + // But, golang does not allow to set a function's address to a specific value. + // So, for gvisor, I defined the size of exception-vector-table as 4K, + // filled the 2nd 2K part with NOP-s. + // So that, I can safely move the 1st 2K part into the address with 11-bits alignment. + // + // So, the prerequisite for this function to work correctly is: + // vectorsSafeLen >= 0x1000 + // vectorsRawLen = 0x800 + vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin) + if vectorsSafeLen < 2*vectorsRawLen { + panic("Can't update vectors") + } + + vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32 + vectorsRawLen32 := vectorsRawLen / int(instSize) + + offset := vectorsBegin & (1<<11 - 1) + if offset != 0 { + offset = 1<<11 - offset + } + + pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1) + + _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC)) + if errno != 0 { + panic(errno.Error()) + } + + offset = offset / instSize // By index, not bytes. + // Move exception-vector-table into the specific address, should uses memmove here. + for i := 1; i <= vectorsRawLen32; i++ { + vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i] + } + + // Adjust branch since instruction was moved forward. + for i := 0; i < vectorsRawLen32; i++ { + if vectorsSafeTable[int(offset)+i] != nopInstruction { + vectorsSafeTable[int(offset)+i] -= uint32(offset) + } + } + + _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC)) + if errno != 0 { + panic(errno.Error()) + } +} diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 5d0085462..7ac38764d 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -300,7 +300,7 @@ type SocketOperations struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil { + if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { return nil, syserr.TranslateNetstackError(err) } } @@ -535,7 +535,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } if resCh != nil { - t := ctx.(*kernel.Task) + t := kernel.TaskFromContext(ctx) if err := t.Block(resCh); err != nil { return 0, syserr.FromError(err).ToError() } @@ -608,7 +608,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader } if resCh != nil { - t := ctx.(*kernel.Task) + t := kernel.TaskFromContext(ctx) if err := t.Block(resCh); err != nil { return 0, syserr.FromError(err).ToError() } @@ -965,6 +965,13 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return nil, syserr.ErrProtocolNotAvailable } +func boolToInt32(v bool) int32 { + if v { + return 1 + } + return 0 +} + // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. @@ -998,12 +1005,11 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.PasscredOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.PasscredOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1042,24 +1048,22 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.ReuseAddressOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.ReusePortOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReusePortOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1089,24 +1093,22 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.BroadcastOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.BroadcastOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.KeepaliveEnabledOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { @@ -1156,47 +1158,41 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptInt(tcpip.DelayOption) + v, err := ep.GetSockOptBool(tcpip.DelayOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - - if v == 0 { - return int32(1), nil - } - return int32(0), nil + return boolToInt32(!v), nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.CorkOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.CorkOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.QuickAckOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.QuickAckOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.MaxSegOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MaxSegOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1328,11 +1324,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1342,8 +1334,8 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if outLen == 0 { return make([]byte, 0), nil } - var v tcpip.IPv6TrafficClassOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1365,12 +1357,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil default: emitUnimplementedEventIPv6(t, name) @@ -1386,8 +1373,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.TTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.TTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1403,8 +1390,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastTTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1429,23 +1416,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastLoopOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - if v { - return int32(1), nil - } - return int32(0), nil + return boolToInt32(v), nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { return []byte(nil), nil } - var v tcpip.IPv4TOSOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { @@ -1462,11 +1445,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1477,11 +1456,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil default: emitUnimplementedEventIP(t, name) @@ -1592,7 +1567,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1600,7 +1575,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) @@ -1628,7 +1603,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0)) case linux.SO_PASSCRED: if len(optVal) < sizeOfInt32 { @@ -1636,7 +1611,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { @@ -1644,7 +1619,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0)) case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { @@ -1716,11 +1691,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - var o int - if v == 0 { - o = 1 - } - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o)) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -1728,7 +1699,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -1736,7 +1707,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -1744,7 +1715,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v))) case linux.TCP_KEEPIDLE: if len(optVal) < sizeOfInt32 { @@ -1855,7 +1826,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) if v == -1 { v = 0 } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v))) case linux.IPV6_RECVTCLASS: v, err := parseIntOrChar(optVal) @@ -1940,7 +1911,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if v < 0 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v))) case linux.IP_ADD_MEMBERSHIP: req, err := copyInMulticastRequest(optVal, false /* allowAddr */) @@ -1987,9 +1958,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt( - tcpip.MulticastLoopOption(v != 0), - )) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2008,7 +1977,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } else if v < 1 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v))) case linux.IP_TOS: if len(optVal) == 0 { @@ -2018,7 +1987,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v))) case linux.IP_RECVTOS: v, err := parseIntOrChar(optVal) diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 74bcd6300..c708b6030 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/ilist", + "//pkg/log", "//pkg/refs", "//pkg/sync", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2ef654235..2f1b127df 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" @@ -838,24 +839,43 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. Currently not supported. func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { - case tcpip.PasscredOption: - e.setPasscred(v != 0) - return nil - } return nil } func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.BroadcastOption: + case tcpip.PasscredOption: + e.setPasscred(v) + case tcpip.ReuseAddressOption: + default: + log.Warningf("Unsupported socket option: %d", opt) + } return nil } func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.SendBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: + default: + log.Warningf("Unsupported socket option: %d", opt) + } return nil } func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.PasscredOption: + return e.Passcred(), nil + + default: + log.Warningf("Unsupported socket option: %d", opt) + return false, tcpip.ErrUnknownProtocolOption + } } func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { @@ -914,29 +934,19 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return int(v), nil default: + log.Warningf("Unsupported socket option: %d", opt) return -1, tcpip.ErrUnknownProtocolOption } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) - } - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: + log.Warningf("Unsupported socket option: %T", opt) return tcpip.ErrUnknownProtocolOption } } diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 77655558e..68ca537c8 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -719,7 +719,7 @@ func (s SyscallMap) SyscallEnter(t *kernel.Task, sysno uintptr, args arch.Syscal // SyscallExit implements kernel.Stracer.SyscallExit. It logs the syscall // exit trace. func (s SyscallMap) SyscallExit(context interface{}, t *kernel.Task, sysno, rval uintptr, err error) { - errno := t.ExtractErrno(err, int(sysno)) + errno := kernel.ExtractErrno(err, int(sysno)) c := context.(*syscallContext) elapsed := time.Since(c.start) @@ -778,9 +778,6 @@ func (s SyscallMap) Name(sysno uintptr) string { // // N.B. This is not in an init function because we can't be sure all syscall // tables are registered with the kernel when init runs. -// -// TODO(gvisor.dev/issue/155): remove kernel package dependencies from this -// package and have the kernel package self-initialize all syscall tables. func Initialize() { for _, table := range kernel.SyscallTables() { // Is this known? diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index b401978db..d781d6a04 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -114,14 +114,28 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca func IoDestroy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { id := args[0].Uint64() - // Destroy the given context. - if !t.MemoryManager().DestroyAIOContext(t, id) { + ctx := t.MemoryManager().DestroyAIOContext(t, id) + if ctx == nil { // Does not exist. return 0, nil, syserror.EINVAL } - // FIXME(fvoznika): Linux blocks until all AIO to the destroyed context is - // done. - return 0, nil, nil + + // Drain completed requests amd wait for pending requests until there are no + // more. + for { + ctx.Drain() + + ch := ctx.WaitChannel() + if ch == nil { + // No more requests, we're done. + return 0, nil, nil + } + // The task cannot be interrupted during the wait. Equivalent to + // TASK_UNINTERRUPTIBLE in Linux. + t.UninterruptibleSleepStart(true /* deactivate */) + <-ch + t.UninterruptibleSleepFinish(true /* activate */) + } } // IoGetevents implements linux syscall io_getevents(2). @@ -200,13 +214,13 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadline ktime.Time) (interface{}, error) { for { if v, ok := ctx.PopRequest(); ok { - // Request was readly available. Just return it. + // Request was readily available. Just return it. return v, nil } // Need to wait for request completion. - done, active := ctx.WaitChannel() - if !active { + done := ctx.WaitChannel() + if done == nil { // Context has been destroyed. return nil, syserror.EINVAL } @@ -248,6 +262,10 @@ func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) { } func performCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *ioCallback, ioseq usermem.IOSequence, ctx *mm.AIOContext, eventFile *fs.File) { + if ctx.Dead() { + ctx.CancelPendingRequest() + return + } ev := &ioEvent{ Data: cb.Data, Obj: uint64(cbAddr), @@ -272,7 +290,7 @@ func performCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *ioC // Update the result. if err != nil { err = handleIOError(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", file) - ev.Result = -int64(t.ExtractErrno(err, 0)) + ev.Result = -int64(kernel.ExtractErrno(err, 0)) } file.DecRef() diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index 798344042..43c510930 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -24,6 +24,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// LINT.IfChange + // pipe2 implements the actual system call with flags. func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { @@ -45,10 +47,12 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { } if _, err := t.CopyOut(addr, fds); err != nil { - // The files are not closed in this case, the exact semantics - // of this error case are not well defined, but they could have - // already been observed by user space. - return 0, syserror.EFAULT + for _, fd := range fds { + if file, _ := t.FDTable().Remove(fd); file != nil { + file.DecRef() + } + } + return 0, err } return 0, nil } @@ -69,3 +73,5 @@ func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := pipe2(t, addr, flags) return n, nil, err } + +// LINT.ThenChange(vfs2/pipe.go) diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go index 9c6728530..f92bf8096 100644 --- a/pkg/sentry/syscalls/linux/sys_prctl.go +++ b/pkg/sentry/syscalls/linux/sys_prctl.go @@ -161,8 +161,8 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if args[1].Int() != 1 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 { return 0, nil, syserror.EINVAL } - // no_new_privs is assumed to always be set. See - // kernel.Task.updateCredsForExec. + // PR_SET_NO_NEW_PRIVS is assumed to always be set. + // See kernel.Task.updateCredsForExecLocked. return 0, nil, nil case linux.PR_GET_NO_NEW_PRIVS: diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index 78a2cb750..071b4bacc 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -96,8 +96,8 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, syserror.EINVAL } - // Check that the offset is legitimate. - if offset < 0 { + // Check that the offset is legitimate and does not overflow. + if offset < 0 || offset+int64(size) < 0 { return 0, nil, syserror.EINVAL } @@ -120,8 +120,8 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } defer file.DecRef() - // Check that the offset is legitimate. - if offset < 0 { + // Check that the offset is legitimate and does not overflow. + if offset < 0 || offset+int64(size) < 0 { return 0, nil, syserror.EINVAL } diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go index e08c333d6..d5d5b6959 100644 --- a/pkg/sentry/syscalls/linux/sys_rlimit.go +++ b/pkg/sentry/syscalls/linux/sys_rlimit.go @@ -197,7 +197,7 @@ func Prlimit64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys // saved set user IDs of the target process must match the real user ID of // the caller and the real, effective, and saved set group IDs of the // target process must match the real group ID of the caller." - if !t.HasCapabilityIn(linux.CAP_SYS_RESOURCE, t.PIDNamespace().UserNamespace()) { + if ot != t && !t.HasCapabilityIn(linux.CAP_SYS_RESOURCE, t.PIDNamespace().UserNamespace()) { cred, tcred := t.Credentials(), ot.Credentials() if cred.RealKUID != tcred.RealKUID || cred.RealKUID != tcred.EffectiveKUID || diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 2919228d0..0760af77b 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -31,6 +31,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// LINT.IfChange + // minListenBacklog is the minimum reasonable backlog for listening sockets. const minListenBacklog = 8 @@ -244,7 +246,11 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // Copy the file descriptors out. if _, err := t.CopyOut(socks, fds); err != nil { - // Note that we don't close files here; see pipe(2) also. + for _, fd := range fds { + if file, _ := t.FDTable().Remove(fd); file != nil { + file.DecRef() + } + } return 0, nil, err } @@ -1128,3 +1134,5 @@ func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen) return n, nil, err } + +// LINT.ThenChange(./vfs2/socket.go) diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index fd642834b..df0d0f461 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -25,10 +26,15 @@ import ( // doSplice implements a blocking splice operation. func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonBlocking bool) (int64, error) { - if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 { + log.Infof("NLAC: doSplice opts: %+v", opts) + if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 || (opts.SrcStart+opts.Length < 0) { return 0, syserror.EINVAL } + if opts.Length > int64(kernel.MAX_RW_COUNT) { + opts.Length = int64(kernel.MAX_RW_COUNT) + } + var ( total int64 n int64 diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 506ee54ce..6ec0de96e 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -87,8 +87,8 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } defer file.DecRef() - // Check that the offset is legitimate. - if offset < 0 { + // Check that the offset is legitimate and does not overflow. + if offset < 0 || offset+int64(size) < 0 { return 0, nil, syserror.EINVAL } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 0004e60d9..6ff2d84d2 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -18,9 +18,11 @@ go_library( "linux64_override_arm64.go", "mmap.go", "path.go", + "pipe.go", "poll.go", "read_write.go", "setstat.go", + "socket.go", "stat.go", "stat_amd64.go", "stat_arm64.go", @@ -32,21 +34,28 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/binary", "//pkg/bits", "//pkg/fspath", "//pkg/gohacks", "//pkg/sentry/arch", "//pkg/sentry/fsbridge", + "//pkg/sentry/fsimpl/pipefs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", "//pkg/sentry/loader", "//pkg/sentry/memmap", + "//pkg/sentry/socket", + "//pkg/sentry/socket/control", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/syscalls", "//pkg/sentry/syscalls/linux", "//pkg/sentry/vfs", "//pkg/sync", + "//pkg/syserr", "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 3afcea665..8181d80f4 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/syserror" ) @@ -140,6 +141,22 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(file.StatusFlags()), nil, nil case linux.F_SETFL: return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint()) + case linux.F_SETPIPE_SZ: + pipefile, ok := file.Impl().(*pipe.VFSPipeFD) + if !ok { + return 0, nil, syserror.EBADF + } + n, err := pipefile.SetPipeSize(int64(args[2].Int())) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil + case linux.F_GETPIPE_SZ: + pipefile, ok := file.Impl().(*pipe.VFSPipeFD) + if !ok { + return 0, nil, syserror.EBADF + } + return uintptr(pipefile.PipeSize()), nil, nil default: // TODO(gvisor.dev/issue/1623): Everything else is not yet supported. return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go index a859095e2..46d3e189c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go @@ -172,7 +172,7 @@ func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mo defer tpop.Release() file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{ - Flags: flags, + Flags: flags | linux.O_LARGEFILE, Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()), }) if err != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go index a61cc5059..62e98817d 100644 --- a/pkg/sentry/syscalls/linux/vfs2/getdents.go +++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go @@ -97,6 +97,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { // char d_name[]; /* Filename (null-terminated) */ // }; size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name) + size = (size + 7) &^ 7 // round up to multiple of 8 if size > cb.remaining { return syserror.EINVAL } @@ -106,7 +107,12 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { usermem.ByteOrder.PutUint16(buf[16:18], uint16(size)) buf[18] = dirent.Type copy(buf[19:], dirent.Name) - buf[size-1] = 0 // NUL terminator + // Zero out all remaining bytes in buf, including the NUL terminator + // after dirent.Name. + bufTail := buf[19+len(dirent.Name):] + for i := range bufTail { + bufTail[i] = 0 + } } else { // struct linux_dirent { // unsigned long d_ino; /* Inode number */ @@ -125,6 +131,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width())) } size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name) + size = (size + 7) &^ 7 // round up to multiple of sizeof(long) if size > cb.remaining { return syserror.EINVAL } @@ -133,9 +140,14 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff)) usermem.ByteOrder.PutUint16(buf[16:18], uint16(size)) copy(buf[18:], dirent.Name) - buf[size-3] = 0 // NUL terminator - buf[size-2] = 0 // zero padding byte - buf[size-1] = dirent.Type + // Zero out all remaining bytes in buf, including the NUL terminator + // after dirent.Name and the zero padding byte between the name and + // dirent type. + bufTail := buf[18+len(dirent.Name):] + for i := range bufTail { + bufTail[i] = 0 + } + bufTail[2] = dirent.Type } n, err := cb.t.CopyOutBytes(cb.addr, buf) if err != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go index 63febc2f7..21eb98444 100644 --- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go +++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go @@ -39,26 +39,27 @@ func Override(table map[uintptr]kernel.Syscall) { table[19] = syscalls.Supported("readv", Readv) table[20] = syscalls.Supported("writev", Writev) table[21] = syscalls.Supported("access", Access) - delete(table, 22) // pipe + table[22] = syscalls.Supported("pipe", Pipe) table[23] = syscalls.Supported("select", Select) table[32] = syscalls.Supported("dup", Dup) table[33] = syscalls.Supported("dup2", Dup2) delete(table, 40) // sendfile - delete(table, 41) // socket - delete(table, 42) // connect - delete(table, 43) // accept - delete(table, 44) // sendto - delete(table, 45) // recvfrom - delete(table, 46) // sendmsg - delete(table, 47) // recvmsg - delete(table, 48) // shutdown - delete(table, 49) // bind - delete(table, 50) // listen - delete(table, 51) // getsockname - delete(table, 52) // getpeername - delete(table, 53) // socketpair - delete(table, 54) // setsockopt - delete(table, 55) // getsockopt + // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. + table[41] = syscalls.PartiallySupported("socket", Socket, "In process of porting socket syscalls to VFS2.", nil) + table[42] = syscalls.PartiallySupported("connect", Connect, "In process of porting socket syscalls to VFS2.", nil) + table[43] = syscalls.PartiallySupported("accept", Accept, "In process of porting socket syscalls to VFS2.", nil) + table[44] = syscalls.PartiallySupported("sendto", SendTo, "In process of porting socket syscalls to VFS2.", nil) + table[45] = syscalls.PartiallySupported("recvfrom", RecvFrom, "In process of porting socket syscalls to VFS2.", nil) + table[46] = syscalls.PartiallySupported("sendmsg", SendMsg, "In process of porting socket syscalls to VFS2.", nil) + table[47] = syscalls.PartiallySupported("recvmsg", RecvMsg, "In process of porting socket syscalls to VFS2.", nil) + table[48] = syscalls.PartiallySupported("shutdown", Shutdown, "In process of porting socket syscalls to VFS2.", nil) + table[49] = syscalls.PartiallySupported("bind", Bind, "In process of porting socket syscalls to VFS2.", nil) + table[50] = syscalls.PartiallySupported("listen", Listen, "In process of porting socket syscalls to VFS2.", nil) + table[51] = syscalls.PartiallySupported("getsockname", GetSockName, "In process of porting socket syscalls to VFS2.", nil) + table[52] = syscalls.PartiallySupported("getpeername", GetPeerName, "In process of porting socket syscalls to VFS2.", nil) + table[53] = syscalls.PartiallySupported("socketpair", SocketPair, "In process of porting socket syscalls to VFS2.", nil) + table[54] = syscalls.PartiallySupported("getsockopt", GetSockOpt, "In process of porting socket syscalls to VFS2.", nil) + table[55] = syscalls.PartiallySupported("setsockopt", SetSockOpt, "In process of porting socket syscalls to VFS2.", nil) table[59] = syscalls.Supported("execve", Execve) table[72] = syscalls.Supported("fcntl", Fcntl) delete(table, 73) // flock @@ -144,18 +145,21 @@ func Override(table map[uintptr]kernel.Syscall) { delete(table, 285) // fallocate table[286] = syscalls.Supported("timerfd_settime", TimerfdSettime) table[287] = syscalls.Supported("timerfd_gettime", TimerfdGettime) - delete(table, 288) // accept4 + // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. + table[288] = syscalls.PartiallySupported("accept4", Accept4, "In process of porting socket syscalls to VFS2.", nil) delete(table, 289) // signalfd4 delete(table, 290) // eventfd2 table[291] = syscalls.Supported("epoll_create1", EpollCreate1) table[292] = syscalls.Supported("dup3", Dup3) - delete(table, 293) // pipe2 + table[293] = syscalls.Supported("pipe2", Pipe2) delete(table, 294) // inotify_init1 table[295] = syscalls.Supported("preadv", Preadv) table[296] = syscalls.Supported("pwritev", Pwritev) - delete(table, 299) // recvmmsg + // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. + table[299] = syscalls.PartiallySupported("recvmmsg", RecvMMsg, "In process of porting socket syscalls to VFS2.", nil) table[306] = syscalls.Supported("syncfs", Syncfs) - delete(table, 307) // sendmmsg + // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. + table[307] = syscalls.PartiallySupported("sendmmsg", SendMMsg, "In process of porting socket syscalls to VFS2.", nil) table[316] = syscalls.Supported("renameat2", Renameat2) delete(table, 319) // memfd_create table[322] = syscalls.Supported("execveat", Execveat) diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go new file mode 100644 index 000000000..4a01e4209 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Pipe implements Linux syscall pipe(2). +func Pipe(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + return 0, nil, pipe2(t, addr, 0) +} + +// Pipe2 implements Linux syscall pipe2(2). +func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + flags := args[1].Int() + return 0, nil, pipe2(t, addr, flags) +} + +func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error { + if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { + return syserror.EINVAL + } + r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK)) + defer r.DecRef() + defer w.DecRef() + + fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{r, w}, kernel.FDFlags{ + CloseOnExec: flags&linux.O_CLOEXEC != 0, + }) + if err != nil { + return err + } + if _, err := t.CopyOut(addr, fds); err != nil { + for _, fd := range fds { + if _, file := t.FDTable().Remove(fd); file != nil { + file.DecRef() + } + } + return err + } + return nil +} diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index 35f6308d6..6c6998f45 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -103,7 +103,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt // Issue the request and break out if it completes with anything other than // "would block". - n, err := file.Read(t, dst, opts) + n, err = file.Read(t, dst, opts) total += n if err != syserror.ErrWouldBlock { break @@ -130,8 +130,8 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } defer file.DecRef() - // Check that the offset is legitimate. - if offset < 0 { + // Check that the offset is legitimate and does not overflow. + if offset < 0 || offset+int64(size) < 0 { return 0, nil, syserror.EINVAL } @@ -248,7 +248,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of // Issue the request and break out if it completes with anything other than // "would block". - n, err := file.PRead(t, dst, offset+total, opts) + n, err = file.PRead(t, dst, offset+total, opts) total += n if err != syserror.ErrWouldBlock { break @@ -335,7 +335,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op // Issue the request and break out if it completes with anything other than // "would block". - n, err := file.Write(t, src, opts) + n, err = file.Write(t, src, opts) total += n if err != syserror.ErrWouldBlock { break @@ -362,8 +362,8 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } defer file.DecRef() - // Check that the offset is legitimate. - if offset < 0 { + // Check that the offset is legitimate and does not overflow. + if offset < 0 || offset+int64(size) < 0 { return 0, nil, syserror.EINVAL } @@ -480,7 +480,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o // Issue the request and break out if it completes with anything other than // "would block". - n, err := file.PWrite(t, src, offset+total, opts) + n, err = file.PWrite(t, src, offset+total, opts) total += n if err != syserror.ErrWouldBlock { break diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go new file mode 100644 index 000000000..b1ede32f0 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -0,0 +1,1139 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "time" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/control" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// minListenBacklog is the minimum reasonable backlog for listening sockets. +const minListenBacklog = 8 + +// maxListenBacklog is the maximum allowed backlog for listening sockets. +const maxListenBacklog = 1024 + +// maxAddrLen is the maximum socket address length we're willing to accept. +const maxAddrLen = 200 + +// maxOptLen is the maximum sockopt parameter length we're willing to accept. +const maxOptLen = 1024 * 8 + +// maxControlLen is the maximum length of the msghdr.msg_control buffer we're +// willing to accept. Note that this limit is smaller than Linux, which allows +// buffers upto INT_MAX. +const maxControlLen = 10 * 1024 * 1024 + +// nameLenOffset is the offset from the start of the MessageHeader64 struct to +// the NameLen field. +const nameLenOffset = 8 + +// controlLenOffset is the offset form the start of the MessageHeader64 struct +// to the ControlLen field. +const controlLenOffset = 40 + +// flagsOffset is the offset form the start of the MessageHeader64 struct +// to the Flags field. +const flagsOffset = 48 + +const sizeOfInt32 = 4 + +// messageHeader64Len is the length of a MessageHeader64 struct. +var messageHeader64Len = uint64(binary.Size(MessageHeader64{})) + +// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct. +var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{})) + +// baseRecvFlags are the flags that are accepted across recvmsg(2), +// recvmmsg(2), and recvfrom(2). +const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | linux.MSG_NOSIGNAL | linux.MSG_WAITALL | linux.MSG_TRUNC | linux.MSG_CTRUNC + +// MessageHeader64 is the 64-bit representation of the msghdr struct used in +// the recvmsg and sendmsg syscalls. +type MessageHeader64 struct { + // Name is the optional pointer to a network address buffer. + Name uint64 + + // NameLen is the length of the buffer pointed to by Name. + NameLen uint32 + _ uint32 + + // Iov is a pointer to an array of io vectors that describe the memory + // locations involved in the io operation. + Iov uint64 + + // IovLen is the length of the array pointed to by Iov. + IovLen uint64 + + // Control is the optional pointer to ancillary control data. + Control uint64 + + // ControlLen is the length of the data pointed to by Control. + ControlLen uint64 + + // Flags on the sent/received message. + Flags int32 + _ int32 +} + +// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in +// the recvmmsg and sendmmsg syscalls. +type multipleMessageHeader64 struct { + msgHdr MessageHeader64 + msgLen uint32 + _ int32 +} + +// CopyInMessageHeader64 copies a message header from user to kernel memory. +func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error { + b := t.CopyScratchBuffer(52) + if _, err := t.CopyInBytes(addr, b); err != nil { + return err + } + + msg.Name = usermem.ByteOrder.Uint64(b[0:]) + msg.NameLen = usermem.ByteOrder.Uint32(b[8:]) + msg.Iov = usermem.ByteOrder.Uint64(b[16:]) + msg.IovLen = usermem.ByteOrder.Uint64(b[24:]) + msg.Control = usermem.ByteOrder.Uint64(b[32:]) + msg.ControlLen = usermem.ByteOrder.Uint64(b[40:]) + msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:])) + + return nil +} + +// CaptureAddress allocates memory for and copies a socket address structure +// from the untrusted address space range. +func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) { + if addrlen > maxAddrLen { + return nil, syserror.EINVAL + } + + addrBuf := make([]byte, addrlen) + if _, err := t.CopyInBytes(addr, addrBuf); err != nil { + return nil, err + } + + return addrBuf, nil +} + +// writeAddress writes a sockaddr structure and its length to an output buffer +// in the unstrusted address space range. If the address is bigger than the +// buffer, it is truncated. +func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { + // Get the buffer length. + var bufLen uint32 + if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil { + return err + } + + if int32(bufLen) < 0 { + return syserror.EINVAL + } + + // Write the length unconditionally. + if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil { + return err + } + + if addr == nil { + return nil + } + + if bufLen > addrLen { + bufLen = addrLen + } + + // Copy as much of the address as will fit in the buffer. + encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr) + if bufLen > uint32(len(encodedAddr)) { + bufLen = uint32(len(encodedAddr)) + } + _, err := t.CopyOutBytes(addrPtr, encodedAddr[:int(bufLen)]) + return err +} + +// Socket implements the linux syscall socket(2). +func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + domain := int(args[0].Int()) + stype := args[1].Int() + protocol := int(args[2].Int()) + + // Check and initialize the flags. + if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { + return 0, nil, syserror.EINVAL + } + + // Create the new socket. + s, e := socket.NewVFS2(t, domain, linux.SockType(stype&0xf), protocol) + if e != nil { + return 0, nil, e.ToError() + } + defer s.DecRef() + + if err := s.SetStatusFlags(t, t.Credentials(), uint32(stype&linux.SOCK_NONBLOCK)); err != nil { + return 0, nil, err + } + + fd, err := t.NewFDFromVFS2(0, s, kernel.FDFlags{ + CloseOnExec: stype&linux.SOCK_CLOEXEC != 0, + }) + if err != nil { + return 0, nil, err + } + + return uintptr(fd), nil, nil +} + +// SocketPair implements the linux syscall socketpair(2). +func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + domain := int(args[0].Int()) + stype := args[1].Int() + protocol := int(args[2].Int()) + addr := args[3].Pointer() + + // Check and initialize the flags. + if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { + return 0, nil, syserror.EINVAL + } + + // Create the socket pair. + s1, s2, e := socket.PairVFS2(t, domain, linux.SockType(stype&0xf), protocol) + if e != nil { + return 0, nil, e.ToError() + } + // Adding to the FD table will cause an extra reference to be acquired. + defer s1.DecRef() + defer s2.DecRef() + + nonblocking := uint32(stype & linux.SOCK_NONBLOCK) + if err := s1.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil { + return 0, nil, err + } + if err := s2.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil { + return 0, nil, err + } + + // Create the FDs for the sockets. + flags := kernel.FDFlags{ + CloseOnExec: stype&linux.SOCK_CLOEXEC != 0, + } + fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{s1, s2}, flags) + if err != nil { + return 0, nil, err + } + + if _, err := t.CopyOut(addr, fds); err != nil { + for _, fd := range fds { + if _, file := t.FDTable().Remove(fd); file != nil { + file.DecRef() + } + } + return 0, nil, err + } + + return 0, nil, nil +} + +// Connect implements the linux syscall connect(2). +func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Uint() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Capture address and call syscall implementation. + a, err := CaptureAddress(t, addr, addrlen) + if err != nil { + return 0, nil, err + } + + blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0 + return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS) +} + +// accept is the implementation of the accept syscall. It is called by accept +// and accept4 syscall handlers. +func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) { + // Check that no unsupported flags are passed in. + if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { + return 0, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, syserror.ENOTSOCK + } + + // Call the syscall implementation for this socket, then copy the + // output address if one is specified. + blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0 + + peerRequested := addrLen != 0 + nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking) + if e != nil { + return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + } + if peerRequested { + // NOTE(magi): Linux does not give you an error if it can't + // write the data back out so neither do we. + if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL { + return 0, err + } + } + return uintptr(nfd), nil +} + +// Accept4 implements the linux syscall accept4(2). +func Accept4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Pointer() + flags := int(args[3].Int()) + + n, err := accept(t, fd, addr, addrlen, flags) + return n, nil, err +} + +// Accept implements the linux syscall accept(2). +func Accept(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Pointer() + + n, err := accept(t, fd, addr, addrlen, 0) + return n, nil, err +} + +// Bind implements the linux syscall bind(2). +func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Uint() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Capture address and call syscall implementation. + a, err := CaptureAddress(t, addr, addrlen) + if err != nil { + return 0, nil, err + } + + return 0, nil, s.Bind(t, a).ToError() +} + +// Listen implements the linux syscall listen(2). +func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + backlog := args[1].Int() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Per Linux, the backlog is silently capped to reasonable values. + if backlog <= 0 { + backlog = minListenBacklog + } + if backlog > maxListenBacklog { + backlog = maxListenBacklog + } + + return 0, nil, s.Listen(t, int(backlog)).ToError() +} + +// Shutdown implements the linux syscall shutdown(2). +func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + how := args[1].Int() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Validate how, then call syscall implementation. + switch how { + case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR: + default: + return 0, nil, syserror.EINVAL + } + + return 0, nil, s.Shutdown(t, int(how)).ToError() +} + +// GetSockOpt implements the linux syscall getsockopt(2). +func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + level := args[1].Int() + name := args[2].Int() + optValAddr := args[3].Pointer() + optLenAddr := args[4].Pointer() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Read the length. Reject negative values. + optLen := int32(0) + if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + return 0, nil, err + } + if optLen < 0 { + return 0, nil, syserror.EINVAL + } + + // Call syscall implementation then copy both value and value len out. + v, e := getSockOpt(t, s, int(level), int(name), optValAddr, int(optLen)) + if e != nil { + return 0, nil, e.ToError() + } + + vLen := int32(binary.Size(v)) + if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + return 0, nil, err + } + + if v != nil { + if _, err := t.CopyOut(optValAddr, v); err != nil { + return 0, nil, err + } + } + + return 0, nil, nil +} + +// getSockOpt tries to handle common socket options, or dispatches to a specific +// socket implementation. +func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { + if level == linux.SOL_SOCKET { + switch name { + case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: + if len < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + } + + switch name { + case linux.SO_TYPE: + _, skType, _ := s.Type() + return int32(skType), nil + case linux.SO_DOMAIN: + family, _, _ := s.Type() + return int32(family), nil + case linux.SO_PROTOCOL: + _, _, protocol := s.Type() + return int32(protocol), nil + } + } + + return s.GetSockOpt(t, level, name, optValAddr, len) +} + +// SetSockOpt implements the linux syscall setsockopt(2). +// +// Note that unlike Linux, enabling SO_PASSCRED does not autobind the socket. +func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + level := args[1].Int() + name := args[2].Int() + optValAddr := args[3].Pointer() + optLen := args[4].Int() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + if optLen < 0 { + return 0, nil, syserror.EINVAL + } + if optLen > maxOptLen { + return 0, nil, syserror.EINVAL + } + buf := t.CopyScratchBuffer(int(optLen)) + if _, err := t.CopyIn(optValAddr, &buf); err != nil { + return 0, nil, err + } + + // Call syscall implementation. + if err := s.SetSockOpt(t, int(level), int(name), buf); err != nil { + return 0, nil, err.ToError() + } + + return 0, nil, nil +} + +// GetSockName implements the linux syscall getsockname(2). +func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Pointer() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Get the socket name and copy it to the caller. + v, vl, err := s.GetSockName(t) + if err != nil { + return 0, nil, err.ToError() + } + + return 0, nil, writeAddress(t, v, vl, addr, addrlen) +} + +// GetPeerName implements the linux syscall getpeername(2). +func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Pointer() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Get the socket peer name and copy it to the caller. + v, vl, err := s.GetPeerName(t) + if err != nil { + return 0, nil, err.ToError() + } + + return 0, nil, writeAddress(t, v, vl, addr, addrlen) +} + +// RecvMsg implements the linux syscall recvmsg(2). +func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + msgPtr := args[1].Pointer() + flags := args[2].Int() + + if t.Arch().Width() != 8 { + // We only handle 64-bit for now. + return 0, nil, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Reject flags that we don't handle yet. + if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 { + return 0, nil, syserror.EINVAL + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + var haveDeadline bool + var deadline ktime.Time + if dl := s.RecvTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + + n, err := recvSingleMsg(t, s, msgPtr, flags, haveDeadline, deadline) + return n, nil, err +} + +// RecvMMsg implements the linux syscall recvmmsg(2). +func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + msgPtr := args[1].Pointer() + vlen := args[2].Uint() + flags := args[3].Int() + toPtr := args[4].Pointer() + + if t.Arch().Width() != 8 { + // We only handle 64-bit for now. + return 0, nil, syserror.EINVAL + } + + // Reject flags that we don't handle yet. + if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 { + return 0, nil, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + var haveDeadline bool + var deadline ktime.Time + if toPtr != 0 { + var ts linux.Timespec + if _, err := ts.CopyIn(t, toPtr); err != nil { + return 0, nil, err + } + if !ts.Valid() { + return 0, nil, syserror.EINVAL + } + deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration()) + haveDeadline = true + } + + if !haveDeadline { + if dl := s.RecvTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + } + + var count uint32 + var err error + for i := uint64(0); i < uint64(vlen); i++ { + mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len) + if !ok { + return 0, nil, syserror.EFAULT + } + var n uintptr + if n, err = recvSingleMsg(t, s, mp, flags, haveDeadline, deadline); err != nil { + break + } + + // Copy the received length to the caller. + lp, ok := mp.AddLength(messageHeader64Len) + if !ok { + return 0, nil, syserror.EFAULT + } + if _, err = t.CopyOut(lp, uint32(n)); err != nil { + break + } + count++ + } + + if count == 0 { + return 0, nil, err + } + return uintptr(count), nil, nil +} + +func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { + // Capture the message header and io vectors. + var msg MessageHeader64 + if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + return 0, err + } + + if msg.IovLen > linux.UIO_MAXIOV { + return 0, syserror.EMSGSIZE + } + dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, err + } + + // FIXME(b/63594852): Pretend we have an empty error queue. + if flags&linux.MSG_ERRQUEUE != 0 { + return 0, syserror.EAGAIN + } + + // Fast path when no control message nor name buffers are provided. + if msg.ControlLen == 0 && msg.NameLen == 0 { + n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) + if err != nil { + return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS) + } + if !cms.Unix.Empty() { + mflags |= linux.MSG_CTRUNC + cms.Release() + } + + if int(msg.Flags) != mflags { + // Copy out the flags to the caller. + if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + return 0, err + } + } + + return uintptr(n), nil + } + + if msg.ControlLen > maxControlLen { + return 0, syserror.ENOBUFS + } + n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen) + if e != nil { + return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + } + defer cms.Release() + + controlData := make([]byte, 0, msg.ControlLen) + controlData = control.PackControlMessages(t, cms, controlData) + + if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { + creds, _ := cms.Unix.Credentials.(control.SCMCredentials) + controlData, mflags = control.PackCredentials(t, creds, controlData, mflags) + } + + if cms.Unix.Rights != nil { + controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags) + } + + // Copy the address to the caller. + if msg.NameLen != 0 { + if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil { + return 0, err + } + } + + // Copy the control data to the caller. + if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { + return 0, err + } + if len(controlData) > 0 { + if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil { + return 0, err + } + } + + // Copy out the flags to the caller. + if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + return 0, err + } + + return uintptr(n), nil +} + +// recvFrom is the implementation of the recvfrom syscall. It is called by +// recvfrom and recv syscall handlers. +func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) { + if int(bufLen) < 0 { + return 0, syserror.EINVAL + } + + // Reject flags that we don't handle yet. + if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 { + return 0, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, syserror.ENOTSOCK + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + dst, err := t.SingleIOSequence(bufPtr, int(bufLen), usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, err + } + + var haveDeadline bool + var deadline ktime.Time + if dl := s.RecvTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + + n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) + cm.Release() + if e != nil { + return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + } + + // Copy the address to the caller. + if nameLenPtr != 0 { + if err := writeAddress(t, sender, senderLen, namePtr, nameLenPtr); err != nil { + return 0, err + } + } + + return uintptr(n), nil +} + +// RecvFrom implements the linux syscall recvfrom(2). +func RecvFrom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + bufPtr := args[1].Pointer() + bufLen := args[2].Uint64() + flags := args[3].Int() + namePtr := args[4].Pointer() + nameLenPtr := args[5].Pointer() + + n, err := recvFrom(t, fd, bufPtr, bufLen, flags, namePtr, nameLenPtr) + return n, nil, err +} + +// SendMsg implements the linux syscall sendmsg(2). +func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + msgPtr := args[1].Pointer() + flags := args[2].Int() + + if t.Arch().Width() != 8 { + // We only handle 64-bit for now. + return 0, nil, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Reject flags that we don't handle yet. + if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 { + return 0, nil, syserror.EINVAL + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + n, err := sendSingleMsg(t, s, file, msgPtr, flags) + return n, nil, err +} + +// SendMMsg implements the linux syscall sendmmsg(2). +func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + msgPtr := args[1].Pointer() + vlen := args[2].Uint() + flags := args[3].Int() + + if t.Arch().Width() != 8 { + // We only handle 64-bit for now. + return 0, nil, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Reject flags that we don't handle yet. + if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 { + return 0, nil, syserror.EINVAL + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + var count uint32 + var err error + for i := uint64(0); i < uint64(vlen); i++ { + mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len) + if !ok { + return 0, nil, syserror.EFAULT + } + var n uintptr + if n, err = sendSingleMsg(t, s, file, mp, flags); err != nil { + break + } + + // Copy the received length to the caller. + lp, ok := mp.AddLength(messageHeader64Len) + if !ok { + return 0, nil, syserror.EFAULT + } + if _, err = t.CopyOut(lp, uint32(n)); err != nil { + break + } + count++ + } + + if count == 0 { + return 0, nil, err + } + return uintptr(count), nil, nil +} + +func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) { + // Capture the message header. + var msg MessageHeader64 + if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + return 0, err + } + + var controlData []byte + if msg.ControlLen > 0 { + // Put an upper bound to prevent large allocations. + if msg.ControlLen > maxControlLen { + return 0, syserror.ENOBUFS + } + controlData = make([]byte, msg.ControlLen) + if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil { + return 0, err + } + } + + // Read the destination address if one is specified. + var to []byte + if msg.NameLen != 0 { + var err error + to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen) + if err != nil { + return 0, err + } + } + + // Read data then call the sendmsg implementation. + if msg.IovLen > linux.UIO_MAXIOV { + return 0, syserror.EMSGSIZE + } + src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, err + } + + controlMessages, err := control.Parse(t, s, controlData) + if err != nil { + return 0, err + } + + var haveDeadline bool + var deadline ktime.Time + if dl := s.SendTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + + // Call the syscall implementation. + n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) + err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) + if err != nil { + controlMessages.Release() + } + return uintptr(n), err +} + +// sendTo is the implementation of the sendto syscall. It is called by sendto +// and send syscall handlers. +func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) { + bl := int(bufLen) + if bl < 0 { + return 0, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, syserror.ENOTSOCK + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + // Read the destination address if one is specified. + var to []byte + var err error + if namePtr != 0 { + to, err = CaptureAddress(t, namePtr, nameLen) + if err != nil { + return 0, err + } + } + + src, err := t.SingleIOSequence(bufPtr, bl, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, err + } + + var haveDeadline bool + var deadline ktime.Time + if dl := s.SendTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + + // Call the syscall implementation. + n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)}) + return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file) +} + +// SendTo implements the linux syscall sendto(2). +func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + bufPtr := args[1].Pointer() + bufLen := args[2].Uint64() + flags := args[3].Int() + namePtr := args[4].Pointer() + nameLen := args[5].Uint() + + n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen) + return n, nil, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go index 89e9ff4d7..af455d5c1 100644 --- a/pkg/sentry/syscalls/linux/vfs2/xattr.go +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go @@ -51,7 +51,7 @@ func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSyml } defer tpop.Release() - names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop) + names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size)) if err != nil { return 0, nil, err } @@ -74,7 +74,7 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } defer file.DecRef() - names, err := file.Listxattr(t) + names, err := file.Listxattr(t, uint64(size)) if err != nil { return 0, nil, err } @@ -116,7 +116,10 @@ func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli return 0, nil, err } - value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, name) + value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetxattrOptions{ + Name: name, + Size: uint64(size), + }) if err != nil { return 0, nil, err } @@ -145,7 +148,7 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } - value, err := file.Getxattr(t, name) + value, err := file.Getxattr(t, &vfs.GetxattrOptions{Name: name, Size: uint64(size)}) if err != nil { return 0, nil, err } @@ -230,7 +233,7 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } - return 0, nil, file.Setxattr(t, vfs.SetxattrOptions{ + return 0, nil, file.Setxattr(t, &vfs.SetxattrOptions{ Name: name, Value: value, Flags: uint32(flags), diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index d1f6dfb45..a64d86122 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -245,7 +245,7 @@ func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath } // ListxattrAt implements FilesystemImpl.ListxattrAt. -func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) { +func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) { if !rp.Done() { return nil, syserror.ENOTDIR } @@ -253,7 +253,7 @@ func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([ } // GetxattrAt implements FilesystemImpl.GetxattrAt. -func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) { +func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) { if !rp.Done() { return "", syserror.ENOTDIR } diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 20c545fca..5976b5ccd 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -122,7 +122,7 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mn } fd.refs = 1 - fd.statusFlags = statusFlags | linux.O_LARGEFILE + fd.statusFlags = statusFlags fd.vd = VirtualDentry{ mount: mnt, dentry: d, @@ -401,11 +401,11 @@ type FileDescriptionImpl interface { Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) // Listxattr returns all extended attribute names for the file. - Listxattr(ctx context.Context) ([]string, error) + Listxattr(ctx context.Context, size uint64) ([]string, error) // Getxattr returns the value associated with the given extended attribute // for the file. - Getxattr(ctx context.Context, name string) (string, error) + Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) // Setxattr changes the value associated with the given extended attribute // for the file. @@ -605,18 +605,23 @@ func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. // Listxattr returns all extended attribute names for the file represented by // fd. -func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) { +// +// If the size of the list (including a NUL terminating byte after every entry) +// would exceed size, ERANGE may be returned. Note that implementations +// are free to ignore size entirely and return without error). In all cases, +// if size is 0, the list should be returned without error, regardless of size. +func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp) + names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size) vfsObj.putResolvingPath(rp) return names, err } - names, err := fd.impl.Listxattr(ctx) + names, err := fd.impl.Listxattr(ctx, size) if err == syserror.ENOTSUP { // Linux doesn't actually return ENOTSUP in this case; instead, // fs/xattr.c:vfs_listxattr() falls back to allowing the security @@ -629,34 +634,39 @@ func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) { // Getxattr returns the value associated with the given extended attribute for // the file represented by fd. -func (fd *FileDescription) Getxattr(ctx context.Context, name string) (string, error) { +// +// If the size of the return value exceeds opts.Size, ERANGE may be returned +// (note that implementations are free to ignore opts.Size entirely and return +// without error). In all cases, if opts.Size is 0, the value should be +// returned without error, regardless of size. +func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions) (string, error) { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, name) + val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts) vfsObj.putResolvingPath(rp) return val, err } - return fd.impl.Getxattr(ctx, name) + return fd.impl.Getxattr(ctx, *opts) } // Setxattr changes the value associated with the given extended attribute for // the file represented by fd. -func (fd *FileDescription) Setxattr(ctx context.Context, opts SetxattrOptions) error { +func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions) error { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, opts) + err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts) vfsObj.putResolvingPath(rp) return err } - return fd.impl.Setxattr(ctx, opts) + return fd.impl.Setxattr(ctx, *opts) } // Removexattr removes the given extended attribute from the file represented diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index d45e602ce..f4c111926 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -130,14 +130,14 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg // Listxattr implements FileDescriptionImpl.Listxattr analogously to // inode_operations::listxattr == NULL in Linux. -func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context) ([]string, error) { +func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context, size uint64) ([]string, error) { // This isn't exactly accurate; see FileDescription.Listxattr. return nil, syserror.ENOTSUP } // Getxattr implements FileDescriptionImpl.Getxattr analogously to // inode::i_opflags & IOP_XATTR == 0 in Linux. -func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, name string) (string, error) { +func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) { return "", syserror.ENOTSUP } diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index cd34782ff..a537a29d1 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -442,7 +442,13 @@ type FilesystemImpl interface { // - If extended attributes are not supported by the filesystem, // ListxattrAt returns nil. (See FileDescription.Listxattr for an // explanation.) - ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) + // + // - If the size of the list (including a NUL terminating byte after every + // entry) would exceed size, ERANGE may be returned. Note that + // implementations are free to ignore size entirely and return without + // error). In all cases, if size is 0, the list should be returned without + // error, regardless of size. + ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) // GetxattrAt returns the value associated with the given extended // attribute for the file at rp. @@ -451,7 +457,15 @@ type FilesystemImpl interface { // // - If extended attributes are not supported by the filesystem, GetxattrAt // returns ENOTSUP. - GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) + // + // - If an extended attribute named opts.Name does not exist, ENODATA is + // returned. + // + // - If the size of the return value exceeds opts.Size, ERANGE may be + // returned (note that implementations are free to ignore opts.Size entirely + // and return without error). In all cases, if opts.Size is 0, the value + // should be returned without error, regardless of size. + GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) // SetxattrAt changes the value associated with the given extended // attribute for the file at rp. @@ -460,6 +474,10 @@ type FilesystemImpl interface { // // - If extended attributes are not supported by the filesystem, SetxattrAt // returns ENOTSUP. + // + // - If XATTR_CREATE is set in opts.Flag and opts.Name already exists, + // EEXIST is returned. If XATTR_REPLACE is set and opts.Name does not exist, + // ENODATA is returned. SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error // RemovexattrAt removes the given extended attribute from the file at rp. @@ -468,6 +486,8 @@ type FilesystemImpl interface { // // - If extended attributes are not supported by the filesystem, // RemovexattrAt returns ENOTSUP. + // + // - If name does not exist, ENODATA is returned. RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error // BoundEndpointAt returns the Unix socket endpoint bound at the path rp. @@ -497,7 +517,7 @@ type FilesystemImpl interface { // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl. PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error - // TODO: inotify_add_watch() + // TODO(gvisor.dev/issue/1479): inotify_add_watch() } // PrependPathAtVFSRootError is returned by implementations of diff --git a/pkg/sentry/vfs/memxattr/BUILD b/pkg/sentry/vfs/memxattr/BUILD new file mode 100644 index 000000000..d8c4d27b9 --- /dev/null +++ b/pkg/sentry/vfs/memxattr/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "memxattr", + srcs = ["xattr.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + ], +) diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go new file mode 100644 index 000000000..cc1e7d764 --- /dev/null +++ b/pkg/sentry/vfs/memxattr/xattr.go @@ -0,0 +1,102 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memxattr provides a default, in-memory extended attribute +// implementation. +package memxattr + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" +) + +// SimpleExtendedAttributes implements extended attributes using a map of +// names to values. +// +// +stateify savable +type SimpleExtendedAttributes struct { + // mu protects the below fields. + mu sync.RWMutex `state:"nosave"` + xattrs map[string]string +} + +// Getxattr returns the value at 'name'. +func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, error) { + x.mu.RLock() + value, ok := x.xattrs[opts.Name] + x.mu.RUnlock() + if !ok { + return "", syserror.ENODATA + } + // Check that the size of the buffer provided in getxattr(2) is large enough + // to contain the value. + if opts.Size != 0 && uint64(len(value)) > opts.Size { + return "", syserror.ERANGE + } + return value, nil +} + +// Setxattr sets 'value' at 'name'. +func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error { + x.mu.Lock() + defer x.mu.Unlock() + if x.xattrs == nil { + if opts.Flags&linux.XATTR_REPLACE != 0 { + return syserror.ENODATA + } + x.xattrs = make(map[string]string) + } + + _, ok := x.xattrs[opts.Name] + if ok && opts.Flags&linux.XATTR_CREATE != 0 { + return syserror.EEXIST + } + if !ok && opts.Flags&linux.XATTR_REPLACE != 0 { + return syserror.ENODATA + } + + x.xattrs[opts.Name] = opts.Value + return nil +} + +// Listxattr returns all names in xattrs. +func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) { + // Keep track of the size of the buffer needed in listxattr(2) for the list. + listSize := 0 + x.mu.RLock() + names := make([]string, 0, len(x.xattrs)) + for n := range x.xattrs { + names = append(names, n) + // Add one byte per null terminator. + listSize += len(n) + 1 + } + x.mu.RUnlock() + if size != 0 && uint64(listSize) > size { + return nil, syserror.ERANGE + } + return names, nil +} + +// Removexattr removes the xattr at 'name'. +func (x *SimpleExtendedAttributes) Removexattr(name string) error { + x.mu.Lock() + defer x.mu.Unlock() + if _, ok := x.xattrs[name]; !ok { + return syserror.ENODATA + } + delete(x.xattrs, name) + return nil +} diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 1b8ecc415..f06946103 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -233,9 +233,9 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia } vd.dentry.mu.Lock() } - // TODO: Linux requires that either both the mount point and the mount root - // are directories, or neither are, and returns ENOTDIR if this is not the - // case. + // TODO(gvisor.dev/issue/1035): Linux requires that either both the mount + // point and the mount root are directories, or neither are, and returns + // ENOTDIR if this is not the case. mntns := vd.mount.ns mnt := newMount(vfs, fs, root, mntns, opts) vfs.mounts.seq.BeginWrite() @@ -274,9 +274,9 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti } } - // TODO(jamieliu): Linux special-cases umount of the caller's root, which - // we don't implement yet (we'll just fail it since the caller holds a - // reference on it). + // TODO(gvisor.dev/issue/1035): Linux special-cases umount of the caller's + // root, which we don't implement yet (we'll just fail it since the caller + // holds a reference on it). vfs.mounts.seq.BeginWrite() if opts.Flags&linux.MNT_DETACH == 0 { diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index 3b933468d..3335e4057 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -55,7 +55,7 @@ func TestMountTableInsertLookup(t *testing.T) { } } -// TODO: concurrent lookup/insertion/removal +// TODO(gvisor.dev/issue/1035): concurrent lookup/insertion/removal. // must be powers of 2 var benchNumMounts = []int{1 << 2, 1 << 5, 1 << 8} diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 2f04bf882..534528ce6 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -132,6 +132,20 @@ type SetStatOptions struct { Stat linux.Statx } +// GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(), +// FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and +// FileDescriptionImpl.Getxattr(). +type GetxattrOptions struct { + // Name is the name of the extended attribute to retrieve. + Name string + + // Size is the maximum value size that the caller will tolerate. If the value + // is larger than size, getxattr methods may return ERANGE, but they are also + // free to ignore the hint entirely (i.e. the value returned may be larger + // than size). All size checking is done independently at the syscall layer. + Size uint64 +} + // SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(), // FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and // FileDescriptionImpl.Setxattr(). diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 720b90d8f..cb5bbd781 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -335,7 +335,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia rp := vfs.getResolvingPath(creds, pop) for { err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts) - if err != nil { + if err == nil { vfs.putResolvingPath(rp) return nil } @@ -383,14 +383,11 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { // Remove: // - // - O_LARGEFILE, which we always report in FileDescription status flags - // since only 64-bit architectures are supported at this time. - // // - O_CLOEXEC, which affects file descriptors and therefore must be // handled outside of VFS. // // - Unknown flags. - opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_LARGEFILE | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC. if opts.Flags&linux.O_SYNC != 0 { opts.Flags |= linux.O_DSYNC @@ -680,10 +677,10 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti // ListxattrAt returns all extended attribute names for the file at the given // path. -func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) ([]string, error) { +func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) { rp := vfs.getResolvingPath(creds, pop) for { - names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp) + names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size) if err == nil { vfs.putResolvingPath(rp) return names, nil @@ -705,10 +702,10 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede // GetxattrAt returns the value associated with the given extended attribute // for the file at the given path. -func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) (string, error) { +func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetxattrOptions) (string, error) { rp := vfs.getResolvingPath(creds, pop) for { - val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, name) + val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts) if err == nil { vfs.putResolvingPath(rp) return val, nil diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index f7d6009a0..fcc46420f 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -319,8 +319,8 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo // Dump stack only if a new task is detected or if it sometime has // passed since the last time a stack dump was generated. - skipStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod - w.doAction(w.TaskTimeoutAction, skipStack, &buf) + showStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod + w.doAction(w.TaskTimeoutAction, showStack, &buf) } func (w *Watchdog) reportStuckWatchdog() { @@ -329,16 +329,15 @@ func (w *Watchdog) reportStuckWatchdog() { w.doAction(w.TaskTimeoutAction, false, &buf) } -// doAction will take the given action. If the action is LogWarnind and -// skipStack is true, then the stack printing will be skipped. -func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) { +// doAction will take the given action. If the action is LogWarning and +// showStack is false, then the stack printing will be skipped. +func (w *Watchdog) doAction(action Action, showStack bool, msg *bytes.Buffer) { switch action { case LogWarning: - if skipStack { + if !showStack { msg.WriteString("\n...[stack dump skipped]...") log.Warningf(msg.String()) return - } log.TracebackAll(msg.String()) w.lastStackDump = time.Now() diff --git a/pkg/state/state.go b/pkg/state/state.go index dbe507ab4..03ae2dbb0 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -241,10 +241,7 @@ func Register(name string, instance interface{}, fns Fns) { // // This function is used by the stateify tool. func IsZeroValue(val interface{}) bool { - if val == nil { - return true - } - return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface()) + return val == nil || reflect.ValueOf(val).Elem().IsZero() } // step captures one encoding / decoding step. On each step, there is up to one diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 8dc0f7c0e..c1745ba6a 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -107,6 +107,8 @@ func DstAddr(addr tcpip.Address) NetworkChecker { // TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). func TTL(ttl uint8) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + var v uint8 switch ip := h[0].(type) { case header.IPv4: @@ -310,6 +312,8 @@ func SrcPort(port uint16) TransportChecker { // DstPort creates a checker that checks the destination port. func DstPort(port uint16) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + if p := h.DestinationPort(); p != port { t.Errorf("Bad destination port, got %v, want %v", p, port) } @@ -336,6 +340,7 @@ func SeqNum(seq uint32) TransportChecker { func AckNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -350,6 +355,8 @@ func AckNum(seq uint32) TransportChecker { // Window creates a checker that checks the tcp window. func Window(window uint16) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -381,6 +388,8 @@ func TCPFlags(flags uint8) TransportChecker { // given mask, match the supplied flags. func TCPFlagsMatch(flags, mask uint8) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -398,6 +407,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker { // If wndscale is negative, the window scale option must not be present. func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -494,6 +505,8 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { // skipped. func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -612,6 +625,8 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { // Payload creates a checker that checks the payload. func Payload(want []byte) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + if got := h.Payload(); !reflect.DeepEqual(got, want) { t.Errorf("Wrong payload, got %v, want %v", got, want) } @@ -644,6 +659,7 @@ func ICMPv4(checkers ...TransportChecker) NetworkChecker { func ICMPv4Type(want header.ICMPv4Type) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv4, ok := h.(header.ICMPv4) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) @@ -658,6 +674,7 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker { func ICMPv4Code(want byte) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv4, ok := h.(header.ICMPv4) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) @@ -700,6 +717,7 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker { func ICMPv6Type(want header.ICMPv6Type) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv6, ok := h.(header.ICMPv6) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) @@ -714,6 +732,7 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker { func ICMPv6Code(want byte) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv6, ok := h.(header.ICMPv6) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) @@ -728,7 +747,7 @@ func ICMPv6Code(want byte) TransportChecker { // message for type of ty, with potentially additional checks specified by // checkers. // -// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid // NDP message as far as the size of the message (minSize) is concerned. The // values within the message are up to checkers to validate. func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { @@ -760,9 +779,9 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N // Neighbor Solicitation message (as per the raw wire format), with potentially // additional checks specified by checkers. // -// checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPNS message as far as the size of the messages concerned. The values within -// the message are up to checkers to validate. +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPNS message as far as the size of the message is concerned. The values +// within the message are up to checkers to validate. func NDPNS(checkers ...TransportChecker) NetworkChecker { return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) } @@ -780,7 +799,54 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { ns := header.NDPNeighborSolicit(icmp.NDPPayload()) if got := ns.TargetAddress(); got != want { - t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want) + t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) + } + } +} + +// NDPNA creates a checker that checks that the packet contains a valid NDP +// Neighbor Advertisement message (as per the raw wire format), with potentially +// additional checks specified by checkers. +// +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPNA message as far as the size of the message is concerned. The values +// within the message are up to checkers to validate. +func NDPNA(checkers ...TransportChecker) NetworkChecker { + return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...) +} + +// NDPNATargetAddress creates a checker that checks the Target Address field of +// a header.NDPNeighborAdvert. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNA message as far as the size is concerned. +func NDPNATargetAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + + if got := na.TargetAddress(); got != want { + t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) + } + } +} + +// NDPNASolicitedFlag creates a checker that checks the Solicited field of +// a header.NDPNeighborAdvert. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNA message as far as the size is concerned. +func NDPNASolicitedFlag(want bool) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + + if got := na.SolicitedFlag(); got != want { + t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) } } } @@ -819,6 +885,13 @@ func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) } + case header.NDPTargetLinkLayerAddressOption: + gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { + t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) + } default: t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) } @@ -831,6 +904,21 @@ func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption } } +// NDPNAOptions creates a checker that checks that the packet contains the +// provided NDP options within an NDP Neighbor Solicitation message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNA message as far as the size is concerned. +func NDPNAOptions(opts []header.NDPOption) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + ndpOptions(t, na.Options(), opts) + } +} + // NDPNSOptions creates a checker that checks that the packet contains the // provided NDP options within an NDP Neighbor Solicitation message. // @@ -849,7 +937,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker { // NDPRS creates a checker that checks that the packet contains a valid NDP // Router Solicitation message (as per the raw wire format). // -// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid // NDPRS as far as the size of the message is concerned. The values within the // message are up to checkers to validate. func NDPRS(checkers ...TransportChecker) NetworkChecker { diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 7094f3f0b..0cde694dc 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -21,6 +21,7 @@ go_library( "ndp_options.go", "ndp_router_advert.go", "ndp_router_solicit.go", + "ndpoptionidentifier_string.go", "tcp.go", "udp.go", ], diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index 7a0014ad9..14413f2ce 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -88,7 +88,7 @@ func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr { - t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", got, test.expectedLinkAddr) + t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", test.addr, got, test.expectedLinkAddr) } }) } diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index e6a6ad39b..5d3975c56 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -15,32 +15,47 @@ package header import ( + "bytes" "encoding/binary" "errors" "fmt" + "io" "math" "time" "gvisor.dev/gvisor/pkg/tcpip" ) +// NDPOptionIdentifier is an NDP option type identifier. +type NDPOptionIdentifier uint8 + const ( // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer // Address option, as per RFC 4861 section 4.6.1. - NDPSourceLinkLayerAddressOptionType = 1 + NDPSourceLinkLayerAddressOptionType NDPOptionIdentifier = 1 // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer // Address option, as per RFC 4861 section 4.6.1. - NDPTargetLinkLayerAddressOptionType = 2 + NDPTargetLinkLayerAddressOptionType NDPOptionIdentifier = 2 + + // NDPPrefixInformationType is the type of the Prefix Information + // option, as per RFC 4861 section 4.6.2. + NDPPrefixInformationType NDPOptionIdentifier = 3 + + // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS + // Server option, as per RFC 8106 section 5.1. + NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25 + // NDPDNSSearchListOptionType is the type of the DNS Search List option, + // as per RFC 8106 section 5.2. + NDPDNSSearchListOptionType = 31 +) + +const ( // NDPLinkLayerAddressSize is the size of a Source or Target Link Layer // Address option for an Ethernet address. NDPLinkLayerAddressSize = 8 - // NDPPrefixInformationType is the type of the Prefix Information - // option, as per RFC 4861 section 4.6.2. - NDPPrefixInformationType = 3 - // ndpPrefixInformationLength is the expected length, in bytes, of the // body of an NDP Prefix Information option, as per RFC 4861 section // 4.6.2 which specifies that the Length field is 4. Given this, the @@ -91,10 +106,6 @@ const ( // within an NDPPrefixInformation. ndpPrefixInformationPrefixOffset = 14 - // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS - // Server option, as per RFC 8106 section 5.1. - NDPRecursiveDNSServerOptionType = 25 - // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte // Lifetime field within an NDPRecursiveDNSServer. ndpRecursiveDNSServerLifetimeOffset = 2 @@ -103,10 +114,31 @@ const ( // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer. ndpRecursiveDNSServerAddressesOffset = 6 - // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS - // Server option's length field value when it contains at least one - // IPv6 address. - minNDPRecursiveDNSServerLength = 3 + // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS Server + // option's body size when it contains at least one IPv6 address, as per + // RFC 8106 section 5.3.1. + minNDPRecursiveDNSServerBodySize = 22 + + // ndpDNSSearchListLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPDNSSearchList. + ndpDNSSearchListLifetimeOffset = 2 + + // ndpDNSSearchListDomainNamesOffset is the start of the DNS search list + // domain names within an NDPDNSSearchList. + ndpDNSSearchListDomainNamesOffset = 6 + + // minNDPDNSSearchListBodySize is the minimum NDP DNS Search List option's + // body size when it contains at least one domain name, as per RFC 8106 + // section 5.3.1. + minNDPDNSSearchListBodySize = 14 + + // maxDomainNameLabelLength is the maximum length of a domain name + // label, as per RFC 1035 section 3.1. + maxDomainNameLabelLength = 63 + + // maxDomainNameLength is the maximum length of a domain name, including + // label AND label length octet, as per RFC 1035 section 3.1. + maxDomainNameLength = 255 // lengthByteUnits is the multiplier factor for the Length field of an // NDP option. That is, the length field for NDP options is in units of @@ -132,16 +164,13 @@ var ( // few NDPOption then modify the backing NDPOptions so long as the // NDPOptionIterator obtained before modification is no longer used. type NDPOptionIterator struct { - // The NDPOptions this NDPOptionIterator is iterating over. - opts NDPOptions + opts *bytes.Buffer } // Potential errors when iterating over an NDPOptions. var ( - ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted") - ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field") - ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") - ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC") + ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") + ErrNDPOptMalformedHeader = errors.New("NDP option has a malformed header") ) // Next returns the next element in the backing NDPOptions, or true if we are @@ -152,48 +181,50 @@ var ( func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { for { // Do we still have elements to look at? - if len(i.opts) == 0 { + if i.opts.Len() == 0 { return nil, true, nil } - // Do we have enough bytes for an NDP option that has a Length - // field of at least 1? Note, 0 in the Length field is invalid. - if len(i.opts) < lengthByteUnits { - return nil, true, ErrNDPOptBufExhausted - } - // Get the Type field. - t := i.opts[0] - - // Get the Length field. - l := i.opts[1] + temp, err := i.opts.ReadByte() + if err != nil { + if err != io.EOF { + // ReadByte should only ever return nil or io.EOF. + panic(fmt.Sprintf("unexpected error when reading the option's Type field: %s", err)) + } - // This would indicate an erroneous NDP option as the Length - // field should never be 0. - if l == 0 { - return nil, true, ErrNDPOptZeroLength + // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once + // we start parsing an option; we expect the buffer to contain enough + // bytes for the whole option. + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF) } + kind := NDPOptionIdentifier(temp) - // How many bytes are in the option body? - numBytes := int(l) * lengthByteUnits - numBodyBytes := numBytes - 2 - - potentialBody := i.opts[2:] + // Get the Length field. + length, err := i.opts.ReadByte() + if err != nil { + if err != io.EOF { + panic(fmt.Sprintf("unexpected error when reading the option's Length field for %s: %s", kind, err)) + } - // This would indicate an erroenous NDPOptions buffer as we ran - // out of the buffer in the middle of an NDP option. - if left := len(potentialBody); left < numBodyBytes { - return nil, true, ErrNDPOptBufExhausted + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Length field for %s: %w", kind, io.ErrUnexpectedEOF) } - // Get only the options body, leaving the rest of the options - // buffer alone. - body := potentialBody[:numBodyBytes] + // This would indicate an erroneous NDP option as the Length field should + // never be 0. + if length == 0 { + return nil, true, fmt.Errorf("zero valued Length field for %s: %w", kind, ErrNDPOptMalformedHeader) + } - // Update opts with the remaining options body. - i.opts = i.opts[numBytes:] + // Get the body. + numBytes := int(length) * lengthByteUnits + numBodyBytes := numBytes - 2 + body := i.opts.Next(numBodyBytes) + if len(body) < numBodyBytes { + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Body for %s: %w", kind, io.ErrUnexpectedEOF) + } - switch t { + switch kind { case NDPSourceLinkLayerAddressOptionType: return NDPSourceLinkLayerAddressOption(body), false, nil @@ -205,22 +236,23 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { // body is ndpPrefixInformationLength, as per RFC 4861 // section 4.6.2. if numBodyBytes != ndpPrefixInformationLength { - return nil, true, ErrNDPOptMalformedBody + return nil, true, fmt.Errorf("got %d bytes for NDP Prefix Information option's body, expected %d bytes: %w", numBodyBytes, ndpPrefixInformationLength, ErrNDPOptMalformedBody) } return NDPPrefixInformation(body), false, nil case NDPRecursiveDNSServerOptionType: - // RFC 8106 section 5.3.1 outlines that the RDNSS option - // must have a minimum length of 3 so it contains at - // least one IPv6 address. - if l < minNDPRecursiveDNSServerLength { - return nil, true, ErrNDPInvalidLength + opt := NDPRecursiveDNSServer(body) + if err := opt.checkAddresses(); err != nil { + return nil, true, err } - opt := NDPRecursiveDNSServer(body) - if len(opt.Addresses()) == 0 { - return nil, true, ErrNDPOptMalformedBody + return opt, false, nil + + case NDPDNSSearchListOptionType: + opt := NDPDNSSearchList(body) + if err := opt.checkDomainNames(); err != nil { + return nil, true, err } return opt, false, nil @@ -247,10 +279,16 @@ type NDPOptions []byte // // See NDPOptionIterator for more information. func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { - it := NDPOptionIterator{opts: b} + it := NDPOptionIterator{ + opts: bytes.NewBuffer(b), + } if check { - for it2 := it; true; { + it2 := NDPOptionIterator{ + opts: bytes.NewBuffer(b), + } + + for { if _, done, err := it2.Next(); err != nil || done { return it, err } @@ -278,7 +316,7 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { continue } - b[0] = o.Type() + b[0] = byte(o.Type()) // We know this safe because paddedLength would have returned // 0 if o had an invalid length (> 255 * lengthByteUnits). @@ -304,7 +342,7 @@ type NDPOption interface { fmt.Stringer // Type returns the type of the receiver. - Type() uint8 + Type() NDPOptionIdentifier // Length returns the length of the body of the receiver, in bytes. Length() int @@ -386,7 +424,7 @@ func (b NDPOptionsSerializer) Length() int { type NDPSourceLinkLayerAddressOption tcpip.LinkAddress // Type implements NDPOption.Type. -func (o NDPSourceLinkLayerAddressOption) Type() uint8 { +func (o NDPSourceLinkLayerAddressOption) Type() NDPOptionIdentifier { return NDPSourceLinkLayerAddressOptionType } @@ -426,7 +464,7 @@ func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { type NDPTargetLinkLayerAddressOption tcpip.LinkAddress // Type implements NDPOption.Type. -func (o NDPTargetLinkLayerAddressOption) Type() uint8 { +func (o NDPTargetLinkLayerAddressOption) Type() NDPOptionIdentifier { return NDPTargetLinkLayerAddressOptionType } @@ -466,7 +504,7 @@ func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { type NDPPrefixInformation []byte // Type implements NDPOption.Type. -func (o NDPPrefixInformation) Type() uint8 { +func (o NDPPrefixInformation) Type() NDPOptionIdentifier { return NDPPrefixInformationType } @@ -590,7 +628,7 @@ type NDPRecursiveDNSServer []byte // Type returns the type of an NDP Recursive DNS Server option. // // Type implements NDPOption.Type. -func (NDPRecursiveDNSServer) Type() uint8 { +func (NDPRecursiveDNSServer) Type() NDPOptionIdentifier { return NDPRecursiveDNSServerOptionType } @@ -613,7 +651,12 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { // String implements fmt.Stringer.String. func (o NDPRecursiveDNSServer) String() string { - return fmt.Sprintf("%T(%s valid for %s)", o, o.Addresses(), o.Lifetime()) + lt := o.Lifetime() + addrs, err := o.Addresses() + if err != nil { + return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err) + } + return fmt.Sprintf("%T(%s valid for %s)", o, addrs, lt) } // Lifetime returns the length of time that the DNS server addresses @@ -632,29 +675,225 @@ func (o NDPRecursiveDNSServer) Lifetime() time.Duration { // Addresses returns the recursive DNS server IPv6 addresses that may be // used for name resolution. // -// Note, some of the addresses returned MAY be link-local addresses. +// Note, the addresses MAY be link-local addresses. +func (o NDPRecursiveDNSServer) Addresses() ([]tcpip.Address, error) { + var addrs []tcpip.Address + return addrs, o.iterAddresses(func(addr tcpip.Address) { addrs = append(addrs, addr) }) +} + +// checkAddresses iterates over the addresses in an NDP Recursive DNS Server +// option and returns any error it encounters. +func (o NDPRecursiveDNSServer) checkAddresses() error { + return o.iterAddresses(nil) +} + +// iterAddresses iterates over the addresses in an NDP Recursive DNS Server +// option and calls a function with each valid unicast IPv6 address. // -// Addresses may panic if o does not hold valid IPv6 addresses. -func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address { - l := len(o) - if l < ndpRecursiveDNSServerAddressesOffset { - return nil +// Note, the addresses MAY be link-local addresses. +func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error { + if l := len(o); l < minNDPRecursiveDNSServerBodySize { + return fmt.Errorf("got %d bytes for NDP Recursive DNS Server option's body, expected at least %d bytes: %w", l, minNDPRecursiveDNSServerBodySize, io.ErrUnexpectedEOF) } - l -= ndpRecursiveDNSServerAddressesOffset + o = o[ndpRecursiveDNSServerAddressesOffset:] + l := len(o) if l%IPv6AddressSize != 0 { - return nil + return fmt.Errorf("NDP Recursive DNS Server option's body ends in the middle of an IPv6 address (addresses body size = %d bytes): %w", l, ErrNDPOptMalformedBody) } - buf := o[ndpRecursiveDNSServerAddressesOffset:] - var addrs []tcpip.Address - for len(buf) > 0 { - addr := tcpip.Address(buf[:IPv6AddressSize]) + for i := 0; len(o) != 0; i++ { + addr := tcpip.Address(o[:IPv6AddressSize]) if !IsV6UnicastAddress(addr) { - return nil + return fmt.Errorf("%d-th address (%s) in NDP Recursive DNS Server option is not a valid unicast IPv6 address: %w", i, addr, ErrNDPOptMalformedBody) + } + + if fn != nil { + fn(addr) } - addrs = append(addrs, addr) - buf = buf[IPv6AddressSize:] + + o = o[IPv6AddressSize:] } - return addrs + + return nil +} + +// NDPDNSSearchList is the NDP DNS Search List option, as defined by +// RFC 8106 section 5.2. +type NDPDNSSearchList []byte + +// Type implements NDPOption.Type. +func (o NDPDNSSearchList) Type() NDPOptionIdentifier { + return NDPDNSSearchListOptionType +} + +// Length implements NDPOption.Length. +func (o NDPDNSSearchList) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPDNSSearchList) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + for i := 0; i < ndpDNSSearchListLifetimeOffset; i++ { + b[i] = 0 + } + + return used +} + +// String implements fmt.Stringer.String. +func (o NDPDNSSearchList) String() string { + lt := o.Lifetime() + domainNames, err := o.DomainNames() + if err != nil { + return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err) + } + return fmt.Sprintf("%T(%s valid for %s)", o, domainNames, lt) +} + +// Lifetime returns the length of time that the DNS search list of domain names +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the domain names should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +func (o NDPDNSSearchList) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpDNSSearchListLifetimeOffset:])) +} + +// DomainNames returns a DNS search list of domain names. +// +// DomainNames will parse the backing buffer as outlined by RFC 1035 section +// 3.1 and return a list of strings, with all domain names in lower case. +func (o NDPDNSSearchList) DomainNames() ([]string, error) { + var domainNames []string + return domainNames, o.iterDomainNames(func(domainName string) { domainNames = append(domainNames, domainName) }) +} + +// checkDomainNames iterates over the domain names in an NDP DNS Search List +// option and returns any error it encounters. +func (o NDPDNSSearchList) checkDomainNames() error { + return o.iterDomainNames(nil) +} + +// iterDomainNames iterates over the domain names in an NDP DNS Search List +// option and calls a function with each valid domain name. +func (o NDPDNSSearchList) iterDomainNames(fn func(string)) error { + if l := len(o); l < minNDPDNSSearchListBodySize { + return fmt.Errorf("got %d bytes for NDP DNS Search List option's body, expected at least %d bytes: %w", l, minNDPDNSSearchListBodySize, io.ErrUnexpectedEOF) + } + + var searchList bytes.Reader + searchList.Reset(o[ndpDNSSearchListDomainNamesOffset:]) + + var scratch [maxDomainNameLength]byte + domainName := bytes.NewBuffer(scratch[:]) + + // Parse the domain names, as per RFC 1035 section 3.1. + for searchList.Len() != 0 { + domainName.Reset() + + // Parse a label within a domain name, as per RFC 1035 section 3.1. + for { + // The first byte is the label length. + labelLenByte, err := searchList.ReadByte() + if err != nil { + if err != io.EOF { + // ReadByte should only ever return nil or io.EOF. + panic(fmt.Sprintf("unexpected error when reading a label's length: %s", err)) + } + + // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected + // once we start parsing a domain name; we expect the buffer to contain + // enough bytes for the whole domain name. + return fmt.Errorf("unexpected exhausted buffer while parsing a new label for a domain from NDP Search List option: %w", io.ErrUnexpectedEOF) + } + labelLen := int(labelLenByte) + + // A zero-length label implies the end of a domain name. + if labelLen == 0 { + // If the domain name is empty or we have no callback function, do + // nothing further with the current domain name. + if domainName.Len() == 0 || fn == nil { + break + } + + // Ignore the trailing period in the parsed domain name. + domainName.Truncate(domainName.Len() - 1) + fn(domainName.String()) + break + } + + // The label's length must not exceed the maximum length for a label. + if labelLen > maxDomainNameLabelLength { + return fmt.Errorf("label length of %d bytes is greater than the max label length of %d bytes for an NDP Search List option: %w", labelLen, maxDomainNameLabelLength, ErrNDPOptMalformedBody) + } + + // The label (and trailing period) must not make the domain name too long. + if labelLen+1 > domainName.Cap()-domainName.Len() { + return fmt.Errorf("label would make an NDP Search List option's domain name longer than the max domain name length of %d bytes: %w", maxDomainNameLength, ErrNDPOptMalformedBody) + } + + // Copy the label and add a trailing period. + for i := 0; i < labelLen; i++ { + b, err := searchList.ReadByte() + if err != nil { + if err != io.EOF { + panic(fmt.Sprintf("unexpected error when reading domain name's label: %s", err)) + } + + return fmt.Errorf("read %d out of %d bytes for a domain name's label from NDP Search List option: %w", i, labelLen, io.ErrUnexpectedEOF) + } + + // As per RFC 1035 section 2.3.1: + // 1) the label must only contain ASCII include letters, digits and + // hyphens + // 2) the first character in a label must be a letter + // 3) the last letter in a label must be a letter or digit + + if !isLetter(b) { + if i == 0 { + return fmt.Errorf("first character of a domain name's label in an NDP Search List option must be a letter, got character code = %d: %w", b, ErrNDPOptMalformedBody) + } + + if b == '-' { + if i == labelLen-1 { + return fmt.Errorf("last character of a domain name's label in an NDP Search List option must not be a hyphen (-): %w", ErrNDPOptMalformedBody) + } + } else if !isDigit(b) { + return fmt.Errorf("domain name's label in an NDP Search List option may only contain letters, digits and hyphens, got character code = %d: %w", b, ErrNDPOptMalformedBody) + } + } + + // If b is an upper case character, make it lower case. + if isUpperLetter(b) { + b = b - 'A' + 'a' + } + + if err := domainName.WriteByte(b); err != nil { + panic(fmt.Sprintf("unexpected error writing label to domain name buffer: %s", err)) + } + } + if err := domainName.WriteByte('.'); err != nil { + panic(fmt.Sprintf("unexpected error writing trailing period to domain name buffer: %s", err)) + } + } + } + + return nil +} + +func isLetter(b byte) bool { + return b >= 'a' && b <= 'z' || isUpperLetter(b) +} + +func isUpperLetter(b byte) bool { + return b >= 'A' && b <= 'Z' +} + +func isDigit(b byte) bool { + return b >= '0' && b <= '9' } diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 1cb9f5dc8..dc4591253 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -16,6 +16,10 @@ package header import ( "bytes" + "errors" + "fmt" + "io" + "regexp" "testing" "time" @@ -115,7 +119,7 @@ func TestNDPNeighborAdvert(t *testing.T) { // Make sure flags got updated in the backing buffer. if got := b[ndpNAFlagsOffset]; got != 64 { - t.Errorf("got flags byte = %d, want = 64") + t.Errorf("got flags byte = %d, want = 64", got) } } @@ -543,8 +547,12 @@ func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { want := []tcpip.Address{ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", } - if got := opt.Addresses(); !cmp.Equal(got, want) { - t.Errorf("got Addresses = %v, want = %v", got, want) + addrs, err := opt.Addresses() + if err != nil { + t.Errorf("opt.Addresses() = %s", err) + } + if diff := cmp.Diff(addrs, want); diff != "" { + t.Errorf("mismatched addresses (-want +got):\n%s", diff) } // Iterator should not return anything else. @@ -638,8 +646,12 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) { if got := opt.Lifetime(); got != test.lifetime { t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) } - if got := opt.Addresses(); !cmp.Equal(got, test.addrs) { - t.Errorf("got Addresses = %v, want = %v", got, test.addrs) + addrs, err := opt.Addresses() + if err != nil { + t.Errorf("opt.Addresses() = %s", err) + } + if diff := cmp.Diff(addrs, test.addrs); diff != "" { + t.Errorf("mismatched addresses (-want +got):\n%s", diff) } // Iterator should not return anything else. @@ -657,42 +669,513 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) { } } +// TestNDPDNSSearchListOption tests the getters of NDPDNSSearchList. +func TestNDPDNSSearchListOption(t *testing.T) { + tests := []struct { + name string + buf []byte + lifetime time.Duration + domainNames []string + err error + }{ + { + name: "Valid1Label", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, 'a', 'b', 'c', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: []string{ + "abc", + }, + err: nil, + }, + { + name: "Valid2Label", + buf: []byte{ + 0, 0, + 0, 0, 0, 5, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 0, + 0, 0, 0, 0, 0, 0, + }, + lifetime: 5 * time.Second, + domainNames: []string{ + "abc.abcd", + }, + err: nil, + }, + { + name: "Valid3Label", + buf: []byte{ + 0, 0, + 1, 0, 0, 0, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + 0, 0, 0, 0, + }, + lifetime: 16777216 * time.Second, + domainNames: []string{ + "abc.abcd.e", + }, + err: nil, + }, + { + name: "Valid2Domains", + buf: []byte{ + 0, 0, + 1, 2, 3, 4, + 3, 'a', 'b', 'c', + 0, + 2, 'd', 'e', + 3, 'x', 'y', 'z', + 0, + 0, 0, 0, + }, + lifetime: 16909060 * time.Second, + domainNames: []string{ + "abc", + "de.xyz", + }, + err: nil, + }, + { + name: "Valid3DomainsMixedCase", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 3, 'a', 'B', 'c', + 0, + 2, 'd', 'E', + 3, 'X', 'y', 'z', + 0, + 1, 'J', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abc", + "de.xyz", + "j", + }, + err: nil, + }, + { + name: "ValidDomainAfterNULL", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 3, 'a', 'B', 'c', + 0, 0, 0, 0, + 2, 'd', 'E', + 3, 'X', 'y', 'z', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abc", + "de.xyz", + }, + err: nil, + }, + { + name: "Valid0Domains", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 0, + 0, 0, 0, 0, 0, 0, 0, + }, + lifetime: 0, + domainNames: nil, + err: nil, + }, + { + name: "NoTrailingNull", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 7, 'a', 'b', 'c', 'd', 'e', 'f', 'g', + }, + lifetime: 0, + domainNames: nil, + err: io.ErrUnexpectedEOF, + }, + { + name: "IncorrectLength", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 8, 'a', 'b', 'c', 'd', 'e', 'f', 'g', + }, + lifetime: 0, + domainNames: nil, + err: io.ErrUnexpectedEOF, + }, + { + name: "IncorrectLengthWithNULL", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 7, 'a', 'b', 'c', 'd', 'e', 'f', + 0, + }, + lifetime: 0, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "LabelOfLength63", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk", + }, + err: nil, + }, + { + name: "LabelOfLength64", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 64, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', + 0, + }, + lifetime: 0, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "DomainNameOfLength255", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghij", + }, + err: nil, + }, + { + name: "DomainNameOfLength256", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 0, + }, + lifetime: 0, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "StartingDigitForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, '9', 'b', 'c', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "StartingHyphenForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, '-', 'b', 'c', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "EndingHyphenForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, 'a', 'b', '-', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "EndingDigitForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, 'a', 'b', '9', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: []string{ + "ab9", + }, + err: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opt := NDPDNSSearchList(test.buf) + + if got := opt.Lifetime(); got != test.lifetime { + t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) + } + domainNames, err := opt.DomainNames() + if !errors.Is(err, test.err) { + t.Errorf("opt.DomainNames() = %s", err) + } + if diff := cmp.Diff(domainNames, test.domainNames); diff != "" { + t.Errorf("mismatched domain names (-want +got):\n%s", diff) + } + }) + } +} + +func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) { + for r := rune(0); r <= 255; r++ { + t.Run(fmt.Sprintf("RuneVal=%d", r), func(t *testing.T) { + buf := []byte{ + 0, 0, + 0, 0, 0, 0, + 3, 'a', 0 /* will be replaced */, 'c', + 0, + 0, 0, 0, + } + buf[8] = uint8(r) + opt := NDPDNSSearchList(buf) + + // As per RFC 1035 section 2.3.1, the label must only include ASCII + // letters, digits and hyphens (a-z, A-Z, 0-9, -). + var expectedErr error + re := regexp.MustCompile(`[a-zA-Z0-9-]`) + if !re.Match([]byte{byte(r)}) { + expectedErr = ErrNDPOptMalformedBody + } + + if domainNames, err := opt.DomainNames(); !errors.Is(err, expectedErr) { + t.Errorf("got opt.DomainNames() = (%s, %v), want = (_, %v)", domainNames, err, ErrNDPOptMalformedBody) + } + }) + } +} + +func TestNDPDNSSearchListOptionSerialize(t *testing.T) { + b := []byte{ + 9, 8, + 1, 0, 0, 0, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + } + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + expected := []byte{ + 31, 3, 0, 0, + 1, 0, 0, 0, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + 0, 0, 0, 0, + } + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPDNSSearchList(b), + } + if got, want := opts.Serialize(serializer), len(expected); got != want { + t.Errorf("got Serialize = %d, want = %d", got, want) + } + if !bytes.Equal(targetBuf, expected) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPDNSSearchListOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType) + } + + opt, ok := next.(NDPDNSSearchList) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next) + } + if got := opt.Type(); got != 31 { + t.Errorf("got Type = %d, want = 31", got) + } + if got := opt.Length(); got != 22 { + t.Errorf("got Length = %d, want = 22", got) + } + if got, want := opt.Lifetime(), 16777216*time.Second; got != want { + t.Errorf("got Lifetime = %s, want = %s", got, want) + } + domainNames, err := opt.DomainNames() + if err != nil { + t.Errorf("opt.DomainNames() = %s", err) + } + if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" { + t.Errorf("domain names mismatch (-want +got):\n%s", diff) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + // TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions // the iterator was returned for is malformed. func TestNDPOptionsIterCheck(t *testing.T) { tests := []struct { - name string - buf []byte - expected error + name string + buf []byte + expectedErr error }{ { - "ZeroLengthField", - []byte{0, 0, 0, 0, 0, 0, 0, 0}, - ErrNDPOptZeroLength, + name: "ZeroLengthField", + buf: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + expectedErr: ErrNDPOptMalformedHeader, }, { - "ValidSourceLinkLayerAddressOption", - []byte{1, 1, 1, 2, 3, 4, 5, 6}, - nil, + name: "ValidSourceLinkLayerAddressOption", + buf: []byte{1, 1, 1, 2, 3, 4, 5, 6}, + expectedErr: nil, }, { - "TooSmallSourceLinkLayerAddressOption", - []byte{1, 1, 1, 2, 3, 4, 5}, - ErrNDPOptBufExhausted, + name: "TooSmallSourceLinkLayerAddressOption", + buf: []byte{1, 1, 1, 2, 3, 4, 5}, + expectedErr: io.ErrUnexpectedEOF, }, { - "ValidTargetLinkLayerAddressOption", - []byte{2, 1, 1, 2, 3, 4, 5, 6}, - nil, + name: "ValidTargetLinkLayerAddressOption", + buf: []byte{2, 1, 1, 2, 3, 4, 5, 6}, + expectedErr: nil, }, { - "TooSmallTargetLinkLayerAddressOption", - []byte{2, 1, 1, 2, 3, 4, 5}, - ErrNDPOptBufExhausted, + name: "TooSmallTargetLinkLayerAddressOption", + buf: []byte{2, 1, 1, 2, 3, 4, 5}, + expectedErr: io.ErrUnexpectedEOF, }, { - "ValidPrefixInformation", - []byte{ + name: "ValidPrefixInformation", + buf: []byte{ 3, 4, 43, 64, 1, 2, 3, 4, 5, 6, 7, 8, @@ -702,11 +1185,11 @@ func TestNDPOptionsIterCheck(t *testing.T) { 17, 18, 19, 20, 21, 22, 23, 24, }, - nil, + expectedErr: nil, }, { - "TooSmallPrefixInformation", - []byte{ + name: "TooSmallPrefixInformation", + buf: []byte{ 3, 4, 43, 64, 1, 2, 3, 4, 5, 6, 7, 8, @@ -716,11 +1199,11 @@ func TestNDPOptionsIterCheck(t *testing.T) { 17, 18, 19, 20, 21, 22, 23, }, - ErrNDPOptBufExhausted, + expectedErr: io.ErrUnexpectedEOF, }, { - "InvalidPrefixInformationLength", - []byte{ + name: "InvalidPrefixInformationLength", + buf: []byte{ 3, 3, 43, 64, 1, 2, 3, 4, 5, 6, 7, 8, @@ -728,11 +1211,11 @@ func TestNDPOptionsIterCheck(t *testing.T) { 9, 10, 11, 12, 13, 14, 15, 16, }, - ErrNDPOptMalformedBody, + expectedErr: ErrNDPOptMalformedBody, }, { - "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation", - []byte{ + name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation", + buf: []byte{ // Source Link-Layer Address. 1, 1, 1, 2, 3, 4, 5, 6, @@ -749,11 +1232,11 @@ func TestNDPOptionsIterCheck(t *testing.T) { 17, 18, 19, 20, 21, 22, 23, 24, }, - nil, + expectedErr: nil, }, { - "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", - []byte{ + name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", + buf: []byte{ // Source Link-Layer Address. 1, 1, 1, 2, 3, 4, 5, 6, @@ -775,52 +1258,153 @@ func TestNDPOptionsIterCheck(t *testing.T) { 17, 18, 19, 20, 21, 22, 23, 24, }, - nil, + expectedErr: nil, }, { - "InvalidRecursiveDNSServerCutsOffAddress", - []byte{ + name: "InvalidRecursiveDNSServerCutsOffAddress", + buf: []byte{ 25, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 4, 5, 6, 7, }, - ErrNDPOptMalformedBody, + expectedErr: ErrNDPOptMalformedBody, }, { - "InvalidRecursiveDNSServerInvalidLengthField", - []byte{ + name: "InvalidRecursiveDNSServerInvalidLengthField", + buf: []byte{ 25, 2, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, }, - ErrNDPInvalidLength, + expectedErr: io.ErrUnexpectedEOF, }, { - "RecursiveDNSServerTooSmall", - []byte{ + name: "RecursiveDNSServerTooSmall", + buf: []byte{ 25, 1, 0, 0, 0, 0, 0, }, - ErrNDPOptBufExhausted, + expectedErr: io.ErrUnexpectedEOF, }, { - "RecursiveDNSServerMulticast", - []byte{ + name: "RecursiveDNSServerMulticast", + buf: []byte{ 25, 3, 0, 0, 0, 0, 0, 0, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, }, - ErrNDPOptMalformedBody, + expectedErr: ErrNDPOptMalformedBody, }, { - "RecursiveDNSServerUnspecified", - []byte{ + name: "RecursiveDNSServerUnspecified", + buf: []byte{ 25, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, - ErrNDPOptMalformedBody, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "DNSSearchListLargeCompliantRFC1035", + buf: []byte{ + 31, 33, 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', + 0, + }, + expectedErr: nil, + }, + { + name: "DNSSearchListNonCompliantRFC1035", + buf: []byte{ + 31, 33, 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "DNSSearchListValidSmall", + buf: []byte{ + 31, 2, 0, 0, + 0, 0, 0, 0, + 6, 'a', 'b', 'c', 'd', 'e', 'f', + 0, + }, + expectedErr: nil, + }, + { + name: "DNSSearchListTooSmall", + buf: []byte{ + 31, 1, 0, 0, + 0, 0, 0, + }, + expectedErr: io.ErrUnexpectedEOF, }, } @@ -828,8 +1412,8 @@ func TestNDPOptionsIterCheck(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := NDPOptions(test.buf) - if _, err := opts.Iter(true); err != test.expected { - t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected) + if _, err := opts.Iter(true); !errors.Is(err, test.expectedErr) { + t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expectedErr) } // test.buf may be malformed but we chose not to check diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go new file mode 100644 index 000000000..6fe9a336b --- /dev/null +++ b/pkg/tcpip/header/ndpoptionidentifier_string.go @@ -0,0 +1,50 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT. + +package header + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[NDPSourceLinkLayerAddressOptionType-1] + _ = x[NDPTargetLinkLayerAddressOptionType-2] + _ = x[NDPPrefixInformationType-3] + _ = x[NDPRecursiveDNSServerOptionType-25] +} + +const ( + _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType" + _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType" +) + +var ( + _NDPOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94} +) + +func (i NDPOptionIdentifier) String() string { + switch { + case 1 <= i && i <= 3: + i -= 1 + return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]] + case i == 25: + return _NDPOptionIdentifier_name_1 + default: + return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")" + } +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index b4a0ae53d..9bf67686d 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -50,13 +50,11 @@ type NotificationHandle struct { } type queue struct { + // c is the outbound packet channel. + c chan PacketInfo // mu protects fields below. - mu sync.RWMutex - // c is the outbound packet channel. Sending to c should hold mu. - c chan PacketInfo - numWrite int - numRead int - notify []*NotificationHandle + mu sync.RWMutex + notify []*NotificationHandle } func (q *queue) Close() { @@ -64,11 +62,8 @@ func (q *queue) Close() { } func (q *queue) Read() (PacketInfo, bool) { - q.mu.Lock() - defer q.mu.Unlock() select { case p := <-q.c: - q.numRead++ return p, true default: return PacketInfo{}, false @@ -76,15 +71,8 @@ func (q *queue) Read() (PacketInfo, bool) { } func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) { - // We have to receive from channel without holding the lock, since it can - // block indefinitely. This will cause a window that numWrite - numRead - // produces a larger number, but won't go to negative. numWrite >= numRead - // still holds. select { case pkt := <-q.c: - q.mu.Lock() - defer q.mu.Unlock() - q.numRead++ return pkt, true case <-ctx.Done(): return PacketInfo{}, false @@ -93,16 +81,12 @@ func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) { func (q *queue) Write(p PacketInfo) bool { wrote := false - - // It's important to make sure nobody can see numWrite until we increment it, - // so numWrite >= numRead holds. - q.mu.Lock() select { case q.c <- p: wrote = true - q.numWrite++ default: } + q.mu.Lock() notify := q.notify q.mu.Unlock() @@ -116,13 +100,7 @@ func (q *queue) Write(p PacketInfo) bool { } func (q *queue) Num() int { - q.mu.RLock() - defer q.mu.RUnlock() - n := q.numWrite - q.numRead - if n < 0 { - panic("numWrite < numRead") - } - return n + return len(q.c) } func (q *queue) AddNotify(notify Notification) *NotificationHandle { diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 7198742b7..b857ce9d0 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -91,7 +91,7 @@ func (p PacketDispatchMode) String() string { case PacketMMap: return "PacketMMap" default: - return fmt.Sprintf("unknown packet dispatch mode %v", p) + return fmt.Sprintf("unknown packet dispatch mode '%d'", p) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 062388f4d..be2537a82 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -21,11 +21,9 @@ package sniffer import ( - "bytes" "encoding/binary" "fmt" "io" - "os" "sync/atomic" "time" @@ -42,12 +40,12 @@ import ( // LogPackets must be accessed atomically. var LogPackets uint32 = 1 -// LogPacketsToFile is a flag used to enable or disable logging packets to a -// pcap file. Valid values are 0 or 1. A file must have been specified when the +// LogPacketsToPCAP is a flag used to enable or disable logging packets to a +// pcap writer. Valid values are 0 or 1. A writer must have been specified when the // sniffer was created for this flag to have effect. // -// LogPacketsToFile must be accessed atomically. -var LogPacketsToFile uint32 = 1 +// LogPacketsToPCAP must be accessed atomically. +var LogPacketsToPCAP uint32 = 1 var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.TransportProtocolNumber]int{ header.ICMPv4ProtocolNumber: header.IPv4MinimumSize, @@ -59,7 +57,7 @@ var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip. type endpoint struct { dispatcher stack.NetworkDispatcher lower stack.LinkEndpoint - file *os.File + writer io.Writer maxPCAPLen uint32 } @@ -99,23 +97,22 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error { }) } -// NewWithFile creates a new sniffer link-layer endpoint. It wraps around -// another endpoint and logs packets and they traverse the endpoint. +// NewWithWriter creates a new sniffer link-layer endpoint. It wraps around +// another endpoint and logs packets as they traverse the endpoint. // -// Packets can be logged to file in the pcap format. A sniffer created -// with this function will not emit packets using the standard log -// package. +// Packets are logged to writer in the pcap format. A sniffer created with this +// function will not emit packets using the standard log package. // // snapLen is the maximum amount of a packet to be saved. Packets with a length -// less than or equal too snapLen will be saved in their entirety. Longer +// less than or equal to snapLen will be saved in their entirety. Longer // packets will be truncated to snapLen. -func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) { - if err := writePCAPHeader(file, snapLen); err != nil { +func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (stack.LinkEndpoint, error) { + if err := writePCAPHeader(writer, snapLen); err != nil { return nil, err } return &endpoint{ lower: lower, - file: file, + writer: writer, maxPCAPLen: snapLen, }, nil } @@ -124,36 +121,7 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("recv", protocol, pkt.Data.First(), nil) - } - if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - vs := pkt.Data.Views() - length := pkt.Data.Size() - if length > int(e.maxPCAPLen) { - length = int(e.maxPCAPLen) - } - - buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(pkt.Data.Size()))); err != nil { - panic(err) - } - for _, v := range vs { - if length == 0 { - break - } - if len(v) > length { - v = v[:length] - } - if _, err := buf.Write([]byte(v)); err != nil { - panic(err) - } - length -= len(v) - } - if _, err := e.file.Write(buf.Bytes()); err != nil { - panic(err) - } - } + e.dumpPacket("recv", nil, protocol, &pkt) e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) } @@ -200,31 +168,43 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", protocol, pkt.Header.View(), gso) +func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + writer := e.writer + if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { + first := pkt.Header.View() + if len(first) == 0 { + first = pkt.Data.First() + } + logPacket(prefix, protocol, first, gso) } - if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - hdrBuf := pkt.Header.View() - length := len(hdrBuf) + pkt.Data.Size() - if length > int(e.maxPCAPLen) { - length = int(e.maxPCAPLen) + if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { + totalLength := pkt.Header.UsedLength() + pkt.Data.Size() + length := totalLength + if max := int(e.maxPCAPLen); length > max { + length = max } - - buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+pkt.Data.Size()))); err != nil { + if err := binary.Write(writer, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(totalLength))); err != nil { panic(err) } - if len(hdrBuf) > length { - hdrBuf = hdrBuf[:length] - } - if _, err := buf.Write(hdrBuf); err != nil { - panic(err) + write := func(b []byte) { + if len(b) > length { + b = b[:length] + } + for len(b) != 0 { + n, err := writer.Write(b) + if err != nil { + panic(err) + } + b = b[n:] + length -= n + } } - length -= len(hdrBuf) - logVectorisedView(pkt.Data, length, buf) - if _, err := e.file.Write(buf.Bytes()); err != nil { - panic(err) + write(pkt.Header.View()) + for _, view := range pkt.Data.Views() { + if length == 0 { + break + } + write(view) } } } @@ -233,7 +213,7 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumb // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - e.dumpPacket(gso, protocol, &pkt) + e.dumpPacket("send", gso, protocol, &pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } @@ -242,55 +222,21 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // forwards the request to the lower endpoint. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.dumpPacket(gso, protocol, pkt) + e.dumpPacket("send", gso, protocol, pkt) } return e.lower.WritePackets(r, gso, pkts, protocol) } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */) - } - if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - length := vv.Size() - if length > int(e.maxPCAPLen) { - length = int(e.maxPCAPLen) - } - - buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { - panic(err) - } - logVectorisedView(vv, length, buf) - if _, err := e.file.Write(buf.Bytes()); err != nil { - panic(err) - } - } + e.dumpPacket("send", nil, 0, &stack.PacketBuffer{ + Data: vv, + }) return e.lower.WriteRawPacket(vv) } -func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) { - if length <= 0 { - return - } - for _, v := range vv.Views() { - if len(v) > length { - v = v[:length] - } - n, err := buf.Write(v) - if err != nil { - panic(err) - } - length -= n - if length == 0 { - return - } - } -} - // Wait implements stack.LinkEndpoint.Wait. -func (*endpoint) Wait() {} +func (e *endpoint) Wait() { e.lower.Wait() } func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index b3e239ac7..1646d9cde 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -138,7 +138,8 @@ func TestDirectRequest(t *testing.T) { // Sleep tests are gross, but this will only potentially flake // if there's a bug. If there is no bug this will reliably // succeed. - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() if pkt, ok := c.linkEP.ReadContext(ctx); ok { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) } diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index a93a7621a..3f71fc520 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -31,6 +31,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index f91180aa3..b68983d10 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -138,53 +138,48 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P targetAddr := ns.TargetAddress() s := r.Stack() - rxNICID := r.NICID() - if isTentative, err := s.IsAddrTentative(rxNICID, targetAddr); err != nil { - // We will only get an error if rxNICID is unrecognized, - // which should not happen. For now short-circuit this - // packet. + if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { + // We will only get an error if the NIC is unrecognized, which should not + // happen. For now, drop this packet. // // TODO(b/141002840): Handle this better? return } else if isTentative { - // If the target address is tentative and the source - // of the packet is a unicast (specified) address, then - // the source of the packet is attempting to perform - // address resolution on the target. In this case, the - // solicitation is silently ignored, as per RFC 4862 - // section 5.4.3. + // If the target address is tentative and the source of the packet is a + // unicast (specified) address, then the source of the packet is + // attempting to perform address resolution on the target. In this case, + // the solicitation is silently ignored, as per RFC 4862 section 5.4.3. // - // If the target address is tentative and the source of - // the packet is the unspecified address (::), then we - // know another node is also performing DAD for the - // same address (since targetAddr is tentative for us, - // we know we are also performing DAD on it). In this - // case we let the stack know so it can handle such a - // scenario and do nothing further with the NDP NS. - if iph.SourceAddress() == header.IPv6Any { - s.DupTentativeAddrDetected(rxNICID, targetAddr) + // If the target address is tentative and the source of the packet is the + // unspecified address (::), then we know another node is also performing + // DAD for the same address (since the target address is tentative for us, + // we know we are also performing DAD on it). In this case we let the + // stack know so it can handle such a scenario and do nothing further with + // the NS. + if r.RemoteAddress == header.IPv6Any { + s.DupTentativeAddrDetected(e.nicID, targetAddr) } - // Do not handle neighbor solicitations targeted - // to an address that is tentative on the received - // NIC any further. + // Do not handle neighbor solicitations targeted to an address that is + // tentative on the NIC any further. return } - // At this point we know that targetAddr is not tentative on - // rxNICID so the packet is processed as defined in RFC 4861, - // as per RFC 4862 section 5.4.3. + // At this point we know that the target address is not tentative on the NIC + // so the packet is processed as defined in RFC 4861, as per RFC 4862 + // section 5.4.3. + // Is the NS targetting us? if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { - // We don't have a useful answer; the best we can do is ignore the request. return } - // If the NS message has the source link layer option, update the link - // address cache with the link address for the sender of the message. + // If the NS message contains the Source Link-Layer Address option, update + // the link address cache with the value of the option. // // TODO(b/148429853): Properly process the NS message and do Neighbor // Unreachability Detection. + var sourceLinkAddr tcpip.LinkAddress for { opt, done, err := it.Next() if err != nil { @@ -197,22 +192,36 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch opt := opt.(type) { case header.NDPSourceLinkLayerAddressOption: - e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, opt.EthernetAddress()) + // No RFCs define what to do when an NS message has multiple Source + // Link-Layer Address options. Since no interface can have multiple + // link-layer addresses, we consider such messages invalid. + if len(sourceLinkAddr) != 0 { + received.Invalid.Increment() + return + } + + sourceLinkAddr = opt.EthernetAddress() } } - optsSerializer := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), + unspecifiedSource := r.RemoteAddress == header.IPv6Any + + // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST + // NOT be included when the source IP address is the unspecified address. + // Otherwise, on link layers that have addresses this option MUST be + // included in multicast solicitations and SHOULD be included in unicast + // solicitations. + if len(sourceLinkAddr) == 0 { + if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { + received.Invalid.Increment() + return + } + } else if unspecifiedSource { + received.Invalid.Increment() + return + } else { + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr) } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - packet.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(packet.NDPPayload()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(targetAddr) - opts := na.Options() - opts.Serialize(optsSerializer) // ICMPv6 Neighbor Solicit messages are always sent to // specially crafted IPv6 multicast addresses. As a result, the @@ -225,6 +234,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P r := r.Clone() defer r.Release() r.LocalAddress = targetAddr + + // As per RFC 4861 section 7.2.4, if the the source of the solicitation is + // the unspecified address, the node MUST set the Solicited flag to zero and + // multicast the advertisement to the all-nodes address. + solicited := true + if unspecifiedSource { + solicited = false + r.RemoteAddress = header.IPv6AllNodesMulticastAddress + } + + // If the NS has a source link-layer option, use the link address it + // specifies as the remote link address for the response instead of the + // source link address of the packet. + // + // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link + // address cache for the right destination link address instead of manually + // patching the route with the remote link address if one is specified in a + // Source Link-Layer Address option. + if len(sourceLinkAddr) != 0 { + r.RemoteLinkAddress = sourceLinkAddr + } + + optsSerializer := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), + } + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + packet.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(packet.NDPPayload()) + na.SetSolicitedFlag(solicited) + na.SetOverrideFlag(true) + na.SetTargetAddress(targetAddr) + opts := na.Options() + opts.Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) @@ -258,40 +301,38 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P targetAddr := na.TargetAddress() stack := r.Stack() - rxNICID := r.NICID() - if isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr); err != nil { - // We will only get an error if rxNICID is unrecognized, - // which should not happen. For now short-circuit this - // packet. + if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil { + // We will only get an error if the NIC is unrecognized, which should not + // happen. For now short-circuit this packet. // // TODO(b/141002840): Handle this better? return } else if isTentative { - // We just got an NA from a node that owns an address we - // are performing DAD on, implying the address is not - // unique. In this case we let the stack know so it can - // handle such a scenario and do nothing furthur with + // We just got an NA from a node that owns an address we are performing + // DAD on, implying the address is not unique. In this case we let the + // stack know so it can handle such a scenario and do nothing furthur with // the NDP NA. - stack.DupTentativeAddrDetected(rxNICID, targetAddr) + stack.DupTentativeAddrDetected(e.nicID, targetAddr) return } - // At this point we know that the targetAddress is not tentative - // on rxNICID. However, targetAddr may still be assigned to - // rxNICID but not tentative (it could be permanent). Such a - // scenario is beyond the scope of RFC 4862. As such, we simply - // ignore such a scenario for now and proceed as normal. + // At this point we know that the target address is not tentative on the + // NIC. However, the target address may still be assigned to the NIC but not + // tentative (it could be permanent). Such a scenario is beyond the scope of + // RFC 4862. As such, we simply ignore such a scenario for now and proceed + // as normal. // + // TODO(b/143147598): Handle the scenario described above. Also inform the + // netstack integration that a duplicate address was detected outside of + // DAD. + // If the NA message has the target link layer option, update the link // address cache with the link address for the target of the message. // - // TODO(b/143147598): Handle the scenario described above. Also - // inform the netstack integration that a duplicate address was - // detected outside of DAD. - // // TODO(b/148429853): Properly process the NA message and do Neighbor // Unreachability Detection. + var targetLinkAddr tcpip.LinkAddress for { opt, done, err := it.Next() if err != nil { @@ -304,10 +345,22 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch opt := opt.(type) { case header.NDPTargetLinkLayerAddressOption: - e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, opt.EthernetAddress()) + // No RFCs define what to do when an NA message has multiple Target + // Link-Layer Address options. Since no interface can have multiple + // link-layer addresses, we consider such messages invalid. + if len(targetLinkAddr) != 0 { + received.Invalid.Increment() + return + } + + targetLinkAddr = opt.EthernetAddress() } } + if len(targetLinkAddr) != 0 { + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr) + } + case header.ICMPv6EchoRequest: received.EchoRequest.Increment() if len(v) < header.ICMPv6EchoMinimumSize { diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bae09ed94..bd099a7f8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -32,7 +32,8 @@ import ( const ( linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") ) var ( diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 95e5dbf8e..841a0cb7a 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -34,6 +34,7 @@ const ( // The least significant 3 bytes are the same as addr2 so both addr2 and // addr3 will have the same solicited-node address. addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" + addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03" // Tests use the extension header identifier values as uint8 instead of // header.IPv6ExtensionHeaderIdentifier. @@ -167,6 +168,8 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { // packets destined to the IPv6 solicited-node address of an assigned IPv6 // address. func TestReceiveOnSolicitedNodeAddr(t *testing.T) { + const nicID = 1 + tests := []struct { name string protocolFactory stack.TransportProtocol @@ -184,50 +187,61 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, }) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + e := channel.New(1, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as we haven't added - // those addresses. + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + // Should not receive a packet destined to the solicited node address of + // addr2/addr3 yet as we haven't added those addresses. test.rxf(t, s, e, addr1, snmc, 0) - if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } - // Should receive a packet destined to the solicited - // node address of addr2/addr3 now that we have added - // added addr2. + // Should receive a packet destined to the solicited node address of + // addr2/addr3 now that we have added added addr2. test.rxf(t, s, e, addr1, snmc, 1) - if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) + if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err) } - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have added addr3. + // Should still receive a packet destined to the solicited node address of + // addr2/addr3 now that we have added addr3. test.rxf(t, s, e, addr1, snmc, 2) - if err := s.RemoveAddress(1, addr2); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) + if err := s.RemoveAddress(nicID, addr2); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err) } - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have removed addr2. + // Should still receive a packet destined to the solicited node address of + // addr2/addr3 now that we have removed addr2. test.rxf(t, s, e, addr1, snmc, 3) - if err := s.RemoveAddress(1, addr3); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) + // Make sure addr3's endpoint does not get removed from the NIC by + // incrementing its reference count with a route. + r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err) + } + defer r.Release() + + if err := s.RemoveAddress(nicID, addr3); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err) } - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as both of them got - // removed. + // Should not receive a packet destined to the solicited node address of + // addr2/addr3 yet as both of them got removed, even though a route using + // addr3 exists. test.rxf(t, s, e, addr1, snmc, 3) }) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index b113aaacc..12b70f7e9 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -173,6 +174,257 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { } } +func TestNeighorSolicitationResponse(t *testing.T) { + const nicID = 1 + nicAddr := lladdr0 + remoteAddr := lladdr1 + nicAddrSNMC := header.SolicitedNodeAddr(nicAddr) + nicLinkAddr := linkAddr0 + remoteLinkAddr0 := linkAddr1 + remoteLinkAddr1 := linkAddr2 + + tests := []struct { + name string + nsOpts header.NDPOptionsSerializer + nsSrcLinkAddr tcpip.LinkAddress + nsSrc tcpip.Address + nsDst tcpip.Address + nsInvalid bool + naDstLinkAddr tcpip.LinkAddress + naSolicited bool + naSrc tcpip.Address + naDst tcpip.Address + }{ + { + name: "Unspecified source to multicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: false, + naSrc: nicAddr, + naDst: header.IPv6AllNodesMulticastAddress, + }, + { + name: "Unspecified source with source ll option to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + { + name: "Unspecified source to unicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: false, + naSrc: nicAddr, + naDst: header.IPv6AllNodesMulticastAddress, + }, + { + name: "Unspecified source with source ll option to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddr, + nsInvalid: true, + }, + + { + name: "Specified source with 1 source ll to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll different from route to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr1, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source to multicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + { + name: "Specified source with 2 source ll to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + + { + name: "Specified source to unicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll different from route to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr1, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 2 source ll to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(1, 1280, nicLinkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(nicAddr) + opts := ns.Options() + opts.Serialize(test.nsOpts) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + if test.nsInvalid { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + + if p, got := e.Read(); got { + t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) + } + + // If we expected the NS to be invalid, we have nothing else to check. + return + } + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + p, got := e.Read() + if !got { + t.Fatal("expected an NDP NA response") + } + + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + } + + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(test.naSrc), + checker.DstAddr(test.naDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNA( + checker.NDPNASolicitedFlag(test.naSolicited), + checker.NDPNATargetAddress(nicAddr), + checker.NDPNAOptions([]header.NDPOption{ + header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), + }), + )) + }) + } +} + // TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a // valid NDP NA message with the Target Link Layer Address option results in a // new entry in the link address cache for the target of the message. @@ -197,6 +449,13 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { name: "Invalid Length", optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, }, + { + name: "Multiple", + optsBuf: []byte{ + 2, 1, 2, 3, 4, 5, 6, 7, + 2, 1, 2, 3, 4, 5, 6, 8, + }, + }, } for _, test := range tests { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 7c9fc48d1..193a9dfde 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -241,6 +241,16 @@ type NDPDispatcher interface { // call functions on the stack itself. OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) + // OnDNSSearchListOption will be called when an NDP option with a DNS + // search list has been received. + // + // It is up to the caller to use the domain names in the search list + // for only their valid lifetime. OnDNSSearchListOption may be called + // with new or already known domain names. If called with known domain + // names, their valid lifetimes must be refreshed to lifetime (it may + // be increased, decreased or completely invalidated when lifetime = 0. + OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) + // OnDHCPv6Configuration will be called with an updated configuration that is // available via DHCPv6 for a specified NIC. // @@ -305,6 +315,15 @@ type NDPConfigurations struct { // lifetime(s) of the generated address changes; this option only // affects the generation of new addresses as part of SLAAC. AutoGenGlobalAddresses bool + + // AutoGenAddressConflictRetries determines how many times to attempt to retry + // generation of a permanent auto-generated address in response to DAD + // conflicts. + // + // If the method used to generate the address does not support creating + // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's + // MAC address), then no attempt will be made to resolve the conflict. + AutoGenAddressConflictRetries uint8 } // DefaultNDPConfigurations returns an NDPConfigurations populated with @@ -411,8 +430,23 @@ type slaacPrefixState struct { // Nonzero only when the address is not valid forever. validUntil time.Time + // Nonzero only when the address is not preferred forever. + preferredUntil time.Time + // The prefix's permanent address endpoint. + // + // May only be nil when a SLAAC address is being (re-)generated. Otherwise, + // must not be nil as all SLAAC prefixes must have a SLAAC address. ref *referencedNetworkEndpoint + + // The number of times a permanent address has been generated for the prefix. + // + // Addresses may be regenerated in reseponse to a DAD conflicts. + generationAttempts uint8 + + // The maximum number of times to attempt regeneration of a permanent SLAAC + // address in response to DAD conflicts. + maxGenerationAttempts uint8 } // startDuplicateAddressDetection performs Duplicate Address Detection. @@ -687,7 +721,16 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { continue } - ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime()) + addrs, _ := opt.Addresses() + ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime()) + + case header.NDPDNSSearchList: + if ndp.nic.stack.ndpDisp == nil { + continue + } + + domainNames, _ := opt.DomainNames() + ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime()) case header.NDPPrefixInformation: prefix := opt.Subnet() @@ -935,60 +978,83 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { return } - // If the preferred lifetime is zero, then the prefix should be considered - // deprecated. - deprecated := pl == 0 - ref := ndp.addSLAACAddr(prefix, deprecated) - if ref == nil { - // We were unable to generate a permanent address for prefix so do nothing - // further as there is no reason to maintain state for a SLAAC prefix we - // cannot generate a permanent address for. - return - } - state := slaacPrefixState{ deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { - prefixState, ok := ndp.slaacPrefixes[prefix] + state, ok := ndp.slaacPrefixes[prefix] if !ok { - panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the SLAAC prefix %s", prefix)) + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) } - ndp.deprecateSLAACAddress(prefixState.ref) + ndp.deprecateSLAACAddress(state.ref) }), invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { - ndp.invalidateSLAACPrefix(prefix, true) + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) + } + + ndp.invalidateSLAACPrefix(prefix, state) }), - ref: ref, + maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, + } + + now := time.Now() + + // The time an address is preferred until is needed to properly generate the + // address. + if pl < header.NDPInfiniteLifetime { + state.preferredUntil = now.Add(pl) + } + + if !ndp.generateSLAACAddr(prefix, &state) { + // We were unable to generate an address for the prefix, we do not nothing + // further as there is no reason to maintain state or timers for a prefix we + // do not have an address for. + return } // Setup the initial timers to deprecate and invalidate prefix. - if !deprecated && pl < header.NDPInfiniteLifetime { + if pl < header.NDPInfiniteLifetime && pl != 0 { state.deprecationTimer.Reset(pl) } if vl < header.NDPInfiniteLifetime { state.invalidationTimer.Reset(vl) - state.validUntil = time.Now().Add(vl) + state.validUntil = now.Add(vl) } ndp.slaacPrefixes[prefix] = state } -// addSLAACAddr adds a SLAAC address for prefix. +// generateSLAACAddr generates a SLAAC address for prefix. +// +// Returns true if an address was successfully generated. +// +// Panics if the prefix is not a SLAAC prefix or it already has an address. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referencedNetworkEndpoint { +func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { + if r := state.ref; r != nil { + panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) + } + + // If we have already reached the maximum address generation attempts for the + // prefix, do not generate another address. + if state.generationAttempts == state.maxGenerationAttempts { + return false + } + addrBytes := []byte(prefix.ID()) if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { addrBytes = header.AppendOpaqueInterfaceIdentifier( addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), - 0, /* dadCounter */ + state.generationAttempts, oIID.SecretKey, ) - } else { + } else if state.generationAttempts == 0 { // Only attempt to generate an interface-specific IID if we have a valid // link address. // @@ -996,12 +1062,16 @@ func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referen // LinkEndpoint.LinkAddress) before reaching this point. linkAddr := ndp.nic.linkEP.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { - return nil + return false } // Generate an address within prefix from the modified EUI-64 of ndp's NIC's // Ethernet MAC address. header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + } else { + // We have no way to regenerate an address when addresses are not generated + // with opaque IIDs. + return false } generatedAddr := tcpip.ProtocolAddress{ @@ -1014,26 +1084,52 @@ func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referen // If the nic already has this address, do nothing further. if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) { - return nil + return false } // Inform the integrator that we have a new SLAAC address. ndpDisp := ndp.nic.stack.ndpDisp if ndpDisp == nil { - return nil + return false } if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) { // Informed by the integrator not to add the address. - return nil + return false } + deprecated := time.Since(state.preferredUntil) >= 0 ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) if err != nil { panic(fmt.Sprintf("ndp: error when adding address %+v: %s", generatedAddr, err)) } - return ref + state.generationAttempts++ + state.ref = ref + return true +} + +// regenerateSLAACAddr regenerates an address for a SLAAC prefix. +// +// If generating a new address for the prefix fails, the prefix will be +// invalidated. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix)) + } + + if ndp.generateSLAACAddr(prefix, &state) { + ndp.slaacPrefixes[prefix] = state + return + } + + // We were unable to generate a permanent address for the SLAAC prefix so + // invalidate the prefix as there is no reason to maintain state for a + // SLAAC prefix we do not have an address for. + ndp.invalidateSLAACPrefix(prefix, state) } // refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. @@ -1060,9 +1156,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl tim // deprecation timer so it can be reset. prefixState.deprecationTimer.StopLocked() + now := time.Now() + // Reset the deprecation timer if prefix has a finite preferred lifetime. - if !deprecated && pl < header.NDPInfiniteLifetime { - prefixState.deprecationTimer.Reset(pl) + if pl < header.NDPInfiniteLifetime { + if !deprecated { + prefixState.deprecationTimer.Reset(pl) + } + prefixState.preferredUntil = now.Add(pl) + } else { + prefixState.preferredUntil = time.Time{} } // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: @@ -1105,7 +1208,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl tim prefixState.invalidationTimer.StopLocked() prefixState.invalidationTimer.Reset(effectiveVl) - prefixState.validUntil = time.Now().Add(effectiveVl) + prefixState.validUntil = now.Add(effectiveVl) } // deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP @@ -1121,48 +1224,60 @@ func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { ref.deprecated = true if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{ - Address: ref.ep.ID().LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - }) + ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix()) } } // invalidateSLAACPrefix invalidates a SLAAC prefix. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, removeAddr bool) { - state, ok := ndp.slaacPrefixes[prefix] - if !ok { - return +func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { + if r := state.ref; r != nil { + // Since we are already invalidating the prefix, do not invalidate the + // prefix when removing the address. + if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACPrefixInvalidation */); err != nil { + panic(fmt.Sprintf("ndp: removePermanentIPv6EndpointLocked(%s, false): %s", r.addrWithPrefix(), err)) + } } - state.deprecationTimer.StopLocked() - state.invalidationTimer.StopLocked() - delete(ndp.slaacPrefixes, prefix) + ndp.cleanupSLAACPrefixResources(prefix, state) +} - addr := state.ref.ep.ID().LocalAddress +// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's +// resources. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + } - if removeAddr { - if err := ndp.nic.removePermanentAddressLocked(addr); err != nil { - panic(fmt.Sprintf("ndp: removePermanentAddressLocked(%s): %s", addr, err)) - } + prefix := addr.Subnet() + state, ok := ndp.slaacPrefixes[prefix] + if !ok || state.ref == nil || addr.Address != state.ref.ep.ID().LocalAddress { + return } - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: state.ref.ep.PrefixLen(), - }) + if !invalidatePrefix { + // If the prefix is not being invalidated, disassociate the address from the + // prefix and do nothing further. + state.ref = nil + ndp.slaacPrefixes[prefix] = state + return } + + ndp.cleanupSLAACPrefixResources(prefix, state) } -// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC -// address's resources from ndp. +// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry. +// +// Panics if the SLAAC prefix is not known. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix) { - ndp.invalidateSLAACPrefix(addr.Subnet(), false) +func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { + state.deprecationTimer.StopLocked() + state.invalidationTimer.StopLocked() + delete(ndp.slaacPrefixes, prefix) } // cleanupState cleans up ndp's state. @@ -1181,7 +1296,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr func (ndp *ndpState) cleanupState(hostOnly bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() linkLocalPrefixes := 0 - for prefix := range ndp.slaacPrefixes { + for prefix, state := range ndp.slaacPrefixes { // RFC 4862 section 5 states that routers are also expected to generate a // link-local address so we do not invalidate them if we are cleaning up // host-only state. @@ -1190,7 +1305,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { continue } - ndp.invalidateSLAACPrefix(prefix, true) + ndp.invalidateSLAACPrefix(prefix, state) } if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 27dc8baf9..6dd460984 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -133,6 +133,12 @@ type ndpRDNSSEvent struct { rdnss ndpRDNSS } +type ndpDNSSLEvent struct { + nicID tcpip.NICID + domainNames []string + lifetime time.Duration +} + type ndpDHCPv6Event struct { nicID tcpip.NICID configuration stack.DHCPv6ConfigurationFromNDPRA @@ -150,6 +156,8 @@ type ndpDispatcher struct { rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent rdnssC chan ndpRDNSSEvent + dnsslC chan ndpDNSSLEvent + routeTable []tcpip.Route dhcpv6ConfigurationC chan ndpDHCPv6Event } @@ -257,6 +265,17 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc } } +// Implements stack.NDPDispatcher.OnDNSSearchListOption. +func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) { + if n.dnsslC != nil { + n.dnsslC <- ndpDNSSLEvent{ + nicID, + domainNames, + lifetime, + } + } +} + // Implements stack.NDPDispatcher.OnDHCPv6Configuration. func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { if c := n.dhcpv6ConfigurationC; c != nil { @@ -623,6 +642,12 @@ func TestDADFail(t *testing.T) { if want := (tcpip.AddressWithPrefix{}); addr != want { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } + + // Attempting to add the address again should not fail if the address's + // state was cleaned up when DAD failed. + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } }) } } @@ -1959,7 +1984,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // addr2 is deprecated but if explicitly requested, it should be used. fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) } // Another PI w/ 0 preferred lifetime should not result in a deprecation @@ -1972,7 +1997,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { } expectPrimaryAddr(addr1) if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) } // Refresh lifetimes of addr generated from prefix2. @@ -2084,7 +2109,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // addr1 is deprecated but if explicitly requested, it should be used. fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make @@ -2097,7 +2122,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } expectPrimaryAddr(addr2) if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Refresh lifetimes for addr of prefix1. @@ -2121,7 +2146,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // addr2 should be the primary endpoint now since it is not deprecated. expectPrimaryAddr(addr2) if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Wait for addr of prefix1 to be invalidated. @@ -2564,7 +2589,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { AddressWithPrefix: addr2, } if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) } // addr2 should be more preferred now since it is at the front of the primary // list. @@ -2783,6 +2808,461 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } } +// TestAutoGenAddrWithOpaqueIIDDADRetries tests the regeneration of an +// auto-generated IPv6 address in response to a DAD conflict. +func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { + const nicID = 1 + const nicName = "nic" + const dadTransmits = 1 + const retransmitTimer = time.Second + const maxMaxRetries = 3 + const lifetimeSeconds = 10 + + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { + for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { + addrTypes := []struct { + name string + ndpConfigs stack.NDPConfigurations + autoGenLinkLocal bool + subnet tcpip.Subnet + triggerSLAACFn func(e *channel.Endpoint) + }{ + { + name: "Global address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, + }, + subnet: subnet, + triggerSLAACFn: func(e *channel.Endpoint) { + // Receive an RA with prefix1 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + + }, + }, + { + name: "LinkLocal address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + AutoGenAddressConflictRetries: maxRetries, + }, + autoGenLinkLocal: true, + subnet: header.IPv6LinkLocalPrefix.Subnet(), + triggerSLAACFn: func(e *channel.Endpoint) {}, + }, + } + + for _, addrType := range addrTypes { + maxRetries := maxRetries + numFailures := numFailures + addrType := addrType + + t.Run(fmt.Sprintf("%s with %d max retries and %d failures", addrType.name, maxRetries, numFailures), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + addrType.triggerSLAACFn(e) + + // Simulate DAD conflicts so the address is regenerated. + for i := uint8(0); i < numFailures; i++ { + addrBytes := []byte(addrType.subnet.ID()) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, i, secretKey)), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + // Should not have any addresses assigned to the NIC. + mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); mainAddr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) + } + + // Simulate a DAD conflict. + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + + // Attempting to add the address manually should not fail if the + // address's state was cleaned up when DAD failed. + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + } + if err := s.RemoveAddress(nicID, addr.Address); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) + } + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + } + + // Should not have any addresses assigned to the NIC. + mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); mainAddr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) + } + + // If we had less failures than generation attempts, we should have an + // address after DAD resolves. + if maxRetries+1 > numFailures { + addrBytes := []byte(addrType.subnet.ID()) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, numFailures, secretKey)), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + + mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if mainAddr != addr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, addr) + } + } + + // Should not attempt address regeneration again. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): + } + }) + } + } + } +} + +// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is +// not made for SLAAC addresses generated with an IID based on the NIC's link +// address. +func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { + const nicID = 1 + const dadTransmits = 1 + const retransmitTimer = time.Second + const maxRetries = 3 + const lifetimeSeconds = 10 + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + addrTypes := []struct { + name string + ndpConfigs stack.NDPConfigurations + autoGenLinkLocal bool + subnet tcpip.Subnet + triggerSLAACFn func(e *channel.Endpoint) + }{ + { + name: "Global address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, + }, + subnet: subnet, + triggerSLAACFn: func(e *channel.Endpoint) { + // Receive an RA with prefix1 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + + }, + }, + { + name: "LinkLocal address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + AutoGenAddressConflictRetries: maxRetries, + }, + autoGenLinkLocal: true, + subnet: header.IPv6LinkLocalPrefix.Subnet(), + triggerSLAACFn: func(e *channel.Endpoint) {}, + }, + } + + for _, addrType := range addrTypes { + addrType := addrType + + t.Run(addrType.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + addrType.triggerSLAACFn(e) + + addrBytes := []byte(addrType.subnet.ID()) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:]) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + // Simulate a DAD conflict. + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + + // Should not attempt address regeneration. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): + } + }) + } +} + +// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address +// generation in response to DAD conflicts does not refresh the lifetimes. +func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { + const nicID = 1 + const nicName = "nic" + const dadTransmits = 1 + const retransmitTimer = 2 * time.Second + const failureTimer = time.Second + const maxRetries = 1 + const lifetimeSeconds = 5 + + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, + }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + + addrBytes := []byte(subnet.ID()) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + // Simulate a DAD conflict after some time has passed. + time.Sleep(failureTimer) + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + + // Let the next address resolve. + addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) + expectAutoGenAddrEvent(addr, newAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + + // Address should be deprecated/invalidated after the lifetime expires. + // + // Note, the remaining lifetime is calculated from when the PI was first + // processed. Since we wait for some time before simulating a DAD conflict + // and more time for the new address to resolve, the new address is only + // expected to be valid for the remaining time. The DAD conflict should + // not have reset the lifetimes. + // + // We expect either just the invalidation event or the deprecation event + // followed by the invalidation event. + select { + case e := <-ndpDisp.autoGenAddrC: + if e.eventType == deprecatedAddr { + if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") + } + } else { + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + } + case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for auto gen addr event") + } +} + // TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event // to the integrator when an RA is received with the NDP Recursive DNS Server // option with at least one valid address. @@ -2925,6 +3405,112 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } } +// TestNDPDNSSearchListDispatch tests that the integrator is informed when an +// NDP DNS Search List option is received with at least one domain name in the +// search list. +func TestNDPDNSSearchListDispatch(t *testing.T) { + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dnsslC: make(chan ndpDNSSLEvent, 3), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + optSer := header.NDPOptionsSerializer{ + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 0, 0, + 2, 'h', 'i', + 0, + }), + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 0, 1, + 1, 'i', + 0, + 2, 'a', 'm', + 2, 'm', 'e', + 0, + }), + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 1, 0, + 3, 'x', 'y', 'z', + 0, + 5, 'h', 'e', 'l', 'l', 'o', + 5, 'w', 'o', 'r', 'l', 'd', + 0, + 4, 't', 'h', 'i', 's', + 2, 'i', 's', + 1, 'a', + 4, 't', 'e', 's', 't', + 0, + }), + } + expected := []struct { + domainNames []string + lifetime time.Duration + }{ + { + domainNames: []string{ + "hi", + }, + lifetime: 0, + }, + { + domainNames: []string{ + "i", + "am.me", + }, + lifetime: time.Second, + }, + { + domainNames: []string{ + "xyz", + "hello.world", + "this.is.a.test", + }, + lifetime: 256 * time.Second, + }, + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) + + for i, expected := range expected { + select { + case dnssl := <-ndpDisp.dnsslC: + if dnssl.nicID != nicID { + t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID) + } + if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" { + t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff) + } + if dnssl.lifetime != expected.lifetime { + t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime) + } + default: + t.Fatal("expected a DNSSL event") + } + } + + // Should have no more DNSSL options. + select { + case <-ndpDisp.dnsslC: + t.Fatal("unexpectedly got a DNSSL event") + default: + } +} + // TestCleanupNDPState tests that all discovered routers and prefixes, and // auto-generated addresses are invalidated when a NIC becomes a router. func TestCleanupNDPState(t *testing.T) { @@ -3483,7 +4069,8 @@ func TestRouterSolicitation(t *testing.T) { e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired waitForPkt := func(timeout time.Duration) { t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() p, ok := e.ReadContext(ctx) if !ok { t.Fatal("timed out waiting for packet") @@ -3513,7 +4100,8 @@ func TestRouterSolicitation(t *testing.T) { } waitForNothing := func(timeout time.Duration) { t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet") } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4835251bc..016dbe15e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1012,29 +1012,31 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress } - isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) + switch r.protocol { + case header.IPv6ProtocolNumber: + return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAAPrefixInvalidation */) + default: + r.expireLocked() + return nil + } +} + +func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACPrefixInvalidation bool) *tcpip.Error { + addr := r.addrWithPrefix() + + isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) if isIPv6Unicast { - // If we are removing a tentative IPv6 unicast address, stop DAD. - if kind == permanentTentative { - n.mu.ndp.stopDuplicateAddressDetection(addr) - } + n.mu.ndp.stopDuplicateAddressDetection(addr.Address) // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. if r.configType == slaac { - n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: r.ep.PrefixLen(), - }) + n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACPrefixInvalidation) } } - r.setKind(permanentExpired) - if !r.decRefLocked() { - // The endpoint still has references to it. - return nil - } + r.expireLocked() // At this point the endpoint is deleted. @@ -1044,7 +1046,7 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node // multicast group may be left by user action. if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(addr) + snmc := header.SolicitedNodeAddr(addr.Address) if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { return err } @@ -1425,10 +1427,12 @@ func (n *NIC) isAddrTentative(addr tcpip.Address) bool { return ref.getKind() == permanentTentative } -// dupTentativeAddrDetected attempts to inform n that a tentative addr -// is a duplicate on a link. +// dupTentativeAddrDetected attempts to inform n that a tentative addr is a +// duplicate on a link. // -// dupTentativeAddrDetected will delete the tentative address if it exists. +// dupTentativeAddrDetected will remove the tentative address if it exists. If +// the address was generated via SLAAC, an attempt will be made to generate a +// new address. func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -1442,7 +1446,17 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - return n.removePermanentAddressLocked(addr) + // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a + // new address will be generated for it. + if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACPrefixInvalidation */); err != nil { + return err + } + + if ref.configType == slaac { + n.mu.ndp.regenerateSLAACAddr(ref.addrWithPrefix().Subnet()) + } + + return nil } // setNDPConfigs sets the NDP configurations for n. @@ -1570,6 +1584,13 @@ type referencedNetworkEndpoint struct { deprecated bool } +func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix { + return tcpip.AddressWithPrefix{ + Address: r.ep.ID().LocalAddress, + PrefixLen: r.ep.PrefixLen(), + } +} + func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) } @@ -1597,6 +1618,13 @@ func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing) } +// expireLocked decrements the reference count and marks the permanent endpoint +// as expired. +func (r *referencedNetworkEndpoint) expireLocked() { + r.setKind(permanentExpired) + r.decRefLocked() +} + // decRef decrements the ref count and cleans up the endpoint once it reaches // zero. func (r *referencedNetworkEndpoint) decRef() { @@ -1606,14 +1634,11 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. Returns true if the endpoint was removed. -func (r *referencedNetworkEndpoint) decRefLocked() bool { +// locked. +func (r *referencedNetworkEndpoint) decRefLocked() { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) - return true } - - return false } // incRef increments the ref count. It must only be called when the caller is diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 3f8a2a095..c7634ceb1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1445,19 +1445,19 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } // If the NIC doesn't exist, it won't work. if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } } @@ -1483,12 +1483,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { } nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) + t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } // Set the initial route table. @@ -1503,10 +1503,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // When an interface is given, the route for a broadcast goes through it. r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } // When an interface is not given, it consults the route table. @@ -2399,7 +2399,7 @@ func TestNICContextPreservation(t *testing.T) { t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) } if got, want := nicinfo.Context == test.want, true; got != want { - t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) + t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) } }) } @@ -2768,7 +2768,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { { subnet, err := tcpip.NewSubnet("\x00", "\x00") if err != nil { - t.Fatalf("NewSubnet failed:", err) + t.Fatalf("NewSubnet failed: %v", err) } s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } @@ -2782,11 +2782,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // permanentExpired kind. r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) if err != nil { - t.Fatal("FindRoute failed:", err) + t.Fatalf("FindRoute failed: %v", err) } defer r.Release() if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } // @@ -2798,7 +2798,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } @@ -2806,7 +2806,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // make sure the new peb was respected. // (The address should just be promoted now). if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } var primaryAddrs []tcpip.Address for _, pa := range s.NICInfo()[1].ProtocolAddresses { @@ -2839,11 +2839,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // GetMainNICAddress; else, our original address // should be returned. if err := s.RemoveAddress(1, "\x03"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } addr, err = s.GetMainNICAddress(1, fakeNetNumber) if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) + t.Fatalf("s.GetMainNICAddress failed: %v", err) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index c65b0c632..2474a7db3 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -206,7 +206,7 @@ func TestTransportDemuxerRegister(t *testing.T) { // the distribution of packets received matches expectations. func TestBindToDeviceDistribution(t *testing.T) { type endpointSockopts struct { - reuse int + reuse bool bindToDevice tcpip.NICID } for _, test := range []struct { @@ -221,11 +221,11 @@ func TestBindToDeviceDistribution(t *testing.T) { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. @@ -236,9 +236,9 @@ func TestBindToDeviceDistribution(t *testing.T) { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {reuse: 0, bindToDevice: 1}, - {reuse: 0, bindToDevice: 2}, - {reuse: 0, bindToDevice: 3}, + {reuse: false, bindToDevice: 1}, + {reuse: false, bindToDevice: 2}, + {reuse: false, bindToDevice: 3}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. @@ -253,12 +253,12 @@ func TestBindToDeviceDistribution(t *testing.T) { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {reuse: 1, bindToDevice: 1}, - {reuse: 1, bindToDevice: 1}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 0}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to @@ -309,9 +309,8 @@ func TestBindToDeviceDistribution(t *testing.T) { }(ep) defer ep.Close() - reusePortOption := tcpip.ReusePortOption(endpoint.reuse) - if err := ep.SetSockOpt(reusePortOption); err != nil { - t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", reusePortOption, i, err) + if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil { + t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(bindToDeviceOption); err != nil { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2ef3271f1..1ca4088c9 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -520,34 +520,90 @@ type WriteOptions struct { type SockOptBool int const ( + // BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether + // datagram sockets are allowed to send packets to a broadcast address. + BroadcastOption SockOptBool = iota + + // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be + // held until segments are full by the TCP transport protocol. + CorkOption + + // DelayOption is used by SetSockOpt/GetSockOpt to specify if data + // should be sent out immediately by the transport protocol. For TCP, + // it determines if the Nagle algorithm is on or off. + DelayOption + + // KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether + // TCP keepalive is enabled for this socket. + KeepaliveEnabledOption + + // MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether + // multicast packets sent over a non-loopback interface will be looped back. + MulticastLoopOption + + // PasscredOption is used by SetSockOpt/GetSockOpt to specify whether + // SCM_CREDENTIALS socket control messages are enabled. + // + // Only supported on Unix sockets. + PasscredOption + + // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. + QuickAckOption + // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the // IPV6_TCLASS ancillary message is passed with incoming packets. - ReceiveTClassOption SockOptBool = iota + ReceiveTClassOption // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS // ancillary message is passed with incoming packets. ReceiveTOSOption - // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 - // socket is to be restricted to sending and receiving IPv6 packets only. - V6OnlyOption - // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify // if more inforamtion is provided with incoming packets such // as interface index and address. ReceiveIPPacketInfoOption - // TODO(b/146901447): convert existing bool socket options to be handled via - // Get/SetSockOptBool + // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() + // should allow reuse of local address. + ReuseAddressOption + + // ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets + // to be bound to an identical socket address. + ReusePortOption + + // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 + // socket is to be restricted to sending and receiving IPv6 packets only. + V6OnlyOption ) // SockOptInt represents socket options which values have the int type. type SockOptInt int const ( + // KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number + // of un-ACKed TCP keepalives that will be sent before the connection is + // closed. + KeepaliveCountOption SockOptInt = iota + + // IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS + // for all subsequent outgoing IPv4 packets from the endpoint. + IPv4TOSOption + + // IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS + // for all subsequent outgoing IPv6 packets from the endpoint. + IPv6TrafficClassOption + + // MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current + // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. + MaxSegOption + + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default + // TTL value for multicast messages. The default is 1. + MulticastTTLOption + // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the input buffer should be returned. - ReceiveQueueSizeOption SockOptInt = iota + ReceiveQueueSizeOption // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the send buffer size option. @@ -561,44 +617,21 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption - // DelayOption is used by SetSockOpt/GetSockOpt to specify if data - // should be sent out immediately by the transport protocol. For TCP, - // it determines if the Nagle algorithm is on or off. - DelayOption - - // TODO(b/137664753): convert all int socket options to be handled via - // GetSockOptInt. + // TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop + // limit value for unicast messages. The default is protocol specific. + // + // A zero value indicates the default. + TTLOption ) // ErrorOption is used in GetSockOpt to specify that the last error reported by // the endpoint should be cleared and returned. type ErrorOption struct{} -// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be -// held until segments are full by the TCP transport protocol. -type CorkOption int - -// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() -// should allow reuse of local address. -type ReuseAddressOption int - -// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets -// to be bound to an identical socket address. -type ReusePortOption int - // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. type BindToDeviceOption NICID -// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. -type QuickAckOption int - -// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether -// SCM_CREDENTIALS socket control messages are enabled. -// -// Only supported on Unix sockets. -type PasscredOption int - // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. @@ -607,10 +640,6 @@ type TCPInfoOption struct { RTTVar time.Duration } -// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether -// TCP keepalive is enabled for this socket. -type KeepaliveEnabledOption int - // KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a // connection must remain idle before the first TCP keepalive packet is sent. // Once this time is reached, KeepaliveIntervalOption is used instead. @@ -620,11 +649,6 @@ type KeepaliveIdleOption time.Duration // interval between sending TCP keepalive packets. type KeepaliveIntervalOption time.Duration -// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number -// of un-ACKed TCP keepalives that will be sent before the connection is -// closed. -type KeepaliveCountOption int - // TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user // specified timeout for a given TCP connection. // See: RFC5482 for details. @@ -638,20 +662,9 @@ type CongestionControlOption string // control algorithms. type AvailableCongestionControlOption string -// ModerateReceiveBufferOption allows the caller to enable/disable TCP receive // buffer moderation. type ModerateReceiveBufferOption bool -// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current -// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. -type MaxSegOption int - -// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop -// limit value for unicast messages. The default is protocol specific. -// -// A zero value indicates the default. -type TTLOption uint8 - // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state // before being marked closed. @@ -668,9 +681,14 @@ type TCPTimeWaitTimeoutOption time.Duration // for a handshake till the specified timeout until a segment with data arrives. type TCPDeferAcceptOption time.Duration -// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default -// TTL value for multicast messages. The default is 1. -type MulticastTTLOption uint8 +// TCPMinRTOOption is use by SetSockOpt/GetSockOpt to allow overriding +// default MinRTO used by the Stack. +type TCPMinRTOOption time.Duration + +// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify +// the number of endpoints that can be in SYN-RCVD state before the stack +// switches to using SYN cookies. +type TCPSynRcvdCountThresholdOption uint64 // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. @@ -679,10 +697,6 @@ type MulticastInterfaceOption struct { InterfaceAddr Address } -// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether -// multicast packets sent over a non-loopback interface will be looped back. -type MulticastLoopOption bool - // MembershipOption is used by SetSockOpt/GetSockOpt as an argument to // AddMembershipOption and RemoveMembershipOption. type MembershipOption struct { @@ -705,22 +719,10 @@ type RemoveMembershipOption MembershipOption // TCP out-of-band data is delivered along with the normal in-band data. type OutOfBandInlineOption int -// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether -// datagram sockets are allowed to send packets to a broadcast address. -type BroadcastOption int - // DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify // a default TTL. type DefaultTTLOption uint8 -// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS -// for all subsequent outgoing IPv4 packets from the endpoint. -type IPv4TOSOption uint8 - -// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS -// for all subsequent outgoing IPv6 packets from the endpoint. -type IPv6TrafficClassOption uint8 - // IPPacketInfo is the message struture for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 8c0aacffa..1c8e2bc34 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -218,7 +218,7 @@ func TestAddressWithPrefixSubnet(t *testing.T) { gotSubnet := ap.Subnet() wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) if err != nil { - t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) + t.Errorf("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) continue } if gotSubnet != wantSubnet { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b007302fb..feef8dca0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -348,13 +348,6 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(o) - e.mu.Unlock() - } - return nil } @@ -365,12 +358,25 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt sets a socket option. Currently not supported. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + + } return nil } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -397,26 +403,23 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil + case tcpip.TTLOption: + e.rcvMu.Lock() + v := int(e.ttl) + e.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - case *tcpip.TTLOption: - e.rcvMu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.rcvMu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 337bc1c71..eee754a5a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -533,14 +533,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: return tcpip.ErrUnknownProtocolOption } @@ -548,7 +544,13 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -576,9 +578,9 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil + default: + return -1, tcpip.ErrUnknownProtocolOption } - - return -1, tcpip.ErrUnknownProtocolOption } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 7f94f9646..edb7718a6 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -87,7 +87,9 @@ go_test( "tcp_timestamp_test.go", ], # FIXME(b/68809571) - tags = ["flaky"], + tags = [ + "flaky", + ], deps = [ ":tcp", "//pkg/sync", @@ -104,5 +106,6 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp/testing/context", "//pkg/waiter", + "//runsc/testutil", ], ) diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 7a9dea4ac..5bb243e3b 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -17,6 +17,7 @@ package tcp import ( "crypto/sha1" "encoding/binary" + "fmt" "hash" "io" "time" @@ -25,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -49,17 +49,14 @@ const ( // timestamp and the current timestamp. If the difference is greater // than maxTSDiff, the cookie is expired. maxTSDiff = 2 -) -var ( - // SynRcvdCountThreshold is the global maximum number of connections - // that are allowed to be in SYN-RCVD state before TCP starts using SYN - // cookies to accept connections. - // - // It is an exported variable only for testing, and should not otherwise - // be used by importers of this package. + // SynRcvdCountThreshold is the default global maximum number of + // connections that are allowed to be in SYN-RCVD state before TCP + // starts using SYN cookies to accept connections. SynRcvdCountThreshold uint64 = 1000 +) +var ( // mssTable is a slice containing the possible MSS values that we // encode in the SYN cookie with two bits. mssTable = []uint16{536, 1300, 1440, 1460} @@ -74,29 +71,42 @@ func encodeMSS(mss uint16) uint32 { return 0 } -// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is -// protected by a mutex so that we can increment only when it's guaranteed not -// to go above a threshold. -var synRcvdCount struct { - sync.Mutex - value uint64 - pending sync.WaitGroup -} - // listenContext is used by a listening endpoint to store state used while // listening for connections. This struct is allocated by the listen goroutine // and must not be accessed or have its methods called concurrently as they // may mutate the stored objects. type listenContext struct { - stack *stack.Stack - rcvWnd seqnum.Size - nonce [2][sha1.BlockSize]byte + stack *stack.Stack + + // synRcvdCount is a reference to the stack level synRcvdCount. + synRcvdCount *synRcvdCounter + + // rcvWnd is the receive window that is sent by this listening context + // in the initial SYN-ACK. + rcvWnd seqnum.Size + + // nonce are random bytes that are initialized once when the context + // is created and used to seed the hash function when generating + // the SYN cookie. + nonce [2][sha1.BlockSize]byte + + // listenEP is a reference to the listening endpoint associated with + // this context. Can be nil if the context is created by the forwarder. listenEP *endpoint + // hasherMu protects hasher. hasherMu sync.Mutex - hasher hash.Hash - v6only bool + // hasher is the hash function used to generate a SYN cookie. + hasher hash.Hash + + // v6Only is true if listenEP is a dual stack socket and has the + // IPV6_V6ONLY option set. + v6only bool + + // netProto indicates the network protocol(IPv4/v6) for the listening + // endpoint. netProto tcpip.NetworkProtocolNumber + // pendingMu protects pendingEndpoints. This should only be accessed // by the listening endpoint's worker goroutine. // @@ -115,44 +125,6 @@ func timeStamp() uint32 { return uint32(time.Now().Unix()>>6) & tsMask } -// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD -// state. It succeeds if the increment doesn't make the count go beyond the -// threshold, and fails otherwise. -func incSynRcvdCount() bool { - synRcvdCount.Lock() - - if synRcvdCount.value >= SynRcvdCountThreshold { - synRcvdCount.Unlock() - return false - } - - synRcvdCount.pending.Add(1) - synRcvdCount.value++ - - synRcvdCount.Unlock() - return true -} - -// decSynRcvdCount atomically decrements the global number of endpoints in -// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount -// succeeded. -func decSynRcvdCount() { - synRcvdCount.Lock() - - synRcvdCount.value-- - synRcvdCount.pending.Done() - synRcvdCount.Unlock() -} - -// synCookiesInUse() returns true if the synRcvdCount is greater than -// SynRcvdCountThreshold. -func synCookiesInUse() bool { - synRcvdCount.Lock() - v := synRcvdCount.value - synRcvdCount.Unlock() - return v >= SynRcvdCountThreshold -} - // newListenContext creates a new listen context. func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ @@ -164,6 +136,11 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, listenEP: listenEP, pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } + p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol) + if !ok { + panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk)) + } + l.synRcvdCount = p.SynRcvdCounter() rand.Read(l.nonce[0][:]) rand.Read(l.nonce[1][:]) @@ -330,6 +307,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if l.listenEP != nil { l.removePendingEndpoint(ep) } + + ep.drainClosingSegmentQueue() + return nil, err } ep.isConnectNotified = true @@ -378,7 +358,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) { for { if e.acceptedChan == nil { e.acceptMu.Unlock() - n.Close() + n.notifyProtocolGoroutine(notifyReset) return } select { @@ -407,7 +387,7 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { - defer decSynRcvdCount() + defer ctx.synRcvdCount.dec() defer func() { e.mu.Lock() e.decSynRcvdCount() @@ -452,19 +432,16 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { - if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) { + e.rcvListMu.Lock() + rcvClosed := e.rcvClosed + e.rcvListMu.Unlock() + if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { + // If the endpoint is shutdown, reply with reset. + // // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST // must be sent in response to a SYN-ACK while in the listen // state to prevent completing a handshake from an old SYN. - e.sendTCP(&s.route, tcpFields{ - id: s.id, - ttl: e.ttl, - tos: e.sendTOS, - flags: header.TCPFlagRst, - seq: s.ackNumber, - ack: 0, - rcvWnd: 0, - }, buffer.VectorisedView{}, nil) + replyWithReset(s, e.sendTOS, e.ttl) return } @@ -474,7 +451,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { switch { case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) - if incSynRcvdCount() { + if ctx.synRcvdCount.inc() { // Only handle the syn if the following conditions hold // - accept queue is not full. // - number of connections in synRcvd state is less than the @@ -484,7 +461,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. return } - decSynRcvdCount() + ctx.synRcvdCount.dec() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() @@ -537,7 +514,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } - if !synCookiesInUse() { + if !ctx.synRcvdCount.synCookiesInUse() { // When not using SYN cookies, as per RFC 793, section 3.9, page 64: // Any acknowledgment is bad if it arrives on a connection still in // the LISTEN state. An acceptable reset segment should be formed @@ -553,7 +530,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // The only time we should reach here when a connection // was opened and closed really quickly and a delayed // ACK was received from the sender. - replyWithReset(s) + replyWithReset(s, e.sendTOS, e.ttl) return } @@ -656,6 +633,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { } e.mu.Unlock() + e.drainClosingSegmentQueue() + // Notify waiters that the endpoint is shutdown. e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 2ca3fb809..368865911 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1053,15 +1053,34 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route) } if ep == nil { - replyWithReset(s) + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) s.decRef() return } + + if e == ep { + panic("current endpoint not removed from demuxer, enqueing segments to itself") + } + if ep.(*endpoint).enqueueSegment(s) { ep.(*endpoint).newSegmentWaker.Assert() } } +// Drain segment queue from the endpoint and try to re-match the segment to a +// different endpoint. This is used when the current endpoint is transitioned to +// StateClose and has been unregistered from the transport demuxer. +func (e *endpoint) drainClosingSegmentQueue() { + for { + s := e.segmentQueue.dequeue() + if s == nil { + break + } + + e.tryDeliverSegmentFromClosedEndpoint(s) + } +} + func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { if e.rcv.acceptable(s.sequenceNumber, 0) { // RFC 793, page 37 states that "in all states @@ -1315,6 +1334,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } e.mu.Unlock() + + e.drainClosingSegmentQueue() + // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1565,19 +1587,6 @@ loop: // Lock released below. epilogue() - // epilogue removes the endpoint from the transport-demuxer and - // unlocks e.mu. Now that no new segments can get enqueued to this - // endpoint, try to re-match the segment to a different endpoint - // as the current endpoint is closed. - for { - s := e.segmentQueue.dequeue() - if s == nil { - break - } - - e.tryDeliverSegmentFromClosedEndpoint(s) - } - // A new SYN was received during TIME_WAIT and we need to abort // the timewait and redirect the segment to the listener queue if reuseTW != nil { diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 4f361b226..804e95aea 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -568,11 +568,10 @@ func TestV4AcceptOnV4(t *testing.T) { func testV4ListenClose(t *testing.T, c *context.Context) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 0 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err) + } + const n = uint16(32) // Start listening. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9b123e968..7e8def82d 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -821,7 +821,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var de DelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { - e.SetSockOptInt(tcpip.DelayOption, 1) + e.SetSockOptBool(tcpip.DelayOption, true) } var tcpLT tcpip.TCPLingerTimeoutOption @@ -980,25 +980,22 @@ func (e *endpoint) closeNoShutdownLocked() { // Mark endpoint as closed. e.closed = true + + switch e.EndpointState() { + case StateClose, StateError: + return + } + // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. - switch e.EndpointState() { - // Sockets in StateSynRecv state(passive connections) are closed when - // the handshake fails or if the listening socket is closed while - // handshake was in progress. In such cases the handshake goroutine - // is already gone by the time Close is called and we need to cleanup - // here. - case StateInitial, StateBound, StateSynRecv: - e.cleanupLocked() - e.setEndpointState(StateClose) - case StateError, StateClose: - // do nothing. - default: + if e.workerRunning { e.workerCleanup = true tcpip.AddDanglingEndpoint(e) // Worker will remove the dangling endpoint when the endpoint // goroutine terminates. e.notifyProtocolGoroutine(notifyClose) + } else { + e.transitionToStateCloseLocked() } } @@ -1010,13 +1007,18 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Unlock() return } - close(e.acceptedChan) + ch := e.acceptedChan e.acceptedChan = nil e.acceptCond.Broadcast() e.acceptMu.Unlock() - // Wait for all pending endpoints to close. + // Reset all connections that are waiting to be accepted. + for n := range ch { + n.notifyProtocolGoroutine(notifyReset) + } + // Wait for reset of all endpoints that are still waiting to be delivered to + // the now closed acceptedChan. e.pendingAccepted.Wait() } @@ -1409,10 +1411,58 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // SetSockOptBool sets a socket option. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - e.LockUser() - defer e.UnlockUser() - switch opt { + + case tcpip.BroadcastOption: + e.LockUser() + e.broadcast = v + e.UnlockUser() + + case tcpip.CorkOption: + e.LockUser() + if !v { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + e.UnlockUser() + + case tcpip.DelayOption: + if v { + atomic.StoreUint32(&e.delay, 1) + } else { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + e.keepalive.enabled = v + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.QuickAckOption: + o := uint32(1) + if v { + o = 0 + } + atomic.StoreUint32(&e.slowAck, o) + + case tcpip.ReuseAddressOption: + e.LockUser() + e.reuseAddr = v + e.UnlockUser() + + case tcpip.ReusePortOption: + e.LockUser() + e.reusePort = v + e.UnlockUser() + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1424,7 +1474,9 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + e.LockUser() e.v6only = v + e.UnlockUser() } return nil @@ -1432,18 +1484,50 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + e.keepalive.count = v + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.IPv4TOSOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.LockUser() + e.userMSS = uint16(userMSS) + e.UnlockUser() + e.notifyProtocolGoroutine(notifyMSSChanged) + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. var rs ReceiveBufferSizeOption - size := int(v) if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - if size < rs.Min { - size = rs.Min + if v < rs.Min { + v = rs.Min } - if size > rs.Max { - size = rs.Max + if v > rs.Max { + v = rs.Max } } @@ -1458,17 +1542,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { if e.rcv != nil { scale = e.rcv.rcvWndScale } - if size>>scale == 0 { - size = 1 << scale + if v>>scale == 0 { + v = 1 << scale } // Make sure 2*size doesn't overflow. - if size > math.MaxInt32/2 { - size = math.MaxInt32 / 2 + if v > math.MaxInt32/2 { + v = math.MaxInt32 / 2 } availBefore := e.receiveBufferAvailableLocked() - e.rcvBufSize = size + e.rcvBufSize = v availAfter := e.receiveBufferAvailableLocked() e.rcvAutoParams.disabled = true @@ -1483,71 +1567,36 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvListMu.Unlock() e.UnlockUser() e.notifyProtocolGoroutine(mask) - return nil case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - size := int(v) var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if size < ss.Min { - size = ss.Min + if v < ss.Min { + v = ss.Min } - if size > ss.Max { - size = ss.Max + if v > ss.Max { + v = ss.Max } } e.sndBufMu.Lock() - e.sndBufSize = size + e.sndBufSize = v e.sndBufMu.Unlock() - return nil - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) - } - return nil + case tcpip.TTLOption: + e.LockUser() + e.ttl = uint8(v) + e.UnlockUser() - default: - return nil } + return nil } // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 - const inetECNMask = 3 switch v := opt.(type) { - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.LockUser() - e.reuseAddr = v != 0 - e.UnlockUser() - return nil - - case tcpip.ReusePortOption: - e.LockUser() - e.reusePort = v != 0 - e.UnlockUser() - return nil - case tcpip.BindToDeviceOption: id := tcpip.NICID(v) if id != 0 && !e.stack.HasNIC(id) { @@ -1556,72 +1605,26 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.LockUser() e.bindToDevice = id e.UnlockUser() - return nil - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.LockUser() - e.userMSS = uint16(userMSS) - e.UnlockUser() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - - case tcpip.TTLOption: - e.LockUser() - e.ttl = uint8(v) - e.UnlockUser() - return nil - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - e.keepalive.enabled = v != 0 - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil case tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil case tcpip.KeepaliveIntervalOption: e.keepalive.Lock() e.keepalive.interval = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil - case tcpip.KeepaliveCountOption: - e.keepalive.Lock() - e.keepalive.count = int(v) - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil + case tcpip.OutOfBandInlineOption: + // We don't currently support disabling this option. case tcpip.TCPUserTimeoutOption: e.LockUser() e.userTimeout = time.Duration(v) e.UnlockUser() - return nil - - case tcpip.BroadcastOption: - e.LockUser() - e.broadcast = v != 0 - e.UnlockUser() - return nil case tcpip.CongestionControlOption: // Query the available cc algorithms in the stack and @@ -1652,22 +1655,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // control algorithm is specified. return tcpip.ErrNoSuchFile - case tcpip.IPv4TOSOption: - e.LockUser() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.UnlockUser() - return nil - - case tcpip.IPv6TrafficClassOption: - e.LockUser() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.UnlockUser() - return nil - case tcpip.TCPLingerTimeoutOption: e.LockUser() if v < 0 { @@ -1688,7 +1675,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } e.tcpLingerTimeout = time.Duration(v) e.UnlockUser() - return nil case tcpip.TCPDeferAcceptOption: e.LockUser() @@ -1697,11 +1683,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } e.deferAccept = time.Duration(v) e.UnlockUser() - return nil default: return nil } + return nil } // readyReceiveSize returns the number of bytes ready to be received. @@ -1723,6 +1709,43 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { + case tcpip.BroadcastOption: + e.LockUser() + v := e.broadcast + e.UnlockUser() + return v, nil + + case tcpip.CorkOption: + return atomic.LoadUint32(&e.cork) != 0, nil + + case tcpip.DelayOption: + return atomic.LoadUint32(&e.delay) != 0, nil + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + v := e.keepalive.enabled + e.keepalive.Unlock() + + return v, nil + + case tcpip.QuickAckOption: + v := atomic.LoadUint32(&e.slowAck) == 0 + return v, nil + + case tcpip.ReuseAddressOption: + e.LockUser() + v := e.reuseAddr + e.UnlockUser() + + return v, nil + + case tcpip.ReusePortOption: + e.LockUser() + v := e.reusePort + e.UnlockUser() + + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1734,14 +1757,41 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.UnlockUser() return v, nil - } - return false, tcpip.ErrUnknownProtocolOption + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + v := e.keepalive.count + e.keepalive.Unlock() + return v, nil + + case tcpip.IPv4TOSOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.MaxSegOption: + // This is just stubbed out. Linux never returns the user_mss + // value as it either returns the defaultMSS or returns the + // actual current MSS. Netstack just returns the defaultMSS + // always for now. + v := header.TCPDefaultMSS + return v, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1757,12 +1807,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil - case tcpip.DelayOption: - var o int - if v := atomic.LoadUint32(&e.delay); v != 0 { - o = 1 - } - return o, nil + case tcpip.TTLOption: + e.LockUser() + v := int(e.ttl) + e.UnlockUser() + return v, nil default: return -1, tcpip.ErrUnknownProtocolOption @@ -1779,61 +1828,10 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.lastErrorMu.Unlock() return err - case *tcpip.MaxSegOption: - // This is just stubbed out. Linux never returns the user_mss - // value as it either returns the defaultMSS or returns the - // actual current MSS. Netstack just returns the defaultMSS - // always for now. - *o = header.TCPDefaultMSS - return nil - - case *tcpip.CorkOption: - *o = 0 - if v := atomic.LoadUint32(&e.cork); v != 0 { - *o = 1 - } - return nil - - case *tcpip.ReuseAddressOption: - e.LockUser() - v := e.reuseAddr - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.ReusePortOption: - e.LockUser() - v := e.reusePort - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.BindToDeviceOption: e.LockUser() *o = tcpip.BindToDeviceOption(e.bindToDevice) e.UnlockUser() - return nil - - case *tcpip.QuickAckOption: - *o = 1 - if v := atomic.LoadUint32(&e.slowAck); v != 0 { - *o = 0 - } - return nil - - case *tcpip.TTLOption: - e.LockUser() - *o = tcpip.TTLOption(e.ttl) - e.UnlockUser() - return nil case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} @@ -1846,92 +1844,45 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { o.RTTVar = snd.rtt.rttvar snd.rtt.Unlock() } - return nil - - case *tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - v := e.keepalive.enabled - e.keepalive.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() *o = tcpip.KeepaliveIdleOption(e.keepalive.idle) e.keepalive.Unlock() - return nil case *tcpip.KeepaliveIntervalOption: e.keepalive.Lock() *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval) e.keepalive.Unlock() - return nil - - case *tcpip.KeepaliveCountOption: - e.keepalive.Lock() - *o = tcpip.KeepaliveCountOption(e.keepalive.count) - e.keepalive.Unlock() - return nil case *tcpip.TCPUserTimeoutOption: e.LockUser() *o = tcpip.TCPUserTimeoutOption(e.userTimeout) e.UnlockUser() - return nil case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 - return nil - - case *tcpip.BroadcastOption: - e.LockUser() - v := e.broadcast - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.CongestionControlOption: e.LockUser() *o = e.cc e.UnlockUser() - return nil - - case *tcpip.IPv4TOSOption: - e.LockUser() - *o = tcpip.IPv4TOSOption(e.sendTOS) - e.UnlockUser() - return nil - - case *tcpip.IPv6TrafficClassOption: - e.LockUser() - *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.UnlockUser() - return nil case *tcpip.TCPLingerTimeoutOption: e.LockUser() *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) e.UnlockUser() - return nil case *tcpip.TCPDeferAcceptOption: e.LockUser() *o = tcpip.TCPDeferAcceptOption(e.deferAccept) e.UnlockUser() - return nil default: return tcpip.ErrUnknownProtocolOption } + return nil } // checkV4MappedLocked determines the effective network protocol and converts @@ -2146,7 +2097,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { switch { case e.EndpointState().connected(): // Close for read. - if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { + if e.shutdownFlags&tcpip.ShutdownRead != 0 { // Mark read side as closed. e.rcvListMu.Lock() e.rcvClosed = true @@ -2155,7 +2106,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. - if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { + if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 { e.resetConnectionLocked(tcpip.ErrConnectionAborted) // Wake up worker to terminate loop. e.notifyProtocolGoroutine(notifyTickleWorker) @@ -2164,7 +2115,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { } // Close for write. - if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { e.sndBufMu.Lock() if e.sndClosed { // Already closed. @@ -2187,12 +2138,23 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { return nil case e.EndpointState() == StateListen: - // Tell protocolListenLoop to stop. - if flags&tcpip.ShutdownRead != 0 { - e.notifyProtocolGoroutine(notifyClose) + if e.shutdownFlags&tcpip.ShutdownRead != 0 { + // Reset all connections from the accept queue and keep the + // worker running so that it can continue handling incoming + // segments by replying with RST. + // + // By not removing this endpoint from the demuxer mapping, we + // ensure that any other bind to the same port fails, as on Linux. + // TODO(gvisor.dev/issue/2468): We need to enable applications to + // start listening on this endpoint again similar to Linux. + e.rcvListMu.Lock() + e.rcvClosed = true + e.rcvListMu.Unlock() + e.closePendingAcceptableConnectionsLocked() + // Notify waiters that the endpoint is shutdown. + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) } return nil - default: return tcpip.ErrNotConnected } @@ -2296,8 +2258,11 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { e.LockUser() defer e.UnlockUser() + e.rcvListMu.Lock() + rcvClosed := e.rcvClosed + e.rcvListMu.Unlock() // Endpoint must be in listen state before it can accept connections. - if e.EndpointState() != StateListen { + if rcvClosed || e.EndpointState() != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 808410c92..704d01c64 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -130,7 +130,7 @@ func (r *ForwarderRequest) Complete(sendReset bool) { // If the caller requested, send a reset. if sendReset { - replyWithReset(r.segment) + replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL()) } // Release all resources. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index dce9a1652..cfd9a4e8e 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -94,6 +94,63 @@ const ( ccCubic = "cubic" ) +// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The +// value is protected by a mutex so that we can increment only when it's +// guaranteed not to go above a threshold. +type synRcvdCounter struct { + sync.Mutex + value uint64 + pending sync.WaitGroup + threshold uint64 +} + +// inc tries to increment the global number of endpoints in SYN-RCVD state. It +// succeeds if the increment doesn't make the count go beyond the threshold, and +// fails otherwise. +func (s *synRcvdCounter) inc() bool { + s.Lock() + defer s.Unlock() + if s.value >= s.threshold { + return false + } + + s.pending.Add(1) + s.value++ + + return true +} + +// dec atomically decrements the global number of endpoints in SYN-RCVD +// state. It must only be called if a previous call to inc succeeded. +func (s *synRcvdCounter) dec() { + s.Lock() + defer s.Unlock() + s.value-- + s.pending.Done() +} + +// synCookiesInUse returns true if the synRcvdCount is greater than +// SynRcvdCountThreshold. +func (s *synRcvdCounter) synCookiesInUse() bool { + s.Lock() + defer s.Unlock() + return s.value >= s.threshold +} + +// SetThreshold sets synRcvdCounter.Threshold to ths new threshold. +func (s *synRcvdCounter) SetThreshold(threshold uint64) { + s.Lock() + defer s.Unlock() + s.threshold = threshold +} + +// Threshold returns the current value of synRcvdCounter.Threhsold. +func (s *synRcvdCounter) Threshold() uint64 { + s.Lock() + defer s.Unlock() + return s.threshold +} + type protocol struct { mu sync.RWMutex sackEnabled bool @@ -105,6 +162,8 @@ type protocol struct { moderateReceiveBuffer bool tcpLingerTimeout time.Duration tcpTimeWaitTimeout time.Duration + minRTO time.Duration + synRcvdCount synRcvdCounter dispatcher *dispatcher } @@ -164,12 +223,12 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo return true } - replyWithReset(s) + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) return true } // replyWithReset replies to the given segment with a reset segment. -func replyWithReset(s *segment) { +func replyWithReset(s *segment, tos, ttl uint8) { // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) ack := seqnum.Value(0) @@ -193,8 +252,8 @@ func replyWithReset(s *segment) { } sendTCP(&s.route, tcpFields{ id: s.id, - ttl: s.route.DefaultTTL(), - tos: stack.DefaultTOS, + ttl: ttl, + tos: tos, flags: flags, seq: seq, ack: ack, @@ -272,6 +331,21 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPMinRTOOption: + if v < 0 { + v = tcpip.TCPMinRTOOption(MinRTO) + } + p.mu.Lock() + p.minRTO = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPSynRcvdCountThresholdOption: + p.mu.Lock() + p.synRcvdCount.SetThreshold(uint64(v)) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -334,6 +408,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil + case *tcpip.TCPMinRTOOption: + p.mu.RLock() + *v = tcpip.TCPMinRTOOption(p.minRTO) + p.mu.RUnlock() + return nil + + case *tcpip.TCPSynRcvdCountThresholdOption: + p.mu.RLock() + *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold()) + p.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -349,6 +435,12 @@ func (p *protocol) Wait() { p.dispatcher.wait() } +// SynRcvdCounter returns a reference to the synRcvdCount for this protocol +// instance. +func (p *protocol) SynRcvdCounter() *synRcvdCounter { + return &p.synRcvdCount +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ @@ -358,6 +450,8 @@ func NewProtocol() stack.TransportProtocol { availableCongestionControl: []string{ccReno, ccCubic}, tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), + minRTO: MinRTO, } } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 6b7bac37d..d8cfe3115 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -15,6 +15,7 @@ package tcp import ( + "fmt" "math" "sync/atomic" "time" @@ -149,6 +150,9 @@ type sender struct { rtt rtt rto time.Duration + // minRTO is the minimum permitted value for sender.rto. + minRTO time.Duration + // maxPayloadSize is the maximum size of the payload of a given segment. // It is initialized on demand. maxPayloadSize int @@ -260,6 +264,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // etc. s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) + // Get Stack wide minRTO. + var v tcpip.TCPMinRTOOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { + panic(fmt.Sprintf("unable to get minRTO from stack: %s", err)) + } + s.minRTO = time.Duration(v) + return s } @@ -394,8 +405,8 @@ func (s *sender) updateRTO(rtt time.Duration) { s.rto = s.rtt.srtt + 4*s.rtt.rttvar s.rtt.Unlock() - if s.rto < MinRTO { - s.rto = MinRTO + if s.rto < s.minRTO { + s.rto = s.minRTO } } diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 782d7b42c..359a75e73 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/runsc/testutil" ) func TestFastRecovery(t *testing.T) { @@ -40,7 +41,7 @@ func TestFastRecovery(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -86,16 +87,23 @@ func TestFastRecovery(t *testing.T) { // Receive the retransmitted packet. c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } + // Wait before checking metrics. + metricPollFn := func() error { + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want) + } + return nil } - if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want) + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } // Now send 7 mode duplicate acks. Each of these should cause a window @@ -117,12 +125,18 @@ func TestFastRecovery(t *testing.T) { // Receive the retransmit due to partial ack. c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + // Wait before checking metrics. + metricPollFn = func() error { + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { + return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + } + return nil } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } // Receive the 10 extra packets that should have been released due to @@ -192,7 +206,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -234,7 +248,7 @@ func TestCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -338,7 +352,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { @@ -447,7 +461,7 @@ func TestRetransmit(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -492,24 +506,33 @@ func TestRetransmit(t *testing.T) { rtxOffset := bytesRead - maxPayload*expected c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want) - } + metricPollFn := func() error { + if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want) + } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) + } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) - } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { + return fmt.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) + } + + if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) + } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) + return nil } - if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) + // Poll when checking metrics. + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } // Acknowledge half of the pending data. diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index afea124ec..1dd63dd61 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -149,21 +149,22 @@ func TestSackPermittedAccept(t *testing.T) { {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() + for _, tc := range testCases { t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } for _, sackEnabled := range []bool{false, true} { t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + + if tc.cookieEnabled { + // Set the SynRcvd threshold to + // zero to force a syn cookie + // based accept to happen. + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } setStackSACKPermitted(t, c, sackEnabled) rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) @@ -222,21 +223,23 @@ func TestSackDisabledAccept(t *testing.T) { {true, -1, 0xffff}, // When cookie is used window scaling is disabled. {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() + for _, tc := range testCases { t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } for _, sackEnabled := range []bool{false, true} { t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + + if tc.cookieEnabled { + // Set the SynRcvd threshold to + // zero to force a syn cookie + // based accept to happen. + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } + setStackSACKPermitted(t, c, sackEnabled) rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) @@ -387,7 +390,7 @@ func TestSACKRecovery(t *testing.T) { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index ce3df7478..ab1014c7f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -284,7 +284,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // are released instantly on Close. tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { - t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %v) = %v", tcp.ProtocolNumber, tcpTW, err) } c.EP.Close() @@ -590,6 +590,10 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { ), ) + // Give the stack a few ms to transition the endpoint out of ESTABLISHED + // state. + time.Sleep(10 * time.Millisecond) + if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } @@ -728,7 +732,7 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) { const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize tests := []struct { name string - setMSS uint16 + setMSS int expMSS uint16 }{ { @@ -756,15 +760,14 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) { c.Create(-1) // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) } // Get expected window size. rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) @@ -818,15 +821,14 @@ func TestUserSuppliedMSSOnConnectV6(t *testing.T) { c.CreateV6Endpoint(true) // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) } // Get expected window size. rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) @@ -1032,8 +1034,8 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) { checker.SeqNum(200))) } -// TestListenShutdown tests for the listening endpoint not processing -// any receive when it is on read shutdown. +// TestListenShutdown tests for the listening endpoint replying with RST +// on read shutdown. func TestListenShutdown(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -1044,7 +1046,7 @@ func TestListenShutdown(t *testing.T) { t.Fatal("Bind failed:", err) } - if err := c.EP.Listen(10 /* backlog */); err != nil { + if err := c.EP.Listen(1 /* backlog */); err != nil { t.Fatal("Listen failed:", err) } @@ -1052,9 +1054,6 @@ func TestListenShutdown(t *testing.T) { t.Fatal("Shutdown failed:", err) } - // Wait for the endpoint state to be propagated. - time.Sleep(10 * time.Millisecond) - c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -1063,7 +1062,49 @@ func TestListenShutdown(t *testing.T) { AckNum: 200, }) - c.CheckNoPacket("Packet received when listening socket was shutdown") + // Expect the listening endpoint to reset the connection. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) +} + +// TestListenCloseWhileConnect tests for the listening endpoint to +// drain the accept-queue when closed. This should reset all of the +// pending connections that are waiting to be accepted. +func TestListenCloseWhileConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(1 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventIn) + defer c.WQ.EventUnregister(&waitEntry) + + executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + // Wait for the new endpoint created because of handshake to be delivered + // to the listening endpoint's accept queue. + <-notifyCh + + // Close the listening endpoint. + c.EP.Close() + + // Expect the listening endpoint to reset the connection. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) } func TestTOSV4(t *testing.T) { @@ -1077,17 +1118,17 @@ func TestTOSV4(t *testing.T) { c.EP = ep const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) } - var v tcpip.IPv4TOSOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Errorf("GetSockopt failed: %s", err) + v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) } - if want := tcpip.IPv4TOSOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tos { + t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) } testV4Connect(t, c, checker.TOS(tos, 0)) @@ -1125,17 +1166,17 @@ func TestTrafficClassV6(t *testing.T) { c.CreateV6Endpoint(false) const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err) + if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { + t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) } - var v tcpip.IPv6TrafficClassOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockopt failed: %s", err) + v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) } - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tos { + t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) } // Test the connection request. @@ -1711,7 +1752,7 @@ func TestNoWindowShrinking(t *testing.T) { c.CreateConnected(789, 30000, 10) if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %v", err) } we, ch := waiter.NewChannelEntry(nil) @@ -1984,7 +2025,7 @@ func TestScaledWindowAccept(t *testing.T) { // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2057,7 +2098,7 @@ func TestNonScaledWindowAccept(t *testing.T) { // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2221,10 +2262,10 @@ func TestSegmentMerging(t *testing.T) { { "cork", func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(1)) + ep.SetSockOptBool(tcpip.CorkOption, true) }, func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(0)) + ep.SetSockOptBool(tcpip.CorkOption, false) }, }, } @@ -2316,7 +2357,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptInt(tcpip.DelayOption, 1) + c.EP.SetSockOptBool(tcpip.DelayOption, true) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -2364,7 +2405,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptInt(tcpip.DelayOption, 1) + c.EP.SetSockOptBool(tcpip.DelayOption, true) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -2397,7 +2438,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOptInt(tcpip.DelayOption, 0) + c.EP.SetSockOptBool(tcpip.DelayOption, false) // Check that data is received. second := c.GetPacket() @@ -2434,8 +2475,8 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }}, + {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }}, } for _, test := range tests { @@ -2576,12 +2617,12 @@ func TestSetTTL(t *testing.T) { t.Fatalf("NewEndpoint failed: %v", err) } - if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) + if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + t.Fatalf("Unexpected return value from Connect: %s", err) } // Receive SYN packet. @@ -2621,7 +2662,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2667,26 +2708,24 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 0 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } // Create EP and start listening. wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Do 3-way handshake. @@ -2704,7 +2743,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -2765,7 +2804,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { const rcvBufferSize = 0x20000 const wndScale = 2 if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } // Start connection attempt. @@ -3882,26 +3921,26 @@ func TestMinMaxBufferSizes(t *testing.T) { // Set values below the min. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) } checkRecvBufferSize(t, ep, 200) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) } checkSendBufferSize(t, ep, 300) // Set values above the max. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) } checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) @@ -4147,11 +4186,11 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { case "ipv4": case "ipv6": if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err) + t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err) } case "dual": if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err) + t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err) } default: t.Fatalf("unknown network: '%s'", network) @@ -4474,11 +4513,11 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const keepAliveInterval = 10 * time.Millisecond - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + const keepAliveInterval = 3 * time.Second + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -4569,7 +4608,7 @@ func TestKeepalive(t *testing.T) { // Sleep for a litte over the KeepAlive interval to make sure // the timer has time to fire after the last ACK and close the // close the socket. - time.Sleep(keepAliveInterval + 5*time.Millisecond) + time.Sleep(keepAliveInterval + keepAliveInterval/2) // The connection should be terminated after 5 unacked keepalives. // Send an ACK to trigger a RST from the stack as the endpoint should @@ -5104,25 +5143,23 @@ func TestListenSynRcvdQueueFull(t *testing.T) { } func TestListenBacklogFullSynCookieInUse(t *testing.T) { - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 1 - c := context.New(t, defaultMTU) defer c.Cleanup() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err) + } + // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } // Bind to wildcard. if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -5130,7 +5167,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { listenBacklog := 1 portOffset := uint16(0) if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } executeHandshake(t, c, context.TestPort+portOffset, false) @@ -5609,7 +5646,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { return } if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { - t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd) + t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w) } }, )) @@ -5770,14 +5807,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) { func TestDelayEnabled(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - checkDelayOption(t, c, false, 0) // Delay is disabled by default. + checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { delayEnabled tcp.DelayEnabled - wantDelayOption int + wantDelayOption bool }{ - {delayEnabled: false, wantDelayOption: 0}, - {delayEnabled: true, wantDelayOption: 1}, + {delayEnabled: false, wantDelayOption: false}, + {delayEnabled: true, wantDelayOption: true}, } { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5788,7 +5825,7 @@ func TestDelayEnabled(t *testing.T) { } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { t.Helper() var gotDelayEnabled tcp.DelayEnabled @@ -5803,12 +5840,12 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del if err != nil { t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) } - gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption) + gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption) if err != nil { - t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err) + t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err) } if gotDelayOption != wantDelayOption { - t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) + t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) } } @@ -6617,14 +6654,17 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - const keepAliveInterval = 10 * time.Millisecond - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + const keepAliveInterval = 3 * time.Second + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) - - // Set userTimeout to be the duration for 3 keepalive probes. - userTimeout := 30 * time.Millisecond + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) + + // Set userTimeout to be the duration to be 1 keepalive + // probes. Which means that after the first probe is sent + // the second one should cause the connection to be + // closed due to userTimeout being hit. + userTimeout := 1 * keepAliveInterval c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) // Check that the connection is still alive. @@ -6632,28 +6672,23 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) } - // Now receive 2 keepalives, but don't ACK them. The connection should - // be reset when the 3rd one should be sent due to userTimeout being - // 30ms and each keepalive probe should be sent 10ms apart as set above after - // the connection has been idle for 10ms. - for i := 0; i < 2; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } + // Now receive 1 keepalives, but don't ACK it. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) // Sleep for a litte over the KeepAlive interval to make sure // the timer has time to fire after the last ACK and close the // close the socket. - time.Sleep(keepAliveInterval + 5*time.Millisecond) + time.Sleep(keepAliveInterval + keepAliveInterval/2) - // The connection should be terminated after 30ms. + // The connection should be closed with a timeout. // Send an ACK to trigger a RST from the stack as the endpoint should // be dead. c.SendPacket(nil, &context.Headers{ diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index a641e953d..8edbff964 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -127,16 +127,14 @@ func TestTimeStampDisabledConnect(t *testing.T) { } func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() + c := context.New(t, defaultMTU) + defer c.Cleanup() if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } } - c := context.New(t, defaultMTU) - defer c.Cleanup() t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) tsVal := rand.Uint32() @@ -148,7 +146,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Unexpected error from Write: %s", err) } // Check that data is received and that the timestamp option TSEcr field @@ -190,17 +188,15 @@ func TestTimeStampEnabledAccept(t *testing.T) { } func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } - c := context.New(t, defaultMTU) defer c.Cleanup() + if cookieEnabled { + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } + t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) @@ -211,7 +207,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Unexpected error from Write: %s", err) } // Check that data is received and that the timestamp option is disabled diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index d4f6bc635..7b1d72cf4 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -152,6 +152,13 @@ func New(t *testing.T, mtu uint32) *Context { t.Fatalf("SetTransportProtocolOption failed: %v", err) } + // Increase minimum RTO in tests to avoid test flakes due to early + // retransmit in case the test executors are overloaded and cause timers + // to fire earlier than expected. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil { + t.Fatalf("failed to set stack-wide minRTO: %s", err) + } + // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. ep := channel.New(1000, mtu, "") @@ -217,7 +224,8 @@ func (c *Context) Stack() *stack.Stack { func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), wait) + ctx, cancel := context.WithTimeout(context.Background(), wait) + defer cancel() if _, ok := c.linkEP.ReadContext(ctx); ok { c.t.Fatal(errMsg) } @@ -235,7 +243,8 @@ func (c *Context) CheckNoPacket(errMsg string) { func (c *Context) GetPacket() []byte { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { c.t.Fatalf("Packet wasn't written out") @@ -415,6 +424,8 @@ func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlock // verifies that the packet packet payload of packet matches the slice // of data indicated by offset & size. func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { + c.t.Helper() + c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0) } @@ -423,6 +434,8 @@ func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { // data indicated by offset & size and skips optlen bytes in addition to the IP // TCP headers when comparing the data. func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) { + c.t.Helper() + b := c.GetPacket() checker.IPv4(c.t, b, checker.PayloadLen(size+header.TCPMinimumSize+optlen), @@ -445,6 +458,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op // data indicated by offset & size. It returns true if a packet was received and // processed. func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { + c.t.Helper() + b := c.GetPacketNonBlocking() if b == nil { return false @@ -486,7 +501,8 @@ func (c *Context) CreateV6Endpoint(v6only bool) { func (c *Context) GetV6Packet() []byte { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { c.t.Fatalf("Packet wasn't written out") @@ -567,6 +583,8 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf // // PreCondition: c.EP must already be created. func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { + c.t.Helper() + // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventOut) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 120d3baa3..edb54f0be 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -501,11 +501,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { + case tcpip.BroadcastOption: + e.mu.Lock() + e.broadcast = v + e.mu.Unlock() + + case tcpip.MulticastLoopOption: + e.mu.Lock() + e.multicastLoop = v + e.mu.Unlock() + case tcpip.ReceiveTOSOption: e.mu.Lock() e.receiveTOS = v e.mu.Unlock() - return nil case tcpip.ReceiveTClassOption: // We only support this option on v6 endpoints. @@ -516,7 +525,18 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.mu.Lock() e.receiveTClass = v e.mu.Unlock() - return nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.Lock() + e.receiveIPPacketInfo = v + e.mu.Unlock() + + case tcpip.ReuseAddressOption: + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v + e.mu.Unlock() case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. @@ -533,13 +553,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { } e.v6only = v - return nil - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.Lock() - e.receiveIPPacketInfo = v - e.mu.Unlock() - return nil } return nil @@ -547,22 +560,38 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return nil -} + switch opt { + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) e.mu.Unlock() - case tcpip.MulticastTTLOption: + case tcpip.IPv4TOSOption: e.mu.Lock() - e.multicastTTL = uint8(v) + e.sendTOS = uint8(v) e.mu.Unlock() + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + + case tcpip.ReceiveBufferSizeOption: + case tcpip.SendBufferSizeOption: + + } + + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { case tcpip.MulticastInterfaceOption: e.mu.Lock() defer e.mu.Unlock() @@ -686,16 +715,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] - case tcpip.MulticastLoopOption: - e.mu.Lock() - e.multicastLoop = bool(v) - e.mu.Unlock() - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - case tcpip.BindToDeviceOption: id := tcpip.NICID(v) if id != 0 && !e.stack.HasNIC(id) { @@ -704,26 +723,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.bindToDevice = id e.mu.Unlock() - return nil - - case tcpip.BroadcastOption: - e.mu.Lock() - e.broadcast = v != 0 - e.mu.Unlock() - - return nil - - case tcpip.IPv4TOSOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - return nil - - case tcpip.IPv6TrafficClassOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - return nil } return nil } @@ -731,6 +730,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { + case tcpip.BroadcastOption: + e.mu.RLock() + v := e.broadcast + e.mu.RUnlock() + return v, nil + + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.MulticastLoopOption: + e.mu.RLock() + v := e.multicastLoop + e.mu.RUnlock() + return v, nil + case tcpip.ReceiveTOSOption: e.mu.RLock() v := e.receiveTOS @@ -748,6 +762,22 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.ReceiveIPPacketInfoOption: + e.mu.RLock() + v := e.receiveIPPacketInfo + e.mu.RUnlock() + return v, nil + + case tcpip.ReuseAddressOption: + return false, nil + + case tcpip.ReusePortOption: + e.mu.RLock() + v := e.reusePort + e.mu.RUnlock() + + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -760,19 +790,32 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil - case tcpip.ReceiveIPPacketInfoOption: - e.mu.RLock() - v := e.receiveIPPacketInfo - e.mu.RUnlock() - return v, nil + default: + return false, tcpip.ErrUnknownProtocolOption } - - return false, tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { + case tcpip.IPv4TOSOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.MulticastTTLOption: + e.mu.Lock() + v := int(e.multicastTTL) + e.mu.Unlock() + return v, nil + case tcpip.ReceiveQueueSizeOption: v := 0 e.rcvMu.Lock() @@ -794,29 +837,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { v := e.rcvBufSizeMax e.rcvMu.Unlock() return v, nil - } - return -1, tcpip.ErrUnknownProtocolOption + case tcpip.TTLOption: + e.mu.Lock() + v := int(e.ttl) + e.mu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - return nil - - case *tcpip.TTLOption: - e.mu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.mu.Unlock() - return nil - - case *tcpip.MulticastTTLOption: - e.mu.Lock() - *o = tcpip.MulticastTTLOption(e.multicastTTL) - e.mu.Unlock() - return nil - case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ @@ -824,67 +860,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.multicastAddr, } e.mu.Unlock() - return nil - - case *tcpip.MulticastLoopOption: - e.mu.RLock() - v := e.multicastLoop - e.mu.RUnlock() - - *o = tcpip.MulticastLoopOption(v) - return nil - - case *tcpip.ReuseAddressOption: - *o = 0 - return nil - - case *tcpip.ReusePortOption: - e.mu.RLock() - v := e.reusePort - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.BindToDeviceOption: e.mu.RLock() *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - case *tcpip.BroadcastOption: - e.mu.RLock() - v := e.broadcast - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.IPv4TOSOption: - e.mu.RLock() - *o = tcpip.IPv4TOSOption(e.sendTOS) - e.mu.RUnlock() - return nil - - case *tcpip.IPv6TrafficClassOption: - e.mu.RLock() - *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.mu.RUnlock() - return nil default: return tcpip.ErrUnknownProtocolOption } + return nil } // sendUDP sends a UDP segment via the provided network endpoint and under the diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0905726c1..8acaa607a 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -343,11 +343,11 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + c.t.Fatalf("SetSockOptBool failed: %s", err) } } else if flow.isBroadcast() { - if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { - c.t.Fatal("SetSockOpt failed:", err) + if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + c.t.Fatalf("SetSockOptBool failed: %s", err) } } } @@ -358,7 +358,8 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { c.t.Fatalf("Packet wasn't written out") @@ -607,7 +608,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) + c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr) } // Check the payload. @@ -1271,8 +1272,8 @@ func TestTTL(t *testing.T) { c.createEndpointForFlow(flow) const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil { + c.t.Fatalf("SetSockOptInt failed: %s", err) } var wantTTL uint8 @@ -1311,8 +1312,8 @@ func TestSetTTL(t *testing.T) { c.createEndpointForFlow(flow) - if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } var p stack.NetworkProtocol @@ -1346,25 +1347,26 @@ func TestSetTOS(t *testing.T) { c.createEndpointForFlow(flow) const tos = testTOS - var v tcpip.IPv4TOSOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) + c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err) + if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err) } - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) } - if want := tcpip.IPv4TOSOption(tos); v != want { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) + if v != tos { + c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos) } testWrite(c, flow, checker.TOS(tos, 0)) @@ -1381,25 +1383,26 @@ func TestSetTClass(t *testing.T) { c.createEndpointForFlow(flow) const tClass = testTOS - var v tcpip.IPv6TrafficClassOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil { - c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err) + if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil { + c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err) } - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) } - if want := tcpip.IPv6TrafficClassOption(tClass); v != want { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) + if v != tClass { + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass) } // The header getter for TClass is called TOS, so use that checker. @@ -1430,7 +1433,7 @@ func TestReceiveTosTClass(t *testing.T) { // Verify that setting and reading the option works. v, err := c.ep.GetSockOptBool(option) if err != nil { - c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) } // Test for expected default value. if v != false { @@ -1444,7 +1447,7 @@ func TestReceiveTosTClass(t *testing.T) { got, err := c.ep.GetSockOptBool(option) if err != nil { - c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) } if got != want { @@ -1563,7 +1566,8 @@ func TestV4UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) } @@ -1571,7 +1575,8 @@ func TestV4UnknownDestination(t *testing.T) { } // ICMP required. - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { t.Fatalf("packet wasn't written out") @@ -1639,7 +1644,8 @@ func TestV6UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) } @@ -1647,7 +1653,8 @@ func TestV6UnknownDestination(t *testing.T) { } // ICMP required. - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { t.Fatalf("packet wasn't written out") diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index d2f4403b0..cd6a0ea6b 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -29,9 +29,6 @@ import ( ) // IO provides access to the contents of a virtual memory space. -// -// FIXME(b/38173783): Implementations of IO cannot expect ctx to contain any -// meaningful data. type IO interface { // CopyOut copies len(src) bytes from src to the memory mapped at addr. It // returns the number of bytes copied. If the number of bytes copied is < diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 26f68fe3d..5451f1eba 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -21,6 +21,7 @@ go_library( "network.go", "strace.go", "user.go", + "vfs.go", ], visibility = [ "//runsc:__subpackages__", @@ -33,6 +34,7 @@ go_library( "//pkg/control/server", "//pkg/cpuid", "//pkg/eventchannel", + "//pkg/fspath", "//pkg/log", "//pkg/memutil", "//pkg/rand", @@ -40,6 +42,7 @@ go_library( "//pkg/sentry/arch", "//pkg/sentry/arch:registers_go_proto", "//pkg/sentry/control", + "//pkg/sentry/devices/memdev", "//pkg/sentry/fs", "//pkg/sentry/fs/dev", "//pkg/sentry/fs/gofer", @@ -49,6 +52,12 @@ go_library( "//pkg/sentry/fs/sys", "//pkg/sentry/fs/tmpfs", "//pkg/sentry/fs/tty", + "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/fsimpl/gofer", + "//pkg/sentry/fsimpl/host", + "//pkg/sentry/fsimpl/proc", + "//pkg/sentry/fsimpl/sys", + "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel:uncaught_signal_go_proto", @@ -71,6 +80,7 @@ go_library( "//pkg/sentry/time", "//pkg/sentry/unimpl:unimplemented_syscall_go_proto", "//pkg/sentry/usage", + "//pkg/sentry/vfs", "//pkg/sentry/watchdog", "//pkg/sync", "//pkg/syserror", @@ -114,10 +124,12 @@ go_test( "//pkg/p9", "//pkg/sentry/contexttest", "//pkg/sentry/fs", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sync", "//pkg/unet", "//runsc/fsgofer", "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go index 8995d678e..b7cfb35bf 100644 --- a/runsc/boot/compat.go +++ b/runsc/boot/compat.go @@ -65,7 +65,7 @@ func newCompatEmitter(logFD int) (*compatEmitter, error) { if logFD > 0 { f := os.NewFile(uintptr(logFD), "user log file") - target := &log.MultiEmitter{c.sink, &log.K8sJSONEmitter{log.Writer{Next: f}}} + target := &log.MultiEmitter{c.sink, log.K8sJSONEmitter{&log.Writer{Next: f}}} c.sink = &log.BasicLogger{Level: log.Info, Emitter: target} } return c, nil diff --git a/runsc/boot/config.go b/runsc/boot/config.go index 7ea5bfade..715a19112 100644 --- a/runsc/boot/config.go +++ b/runsc/boot/config.go @@ -305,5 +305,10 @@ func (c *Config) ToFlags() []string { if len(c.TestOnlyTestNameEnv) != 0 { f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv) } + + if c.VFS2 { + f = append(f, "--vfs2=true") + } + return f } diff --git a/runsc/boot/fds.go b/runsc/boot/fds.go index 5314b0f2a..7e49f6f9f 100644 --- a/runsc/boot/fds.go +++ b/runsc/boot/fds.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/host" + vfshost "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" "gvisor.dev/gvisor/pkg/sentry/kernel" ) @@ -31,6 +32,10 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F return nil, fmt.Errorf("stdioFDs should contain exactly 3 FDs (stdin, stdout, and stderr), but %d FDs received", len(stdioFDs)) } + if kernel.VFS2Enabled { + return createFDTableVFS2(ctx, console, stdioFDs) + } + k := kernel.KernelFromContext(ctx) fdTable := k.NewFDTable() defer fdTable.DecRef() @@ -78,3 +83,31 @@ func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.F fdTable.IncRef() return fdTable, nil } + +func createFDTableVFS2(ctx context.Context, console bool, stdioFDs []int) (*kernel.FDTable, error) { + k := kernel.KernelFromContext(ctx) + fdTable := k.NewFDTable() + defer fdTable.DecRef() + + hostMount, err := vfshost.NewMount(k.VFS()) + if err != nil { + return nil, fmt.Errorf("creating host mount: %w", err) + } + + for appFD, hostFD := range stdioFDs { + // TODO(gvisor.dev/issue/1482): Add TTY support. + appFile, err := vfshost.ImportFD(hostMount, hostFD, false) + if err != nil { + return nil, err + } + + if err := fdTable.NewFDAtVFS2(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil { + appFile.DecRef() + return nil, err + } + appFile.DecRef() + } + + fdTable.IncRef() + return fdTable, nil +} diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index 06b9f888a..1828d116a 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -44,7 +44,7 @@ var allowedSyscalls = seccomp.SyscallRules{ { seccomp.AllowAny{}, seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.AllowValue(syscall.O_CLOEXEC), }, }, syscall.SYS_EPOLL_CREATE1: {}, diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 0f62842ea..98cce60af 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -278,6 +278,9 @@ func subtargets(root string, mnts []specs.Mount) []string { } func setupContainerFS(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error { + if conf.VFS2 { + return setupContainerVFS2(ctx, conf, mntr, procArgs) + } mns, err := mntr.setupFS(conf, procArgs) if err != nil { return err @@ -573,6 +576,9 @@ func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hin // should be mounted (e.g. a volume shared between containers). It must be // called for the root container only. func (c *containerMounter) processHints(conf *Config) error { + if conf.VFS2 { + return nil + } ctx := c.k.SupervisorContext() for _, hint := range c.hints.mounts { // TODO(b/142076984): Only support tmpfs for now. Bind mounts require a @@ -781,9 +787,6 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) ( useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly default: - // TODO(nlacasse): Support all the mount types and make this a fatal error. - // Most applications will "just work" without them, so this is a warning - // for now. log.Warningf("ignoring unknown filesystem type %q", m.Type) } return fsName, opts, useOverlay, nil @@ -824,7 +827,20 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns inode, err := filesystem.Mount(ctx, mountDevice(m), mf, strings.Join(opts, ","), nil) if err != nil { - return fmt.Errorf("creating mount with source %q: %v", m.Source, err) + err := fmt.Errorf("creating mount with source %q: %v", m.Source, err) + // Check to see if this is a common error due to a Linux bug. + // This error is generated here in order to cause it to be + // printed to the user using Docker via 'runsc create' etc. rather + // than simply printed to the logs for the 'runsc boot' command. + // + // We check the error message string rather than type because the + // actual error types (syscall.EIO, syscall.EPIPE) are lost by file system + // implementation (e.g. p9). + // TODO(gvisor.dev/issue/1765): Remove message when bug is resolved. + if strings.Contains(err.Error(), syscall.EIO.Error()) || strings.Contains(err.Error(), syscall.EPIPE.Error()) { + return fmt.Errorf("%v: %s", err, specutils.FaqErrorMsg("memlock", "you may be encountering a Linux kernel bug")) + } + return err } // If there are submounts, we need to overlay the mount on top of a ramfs diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index e7ca98134..cf1f47bc7 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -26,7 +26,6 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" @@ -73,6 +72,8 @@ import ( _ "gvisor.dev/gvisor/pkg/sentry/socket/unix" ) +var syscallTable *kernel.SyscallTable + // Loader keeps state needed to start the kernel and run the container.. type Loader struct { // k is the kernel. @@ -156,13 +157,17 @@ type Args struct { Spec *specs.Spec // Conf is the system configuration. Conf *Config - // ControllerFD is the FD to the URPC controller. + // ControllerFD is the FD to the URPC controller. The Loader takes ownership + // of this FD and may close it at any time. ControllerFD int - // Device is an optional argument that is passed to the platform. + // Device is an optional argument that is passed to the platform. The Loader + // takes ownership of this file and may close it at any time. Device *os.File - // GoferFDs is an array of FDs used to connect with the Gofer. + // GoferFDs is an array of FDs used to connect with the Gofer. The Loader + // takes ownership of these FDs and may close them at any time. GoferFDs []int - // StdioFDs is the stdio for the application. + // StdioFDs is the stdio for the application. The Loader takes ownership of + // these FDs and may close them at any time. StdioFDs []int // Console is set to true if using TTY. Console bool @@ -175,6 +180,9 @@ type Args struct { UserLogFD int } +// make sure stdioFDs are always the same on initial start and on restore +const startingStdioFD = 64 + // New initializes a new kernel loader configured by spec. // New also handles setting up a kernel for restoring a container. func New(args Args) (*Loader, error) { @@ -188,13 +196,14 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("setting up memory usage: %v", err) } - if args.Conf.VFS2 { - st, ok := kernel.LookupSyscallTable(abi.Linux, arch.Host) - if ok { - vfs2.Override(st.Table) - } + // Patch the syscall table. + kernel.VFS2Enabled = args.Conf.VFS2 + if kernel.VFS2Enabled { + vfs2.Override(syscallTable.Table) } + kernel.RegisterSyscallTable(syscallTable) + // Create kernel and platform. p, err := createPlatform(args.Conf, args.Device) if err != nil { @@ -319,6 +328,24 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("creating pod mount hints: %v", err) } + // Make host FDs stable between invocations. Host FDs must map to the exact + // same number when the sandbox is restored. Otherwise the wrong FD will be + // used. + var stdioFDs []int + newfd := startingStdioFD + for _, fd := range args.StdioFDs { + err := syscall.Dup3(fd, newfd, syscall.O_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("dup3 of stdioFDs failed: %v", err) + } + stdioFDs = append(stdioFDs, newfd) + err = syscall.Close(fd) + if err != nil { + return nil, fmt.Errorf("close original stdioFDs failed: %v", err) + } + newfd++ + } + eid := execID{cid: args.ID} l := &Loader{ k: k, @@ -327,7 +354,7 @@ func New(args Args) (*Loader, error) { watchdog: dog, spec: args.Spec, goferFDs: args.GoferFDs, - stdioFDs: args.StdioFDs, + stdioFDs: stdioFDs, rootProcArgs: procArgs, sandboxID: args.ID, processes: map[execID]*execProcess{eid: {}}, @@ -367,11 +394,16 @@ func newProcess(id string, spec *specs.Spec, creds *auth.Credentials, k *kernel. return kernel.CreateProcessArgs{}, fmt.Errorf("creating limits: %v", err) } + wd := spec.Process.Cwd + if wd == "" { + wd = "/" + } + // Create the process arguments. procArgs := kernel.CreateProcessArgs{ Argv: spec.Process.Args, Envv: spec.Process.Env, - WorkingDirectory: spec.Process.Cwd, // Defaults to '/' if empty. + WorkingDirectory: wd, Credentials: creds, Umask: 0022, Limits: ls, @@ -516,7 +548,15 @@ func (l *Loader) run() error { } // Add the HOME enviroment variable if it is not already set. - envv, err := maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv) + var envv []string + if kernel.VFS2Enabled { + envv, err = maybeAddExecUserHomeVFS2(ctx, l.rootProcArgs.MountNamespaceVFS2, + l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv) + + } else { + envv, err = maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, + l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv) + } if err != nil { return err } @@ -569,6 +609,16 @@ func (l *Loader) run() error { } }) + // l.stdioFDs are derived from dup() in boot.New() and they are now dup()ed again + // either in createFDTable() during initial start or in descriptor.initAfterLoad() + // during restore, we can release l.stdioFDs now. + for _, fd := range l.stdioFDs { + err := syscall.Close(fd) + if err != nil { + return fmt.Errorf("close dup()ed stdioFDs: %v", err) + } + } + log.Infof("Process should have started...") l.watchdog.Start() return l.k.Start() diff --git a/runsc/boot/loader_amd64.go b/runsc/boot/loader_amd64.go index b9669f2ac..78df86611 100644 --- a/runsc/boot/loader_amd64.go +++ b/runsc/boot/loader_amd64.go @@ -17,11 +17,10 @@ package boot import ( - "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" ) func init() { - // Register the global syscall table. - kernel.RegisterSyscallTable(linux.AMD64) + // Set the global syscall table. + syscallTable = linux.AMD64 } diff --git a/runsc/boot/loader_arm64.go b/runsc/boot/loader_arm64.go index cf64d28c8..250785010 100644 --- a/runsc/boot/loader_arm64.go +++ b/runsc/boot/loader_arm64.go @@ -17,11 +17,10 @@ package boot import ( - "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" ) func init() { - // Register the global syscall table. - kernel.RegisterSyscallTable(linux.ARM64) + // Set the global syscall table. + syscallTable = linux.ARM64 } diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index 44aa63196..e7c71734f 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -24,11 +24,13 @@ import ( "time" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/control/server" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/runsc/fsgofer" @@ -65,6 +67,11 @@ func testSpec() *specs.Spec { } } +func resetSyscallTable() { + kernel.VFS2Enabled = false + kernel.FlushSyscallTablesTestOnly() +} + // startGofer starts a new gofer routine serving 'root' path. It returns the // sandbox side of the connection, and a function that when called will stop the // gofer. @@ -100,7 +107,7 @@ func startGofer(root string) (int, func(), error) { return sandboxEnd, cleanup, nil } -func createLoader() (*Loader, func(), error) { +func createLoader(vfsEnabled bool) (*Loader, func(), error) { fd, err := server.CreateSocket(ControlSocketAddr(fmt.Sprintf("%010d", rand.Int())[:10])) if err != nil { return nil, nil, err @@ -108,12 +115,23 @@ func createLoader() (*Loader, func(), error) { conf := testConfig() spec := testSpec() + conf.VFS2 = vfsEnabled + sandEnd, cleanup, err := startGofer(spec.Root.Path) if err != nil { return nil, nil, err } - stdio := []int{int(os.Stdin.Fd()), int(os.Stdout.Fd()), int(os.Stderr.Fd())} + // Loader takes ownership of stdio. + var stdio []int + for _, f := range []*os.File{os.Stdin, os.Stdout, os.Stderr} { + newFd, err := unix.Dup(int(f.Fd())) + if err != nil { + return nil, nil, err + } + stdio = append(stdio, newFd) + } + args := Args{ ID: "foo", Spec: spec, @@ -132,10 +150,22 @@ func createLoader() (*Loader, func(), error) { // TestRun runs a simple application in a sandbox and checks that it succeeds. func TestRun(t *testing.T) { - l, cleanup, err := createLoader() + defer resetSyscallTable() + doRun(t, false) +} + +// TestRunVFS2 runs TestRun in VFSv2. +func TestRunVFS2(t *testing.T) { + defer resetSyscallTable() + doRun(t, true) +} + +func doRun(t *testing.T, vfsEnabled bool) { + l, cleanup, err := createLoader(vfsEnabled) if err != nil { t.Fatalf("error creating loader: %v", err) } + defer l.Destroy() defer cleanup() @@ -169,7 +199,18 @@ func TestRun(t *testing.T) { // TestStartSignal tests that the controller Start message will cause // WaitForStartSignal to return. func TestStartSignal(t *testing.T) { - l, cleanup, err := createLoader() + defer resetSyscallTable() + doStartSignal(t, false) +} + +// TestStartSignalVFS2 does TestStartSignal with VFS2. +func TestStartSignalVFS2(t *testing.T) { + defer resetSyscallTable() + doStartSignal(t, true) +} + +func doStartSignal(t *testing.T, vfsEnabled bool) { + l, cleanup, err := createLoader(vfsEnabled) if err != nil { t.Fatalf("error creating loader: %v", err) } diff --git a/runsc/boot/user.go b/runsc/boot/user.go index f0aa52135..332e4fce5 100644 --- a/runsc/boot/user.go +++ b/runsc/boot/user.go @@ -23,8 +23,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/usermem" ) @@ -84,6 +86,48 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K File: f, } + return findHomeInPasswd(uint32(uid), r, defaultHome) +} + +type fileReaderVFS2 struct { + ctx context.Context + fd *vfs.FileDescription +} + +func (r *fileReaderVFS2) Read(buf []byte) (int, error) { + n, err := r.fd.Read(r.ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + return int(n), err +} + +func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth.KUID) (string, error) { + const defaultHome = "/" + + root := mns.Root() + defer root.DecRef() + + creds := auth.CredentialsFromContext(ctx) + + target := &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse("/etc/passwd"), + } + + opts := &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + } + + fd, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, target, opts) + if err != nil { + return defaultHome, nil + } + defer fd.DecRef() + + r := &fileReaderVFS2{ + ctx: ctx, + fd: fd, + } + homeDir, err := findHomeInPasswd(uint32(uid), r, defaultHome) if err != nil { return "", err @@ -111,6 +155,26 @@ func maybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth. if err != nil { return nil, fmt.Errorf("error reading exec user: %v", err) } + + return append(envv, "HOME="+homeDir), nil +} + +func maybeAddExecUserHomeVFS2(ctx context.Context, vmns *vfs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) { + // Check if the envv already contains HOME. + for _, env := range envv { + if strings.HasPrefix(env, "HOME=") { + // We have it. Return the original slice unmodified. + return envv, nil + } + } + + // Read /etc/passwd for the user's HOME directory and set the HOME + // environment variable as required by POSIX if it is not overridden by + // the user. + homeDir, err := getExecUserHomeVFS2(ctx, vmns, uid) + if err != nil { + return nil, fmt.Errorf("error reading exec user: %v", err) + } return append(envv, "HOME="+homeDir), nil } diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go new file mode 100644 index 000000000..82083c57d --- /dev/null +++ b/runsc/boot/vfs.go @@ -0,0 +1,310 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package boot + +import ( + "fmt" + "path" + "strconv" + "strings" + + specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/devices/memdev" + "gvisor.dev/gvisor/pkg/sentry/fs" + devtmpfsimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + goferimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer" + procimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc" + sysimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys" + tmpfsimpl "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/syserror" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +func registerFilesystems(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) error { + + vfsObj.MustRegisterFilesystemType(rootFsName, &goferimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserList: true, + }) + + vfsObj.MustRegisterFilesystemType(bind, &goferimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserList: true, + }) + + vfsObj.MustRegisterFilesystemType(devpts, &devtmpfsimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) + + vfsObj.MustRegisterFilesystemType(devtmpfs, &devtmpfsimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) + vfsObj.MustRegisterFilesystemType(proc, &procimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) + vfsObj.MustRegisterFilesystemType(sysfs, &sysimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) + vfsObj.MustRegisterFilesystemType(tmpfs, &tmpfsimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) + vfsObj.MustRegisterFilesystemType(nonefs, &sysimpl.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) + + // Setup files in devtmpfs. + if err := memdev.Register(vfsObj); err != nil { + return fmt.Errorf("registering memdev: %w", err) + } + a, err := devtmpfsimpl.NewAccessor(ctx, vfsObj, creds, devtmpfsimpl.Name) + if err != nil { + return fmt.Errorf("creating devtmpfs accessor: %w", err) + } + defer a.Release() + + if err := a.UserspaceInit(ctx); err != nil { + return fmt.Errorf("initializing userspace: %w", err) + } + if err := memdev.CreateDevtmpfsFiles(ctx, a); err != nil { + return fmt.Errorf("creating devtmpfs files: %w", err) + } + return nil +} + +func setupContainerVFS2(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error { + if err := mntr.k.VFS().Init(); err != nil { + return fmt.Errorf("failed to initialize VFS: %w", err) + } + mns, err := mntr.setupVFS2(ctx, conf, procArgs) + if err != nil { + return fmt.Errorf("failed to setupFS: %w", err) + } + procArgs.MountNamespaceVFS2 = mns + return setExecutablePathVFS2(ctx, procArgs) +} + +func setExecutablePathVFS2(ctx context.Context, procArgs *kernel.CreateProcessArgs) error { + + exe := procArgs.Argv[0] + + // Absolute paths can be used directly. + if path.IsAbs(exe) { + procArgs.Filename = exe + return nil + } + + // Paths with '/' in them should be joined to the working directory, or + // to the root if working directory is not set. + if strings.IndexByte(exe, '/') > 0 { + + if !path.IsAbs(procArgs.WorkingDirectory) { + return fmt.Errorf("working directory %q must be absolute", procArgs.WorkingDirectory) + } + + procArgs.Filename = path.Join(procArgs.WorkingDirectory, exe) + return nil + } + + // Paths with a '/' are relative to the CWD. + if strings.IndexByte(exe, '/') > 0 { + procArgs.Filename = path.Join(procArgs.WorkingDirectory, exe) + return nil + } + + // Otherwise, We must lookup the name in the paths, starting from the + // root directory. + root := procArgs.MountNamespaceVFS2.Root() + defer root.DecRef() + + paths := fs.GetPath(procArgs.Envv) + creds := procArgs.Credentials + + for _, p := range paths { + + binPath := path.Join(p, exe) + + pop := &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(binPath), + FollowFinalSymlink: true, + } + + opts := &vfs.OpenOptions{ + FileExec: true, + Flags: linux.O_RDONLY, + } + + dentry, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, pop, opts) + if err == syserror.ENOENT || err == syserror.EACCES { + // Didn't find it here. + continue + } + if err != nil { + return err + } + dentry.DecRef() + + procArgs.Filename = binPath + return nil + } + + return fmt.Errorf("executable %q not found in $PATH=%q", exe, strings.Join(paths, ":")) +} + +func (c *containerMounter) setupVFS2(ctx context.Context, conf *Config, procArgs *kernel.CreateProcessArgs) (*vfs.MountNamespace, error) { + log.Infof("Configuring container's file system with VFS2") + + // Create context with root credentials to mount the filesystem (the current + // user may not be privileged enough). + rootProcArgs := *procArgs + rootProcArgs.WorkingDirectory = "/" + rootProcArgs.Credentials = auth.NewRootCredentials(procArgs.Credentials.UserNamespace) + rootProcArgs.Umask = 0022 + rootProcArgs.MaxSymlinkTraversals = linux.MaxSymlinkTraversals + rootCtx := procArgs.NewContext(c.k) + + creds := procArgs.Credentials + if err := registerFilesystems(rootCtx, c.k.VFS(), creds); err != nil { + return nil, fmt.Errorf("register filesystems: %w", err) + } + + fd := c.fds.remove() + + opts := strings.Join(p9MountOptionsVFS2(fd, conf.FileAccess), ",") + + log.Infof("Mounting root over 9P, ioFD: %d", fd) + mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", rootFsName, &vfs.GetFilesystemOptions{Data: opts}) + if err != nil { + return nil, fmt.Errorf("setting up mountnamespace: %w", err) + } + + rootProcArgs.MountNamespaceVFS2 = mns + + // Mount submounts. + if err := c.mountSubmountsVFS2(rootCtx, conf, mns, creds); err != nil { + return nil, fmt.Errorf("mounting submounts vfs2: %w", err) + } + + return mns, nil +} + +func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials) error { + + for _, submount := range c.mounts { + log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.Source, submount.Destination, submount.Type, submount.Options) + if err := c.mountSubmountVFS2(ctx, conf, mns, creds, &submount); err != nil { + return err + } + } + + // TODO(gvisor.dev/issue/1487): implement mountTmp from fs.go. + + return c.checkDispenser() +} + +// TODO(gvisor.dev/issue/1487): Implement submount options similar to the VFS1 version. +func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *specs.Mount) error { + root := mns.Root() + defer root.DecRef() + target := &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(submount.Destination), + } + + _, options, useOverlay, err := c.getMountNameAndOptionsVFS2(conf, *submount) + if err != nil { + return fmt.Errorf("mountOptions failed: %w", err) + } + + opts := &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + Data: strings.Join(options, ","), + }, + InternalMount: true, + } + + // All writes go to upper, be paranoid and make lower readonly. + opts.ReadOnly = useOverlay + + if err := c.k.VFS().MountAt(ctx, creds, "", target, submount.Type, opts); err != nil { + return fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.Destination, submount.Type, err, opts) + } + log.Infof("Mounted %q to %q type: %s, internal-options: %q", submount.Source, submount.Destination, submount.Type, opts) + return nil +} + +// getMountNameAndOptionsVFS2 retrieves the fsName, opts, and useOverlay values +// used for mounts. +func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m specs.Mount) (string, []string, bool, error) { + var ( + fsName string + opts []string + useOverlay bool + ) + + switch m.Type { + case devpts, devtmpfs, proc, sysfs: + fsName = m.Type + case nonefs: + fsName = sysfs + case tmpfs: + fsName = m.Type + + var err error + opts, err = parseAndFilterOptions(m.Options, tmpfsAllowedOptions...) + if err != nil { + return "", nil, false, err + } + + case bind: + fd := c.fds.remove() + fsName = "9p" + opts = p9MountOptionsVFS2(fd, c.getMountAccessType(m)) + // If configured, add overlay to all writable mounts. + useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly + + default: + log.Warningf("ignoring unknown filesystem type %q", m.Type) + } + return fsName, opts, useOverlay, nil +} + +// p9MountOptions creates a slice of options for a p9 mount. +// TODO(gvisor.dev/issue/1200): Remove this version in favor of the one in +// fs.go when privateunixsocket lands. +func p9MountOptionsVFS2(fd int, fa FileAccessType) []string { + opts := []string{ + "trans=fd", + "rfdno=" + strconv.Itoa(fd), + "wfdno=" + strconv.Itoa(fd), + } + if fa == FileAccessShared { + opts = append(opts, "cache=remote_revalidating") + } + return opts +} diff --git a/runsc/cmd/capability_test.go b/runsc/cmd/capability_test.go index 0c27f7313..9360d7442 100644 --- a/runsc/cmd/capability_test.go +++ b/runsc/cmd/capability_test.go @@ -85,7 +85,7 @@ func TestCapabilities(t *testing.T) { Inheritable: caps, } - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) // Use --network=host to make sandbox use spec's capabilities. conf.Network = boot.NetworkHost diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 02e5af3d3..28f0d54b9 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -272,9 +272,8 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error { root := spec.Root.Path if !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot { - // FIXME: runsc can't be re-executed without - // /proc, so we create a tmpfs mount, mount ./proc and ./root - // there, then move this mount to the root and after + // runsc can't be re-executed without /proc, so we create a tmpfs mount, + // mount ./proc and ./root there, then move this mount to the root and after // setCapsAndCallSelf, runsc will chroot into /root. // // We need a directory to construct a new root and we know that diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 651615d4c..af245b6d8 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -118,7 +118,7 @@ func receiveConsolePTY(srv *unet.ServerSocket) (*os.File, error) { // Test that an pty FD is sent over the console socket if one is provided. func TestConsoleSocket(t *testing.T) { - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) spec := testutil.NewSpecWithArgs("true") rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) @@ -163,7 +163,7 @@ func TestConsoleSocket(t *testing.T) { // Test that job control signals work on a console created with "exec -ti". func TestJobControlSignalExec(t *testing.T) { spec := testutil.NewSpecWithArgs("/bin/sleep", "10000") - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) if err != nil { @@ -286,7 +286,7 @@ func TestJobControlSignalExec(t *testing.T) { // Test that job control signals work on a console created with "run -ti". func TestJobControlSignalRootContainer(t *testing.T) { - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) // Don't let bash execute from profile or rc files, otherwise our PID // counts get messed up. spec := testutil.NewSpecWithArgs("/bin/bash", "--noprofile", "--norc") diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 442e80ac0..5db6d64aa 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -251,12 +251,12 @@ var noOverlay = []configOption{kvm, nonExclusiveFS} var all = append(noOverlay, overlay) // configs generates different configurations to run tests. -func configs(opts ...configOption) []*boot.Config { +func configs(t *testing.T, opts ...configOption) []*boot.Config { // Always load the default config. - cs := []*boot.Config{testutil.TestConfig()} + cs := []*boot.Config{testutil.TestConfig(t)} for _, o := range opts { - c := testutil.TestConfig() + c := testutil.TestConfig(t) switch o { case overlay: c.Overlay = true @@ -285,7 +285,7 @@ func TestLifecycle(t *testing.T) { childReaper.Start() defer childReaper.Stop() - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) // The container will just sleep for a long time. We will kill it before // it finishes sleeping. @@ -457,7 +457,7 @@ func TestExePath(t *testing.T) { t.Fatal(err) } - for _, conf := range configs(overlay) { + for _, conf := range configs(t, overlay) { t.Logf("Running test with conf: %+v", conf) for _, test := range []struct { path string @@ -521,9 +521,19 @@ func TestExePath(t *testing.T) { // Test the we can retrieve the application exit status from the container. func TestAppExitStatus(t *testing.T) { + doAppExitStatus(t, false) +} + +// This is TestAppExitStatus for VFSv2. +func TestAppExitStatusVFS2(t *testing.T) { + doAppExitStatus(t, true) +} + +func doAppExitStatus(t *testing.T, vfs2 bool) { // First container will succeed. succSpec := testutil.NewSpecWithArgs("true") - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) + conf.VFS2 = vfs2 rootDir, bundleDir, err := testutil.SetupContainer(succSpec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -573,7 +583,7 @@ func TestAppExitStatus(t *testing.T) { // TestExec verifies that a container can exec a new program. func TestExec(t *testing.T) { - for _, conf := range configs(overlay) { + for _, conf := range configs(t, overlay) { t.Logf("Running test with conf: %+v", conf) const uid = 343 @@ -667,7 +677,7 @@ func TestExec(t *testing.T) { // TestKillPid verifies that we can signal individual exec'd processes. func TestKillPid(t *testing.T) { - for _, conf := range configs(overlay) { + for _, conf := range configs(t, overlay) { t.Logf("Running test with conf: %+v", conf) app, err := testutil.FindFile("runsc/container/test_app/test_app") @@ -743,7 +753,7 @@ func TestKillPid(t *testing.T) { // be the next consecutive number after the last number from the checkpointed container. func TestCheckpointRestore(t *testing.T) { // Skip overlay because test requires writing to host file. - for _, conf := range configs(noOverlay...) { + for _, conf := range configs(t, noOverlay...) { t.Logf("Running test with conf: %+v", conf) dir, err := ioutil.TempDir(testutil.TmpDir(), "checkpoint-test") @@ -904,7 +914,7 @@ func TestCheckpointRestore(t *testing.T) { // with filesystem Unix Domain Socket use. func TestUnixDomainSockets(t *testing.T) { // Skip overlay because test requires writing to host file. - for _, conf := range configs(noOverlay...) { + for _, conf := range configs(t, noOverlay...) { t.Logf("Running test with conf: %+v", conf) // UDS path is limited to 108 chars for compatibility with older systems. @@ -1042,7 +1052,7 @@ func TestUnixDomainSockets(t *testing.T) { // recreated. Then it resumes the container, verify that the file gets created // again. func TestPauseResume(t *testing.T) { - for _, conf := range configs(noOverlay...) { + for _, conf := range configs(t, noOverlay...) { t.Run(fmt.Sprintf("conf: %+v", conf), func(t *testing.T) { t.Logf("Running test with conf: %+v", conf) @@ -1123,7 +1133,7 @@ func TestPauseResume(t *testing.T) { // occurs given the correct state. func TestPauseResumeStatus(t *testing.T) { spec := testutil.NewSpecWithArgs("sleep", "20") - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -1189,7 +1199,7 @@ func TestCapabilities(t *testing.T) { uid := auth.KUID(os.Getuid() + 1) gid := auth.KGID(os.Getgid() + 1) - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) spec := testutil.NewSpecWithArgs("sleep", "100") @@ -1278,7 +1288,7 @@ func TestCapabilities(t *testing.T) { // TestRunNonRoot checks that sandbox can be configured when running as // non-privileged user. func TestRunNonRoot(t *testing.T) { - for _, conf := range configs(noOverlay...) { + for _, conf := range configs(t, noOverlay...) { t.Logf("Running test with conf: %+v", conf) spec := testutil.NewSpecWithArgs("/bin/true") @@ -1322,7 +1332,7 @@ func TestRunNonRoot(t *testing.T) { // TestMountNewDir checks that runsc will create destination directory if it // doesn't exit. func TestMountNewDir(t *testing.T) { - for _, conf := range configs(overlay) { + for _, conf := range configs(t, overlay) { t.Logf("Running test with conf: %+v", conf) root, err := ioutil.TempDir(testutil.TmpDir(), "root") @@ -1351,7 +1361,7 @@ func TestMountNewDir(t *testing.T) { } func TestReadonlyRoot(t *testing.T) { - for _, conf := range configs(overlay) { + for _, conf := range configs(t, overlay) { t.Logf("Running test with conf: %+v", conf) spec := testutil.NewSpecWithArgs("/bin/touch", "/foo") @@ -1389,7 +1399,7 @@ func TestReadonlyRoot(t *testing.T) { } func TestUIDMap(t *testing.T) { - for _, conf := range configs(noOverlay...) { + for _, conf := range configs(t, noOverlay...) { t.Logf("Running test with conf: %+v", conf) testDir, err := ioutil.TempDir(testutil.TmpDir(), "test-mount") if err != nil { @@ -1470,7 +1480,7 @@ func TestUIDMap(t *testing.T) { } func TestReadonlyMount(t *testing.T) { - for _, conf := range configs(overlay) { + for _, conf := range configs(t, overlay) { t.Logf("Running test with conf: %+v", conf) dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount") @@ -1527,7 +1537,7 @@ func TestAbbreviatedIDs(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir cids := []string{ @@ -1585,7 +1595,7 @@ func TestAbbreviatedIDs(t *testing.T) { func TestGoferExits(t *testing.T) { spec := testutil.NewSpecWithArgs("/bin/sleep", "10000") - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -1654,7 +1664,7 @@ func TestRootNotMount(t *testing.T) { spec.Root.Readonly = true spec.Mounts = nil - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) if err := run(spec, conf); err != nil { t.Fatalf("error running sandbox: %v", err) } @@ -1668,7 +1678,7 @@ func TestUserLog(t *testing.T) { // sched_rr_get_interval = 148 - not implemented in gvisor. spec := testutil.NewSpecWithArgs(app, "syscall", "--syscall=148") - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -1708,7 +1718,7 @@ func TestUserLog(t *testing.T) { } func TestWaitOnExitedSandbox(t *testing.T) { - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) // Run a shell that sleeps for 1 second and then exits with a @@ -1763,7 +1773,7 @@ func TestWaitOnExitedSandbox(t *testing.T) { func TestDestroyNotStarted(t *testing.T) { spec := testutil.NewSpecWithArgs("/bin/sleep", "100") - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -1790,7 +1800,7 @@ func TestDestroyNotStarted(t *testing.T) { func TestDestroyStarting(t *testing.T) { for i := 0; i < 10; i++ { spec := testutil.NewSpecWithArgs("/bin/sleep", "100") - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -1835,7 +1845,7 @@ func TestDestroyStarting(t *testing.T) { } func TestCreateWorkingDir(t *testing.T) { - for _, conf := range configs(overlay) { + for _, conf := range configs(t, overlay) { t.Logf("Running test with conf: %+v", conf) tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "cwd-create") @@ -1908,7 +1918,7 @@ func TestMountPropagation(t *testing.T) { }, } - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) rootDir, bundleDir, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -1959,7 +1969,7 @@ func TestMountPropagation(t *testing.T) { } func TestMountSymlink(t *testing.T) { - for _, conf := range configs(overlay) { + for _, conf := range configs(t, overlay) { t.Logf("Running test with conf: %+v", conf) dir, err := ioutil.TempDir(testutil.TmpDir(), "mount-symlink") @@ -2039,7 +2049,7 @@ func TestNetRaw(t *testing.T) { } for _, enableRaw := range []bool{true, false} { - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.EnableRaw = enableRaw test := "--enabled" @@ -2056,7 +2066,7 @@ func TestNetRaw(t *testing.T) { // TestOverlayfsStaleRead most basic test that '--overlayfs-stale-read' works. func TestOverlayfsStaleRead(t *testing.T) { - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.OverlayfsStaleRead = true in, err := ioutil.TempFile(testutil.TmpDir(), "stale-read.in") @@ -2120,7 +2130,7 @@ func TestTTYField(t *testing.T) { for _, test := range testCases { t.Run(test.name, func(t *testing.T) { - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) // We will run /bin/sleep, possibly with an open TTY. cmd := []string{"/bin/sleep", "10000"} diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 2da93ec5b..dc2fb42ce 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -135,7 +135,7 @@ func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) { // TestMultiContainerSanity checks that it is possible to run 2 dead-simple // containers in the same sandbox. func TestMultiContainerSanity(t *testing.T) { - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) rootDir, err := testutil.SetupRootDir() @@ -173,7 +173,7 @@ func TestMultiContainerSanity(t *testing.T) { // TestMultiPIDNS checks that it is possible to run 2 dead-simple // containers in the same sandbox with different pidns. func TestMultiPIDNS(t *testing.T) { - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) rootDir, err := testutil.SetupRootDir() @@ -218,7 +218,7 @@ func TestMultiPIDNS(t *testing.T) { // TestMultiPIDNSPath checks the pidns path. func TestMultiPIDNSPath(t *testing.T) { - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) rootDir, err := testutil.SetupRootDir() @@ -289,7 +289,7 @@ func TestMultiContainerWait(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir // The first container should run the entire duration of the test. @@ -367,7 +367,7 @@ func TestExecWait(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir // The first container should run the entire duration of the test. @@ -463,7 +463,7 @@ func TestMultiContainerMount(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir containers, cleanup, err := startContainers(conf, sps, ids) @@ -484,7 +484,7 @@ func TestMultiContainerMount(t *testing.T) { // TestMultiContainerSignal checks that it is possible to signal individual // containers without killing the entire sandbox. func TestMultiContainerSignal(t *testing.T) { - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) rootDir, err := testutil.SetupRootDir() @@ -585,7 +585,7 @@ func TestMultiContainerDestroy(t *testing.T) { t.Fatal("error finding test_app:", err) } - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) rootDir, err := testutil.SetupRootDir() @@ -653,7 +653,7 @@ func TestMultiContainerProcesses(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir // Note: use curly braces to keep 'sh' process around. Otherwise, shell @@ -712,7 +712,7 @@ func TestMultiContainerKillAll(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir for _, tc := range []struct { @@ -804,7 +804,7 @@ func TestMultiContainerDestroyNotStarted(t *testing.T) { []string{"/bin/sleep", "100"}, []string{"/bin/sleep", "100"}) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -858,7 +858,7 @@ func TestMultiContainerDestroyStarting(t *testing.T) { } specs, ids := createSpecs(cmds...) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -943,7 +943,7 @@ func TestMultiContainerDifferentFilesystems(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir // Make sure overlay is enabled, and none of the root filesystems are @@ -1006,7 +1006,7 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) { childrenSpecs := allSpecs[1:] childrenIDs := allIDs[1:] - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) rootDir, bundleDir, err := testutil.SetupContainer(rootSpec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -1080,7 +1080,7 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) { // Test that pod shared mounts are properly mounted in 2 containers and that // changes from one container is reflected in the other. func TestMultiContainerSharedMount(t *testing.T) { - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) rootDir, err := testutil.SetupRootDir() @@ -1195,7 +1195,7 @@ func TestMultiContainerSharedMount(t *testing.T) { // Test that pod mounts are mounted as readonly when requested. func TestMultiContainerSharedMountReadonly(t *testing.T) { - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) rootDir, err := testutil.SetupRootDir() @@ -1262,7 +1262,7 @@ func TestMultiContainerSharedMountReadonly(t *testing.T) { // Test that shared pod mounts continue to work after container is restarted. func TestMultiContainerSharedMountRestart(t *testing.T) { - for _, conf := range configs(all...) { + for _, conf := range configs(t, all...) { t.Logf("Running test with conf: %+v", conf) rootDir, err := testutil.SetupRootDir() @@ -1381,7 +1381,7 @@ func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir // Setup the containers. @@ -1463,7 +1463,7 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir // Create the specs. @@ -1500,7 +1500,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir sleep := []string{"sleep", "100"} @@ -1587,7 +1587,7 @@ func TestMultiContainerLoadSandbox(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir // Create containers for the sandbox. @@ -1687,7 +1687,7 @@ func TestMultiContainerRunNonRoot(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir pod, cleanup, err := startContainers(conf, podSpecs, ids) diff --git a/runsc/container/shared_volume_test.go b/runsc/container/shared_volume_test.go index dc4194134..f80852414 100644 --- a/runsc/container/shared_volume_test.go +++ b/runsc/container/shared_volume_test.go @@ -31,7 +31,7 @@ import ( // TestSharedVolume checks that modifications to a volume mount are propagated // into and out of the sandbox. func TestSharedVolume(t *testing.T) { - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.FileAccess = boot.FileAccessShared t.Logf("Running test with conf: %+v", conf) @@ -190,7 +190,7 @@ func checkFile(c *Container, filename string, want []byte) error { // TestSharedVolumeFile tests that changes to file content outside the sandbox // is reflected inside. func TestSharedVolumeFile(t *testing.T) { - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.FileAccess = boot.FileAccessShared t.Logf("Running test with conf: %+v", conf) diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go index 01c47c79f..5f1c4b7d6 100644 --- a/runsc/container/test_app/test_app.go +++ b/runsc/container/test_app/test_app.go @@ -96,7 +96,7 @@ func (c *uds) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) listener, err := net.Listen("unix", c.socketPath) if err != nil { - log.Fatal("error listening on socket %q:", c.socketPath, err) + log.Fatalf("error listening on socket %q: %v", c.socketPath, err) } go server(listener, outputFile) diff --git a/runsc/main.go b/runsc/main.go index 62e184ec9..2baba90f8 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -84,6 +84,7 @@ var ( rootless = flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.") referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.") cpuNumFromQuota = flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)") + vfs2Enabled = flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.") // Test flags, not to be used outside tests, ever. testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") @@ -230,6 +231,7 @@ func main() { ReferenceLeakMode: refsLeakMode, OverlayfsStaleRead: *overlayfsStaleRead, CPUNumFromQuota: *cpuNumFromQuota, + VFS2: *vfs2Enabled, TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot, TestOnlyTestNameEnv: *testOnlyTestNameEnv, @@ -294,9 +296,7 @@ func main() { if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil { cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err) } - } - - if *alsoLogToStderr { + } else if *alsoLogToStderr { e = &log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)} } @@ -313,6 +313,7 @@ func main() { log.Infof("\t\tFileAccess: %v, overlay: %t", conf.FileAccess, conf.Overlay) log.Infof("\t\tNetwork: %v, logging: %t", conf.Network, conf.LogPackets) log.Infof("\t\tStrace: %t, max size: %d, syscalls: %s", conf.Strace, conf.StraceLogSize, conf.StraceSyscalls) + log.Infof("\t\tVFS2 enabled: %v", conf.VFS2) log.Infof("***************************") if *testOnlyAllowRunAsCurrentUserWithoutChroot { @@ -342,11 +343,11 @@ func main() { func newEmitter(format string, logFile io.Writer) log.Emitter { switch format { case "text": - return &log.GoogleEmitter{log.Writer{Next: logFile}} + return log.GoogleEmitter{&log.Writer{Next: logFile}} case "json": - return &log.JSONEmitter{log.Writer{Next: logFile}} + return log.JSONEmitter{&log.Writer{Next: logFile}} case "json-k8s": - return &log.K8sJSONEmitter{log.Writer{Next: logFile}} + return log.K8sJSONEmitter{&log.Writer{Next: logFile}} } cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format) panic("unreachable") diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index 3b06da98b..e82bcef6f 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -18,10 +18,12 @@ package sandbox import ( "context" "fmt" + "io" "math" "os" "os/exec" "strconv" + "strings" "syscall" "time" @@ -142,7 +144,19 @@ func New(conf *boot.Config, args *Args) (*Sandbox, error) { // Wait until the sandbox has booted. b := make([]byte, 1) if l, err := clientSyncFile.Read(b); err != nil || l != 1 { - return nil, fmt.Errorf("waiting for sandbox to start: %v", err) + err := fmt.Errorf("waiting for sandbox to start: %v", err) + // If the sandbox failed to start, it may be because the binary + // permissions were incorrect. Check the bits and return a more helpful + // error message. + // + // NOTE: The error message is checked because error types are lost over + // rpc calls. + if strings.Contains(err.Error(), io.EOF.Error()) { + if permsErr := checkBinaryPermissions(conf); permsErr != nil { + return nil, fmt.Errorf("%v: %v", err, permsErr) + } + } + return nil, err } c.Release() @@ -388,8 +402,6 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF nextFD++ } - cmd.Args = append(cmd.Args, "--panic-signal="+strconv.Itoa(int(syscall.SIGTERM))) - // Add the "boot" command to the args. // // All flags after this must be for the boot command @@ -706,7 +718,19 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF log.Debugf("Starting sandbox: %s %v", binPath, cmd.Args) log.Debugf("SysProcAttr: %+v", cmd.SysProcAttr) if err := specutils.StartInNS(cmd, nss); err != nil { - return fmt.Errorf("Sandbox: %v", err) + err := fmt.Errorf("starting sandbox: %v", err) + // If the sandbox failed to start, it may be because the binary + // permissions were incorrect. Check the bits and return a more helpful + // error message. + // + // NOTE: The error message is checked because error types are lost over + // rpc calls. + if strings.Contains(err.Error(), syscall.EACCES.Error()) { + if permsErr := checkBinaryPermissions(conf); permsErr != nil { + return fmt.Errorf("%v: %v", err, permsErr) + } + } + return err } s.child = true s.Pid = cmd.Process.Pid @@ -1169,3 +1193,31 @@ func deviceFileForPlatform(name string) (*os.File, error) { } return f, nil } + +// checkBinaryPermissions verifies that the required binary bits are set on +// the runsc executable. +func checkBinaryPermissions(conf *boot.Config) error { + // All platforms need the other exe bit + neededBits := os.FileMode(0001) + if conf.Platform == platforms.Ptrace { + // Ptrace needs the other read bit + neededBits |= os.FileMode(0004) + } + + exePath, err := os.Executable() + if err != nil { + return fmt.Errorf("getting exe path: %v", err) + } + + // Check the permissions of the runsc binary and print an error if it + // doesn't match expectations. + info, err := os.Stat(exePath) + if err != nil { + return fmt.Errorf("stat file: %v", err) + } + + if info.Mode().Perm()&neededBits != neededBits { + return fmt.Errorf(specutils.FaqErrorMsg("runsc-perms", fmt.Sprintf("%s does not have the correct permissions", exePath))) + } + return nil +} diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index d3c2e4e78..837d5e238 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -92,6 +92,12 @@ func ValidateSpec(spec *specs.Spec) error { log.Warningf("AppArmor profile %q is being ignored", spec.Process.ApparmorProfile) } + // PR_SET_NO_NEW_PRIVS is assumed to always be set. + // See kernel.Task.updateCredsForExecLocked. + if !spec.Process.NoNewPrivileges { + log.Warningf("noNewPrivileges ignored. PR_SET_NO_NEW_PRIVS is assumed to always be set.") + } + // TODO(gvisor.dev/issue/510): Apply seccomp to application inside sandbox. if spec.Linux != nil && spec.Linux.Seccomp != nil { log.Warningf("Seccomp spec is being ignored") @@ -528,3 +534,8 @@ func EnvVar(env []string, name string) (string, bool) { } return "", false } + +// FaqErrorMsg returns an error message pointing to the FAQ. +func FaqErrorMsg(anchor, msg string) string { + return fmt.Sprintf("%s; see https://gvisor.dev/faq#%s for more details", msg, anchor) +} diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go index 51e487715..5e09f8f16 100644 --- a/runsc/testutil/testutil.go +++ b/runsc/testutil/testutil.go @@ -31,11 +31,13 @@ import ( "os" "os/exec" "os/signal" + "path" "path/filepath" "strconv" "strings" "sync/atomic" "syscall" + "testing" "time" "github.com/cenkalti/backoff" @@ -81,17 +83,16 @@ func ConfigureExePath() error { // TestConfig returns the default configuration to use in tests. Note that // 'RootDir' must be set by caller if required. -func TestConfig() *boot.Config { +func TestConfig(t *testing.T) *boot.Config { logDir := "" if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { logDir = dir + "/" } return &boot.Config{ Debug: true, - DebugLog: logDir, + DebugLog: path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%"), LogFormat: "text", DebugLogFormat: "text", - AlsoLogToStderr: true, LogPackets: true, Network: boot.NetworkNone, Strace: true, diff --git a/scripts/common.sh b/scripts/common.sh index 735a383de..bc6ba71e8 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -89,12 +89,20 @@ function install_runsc() { # be correct, otherwise this may result in a loop that spins until time out. function apt_install() { while true; do - if (sudo apt-get update && sudo apt-get install -y "$@"); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - return $result - fi + sudo apt-get update && + sudo apt-get install -y "$@" && + true + result="${?}" + case $result in + 0) + break + ;; + 100) + # 100 is the error code that apt-get returns. + ;; + *) + exit $result + ;; + esac done } diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index 199823419..b6a254882 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -28,6 +28,7 @@ go_library( "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//keepalive:go_default_library", "@org_golang_x_sys//unix:go_default_library", + "@org_uber_go_multierr//:go_default_library", ], ) @@ -36,4 +37,5 @@ go_test( size = "small", srcs = ["layers_test.go"], library = ":testbench", + deps = ["//pkg/tcpip"], ) diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 579da59c3..f84fd8ba7 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -21,10 +21,12 @@ import ( "fmt" "math/rand" "net" + "strings" "testing" "time" "github.com/mohae/deepcopy" + "go.uber.org/multierr" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -62,384 +64,607 @@ func pickPort() (int, uint16, error) { return fd, uint16(newSockAddrInet4.Port), nil } -// TCPIPv4 maintains state about a TCP/IPv4 connection. -type TCPIPv4 struct { - outgoing Layers - incoming Layers - LocalSeqNum seqnum.Value - RemoteSeqNum seqnum.Value - SynAck *TCP - sniffer Sniffer - injector Injector - portPickerFD int - t *testing.T +// layerState stores the state of a layer of a connection. +type layerState interface { + // outgoing returns an outgoing layer to be sent in a frame. + outgoing() Layer + + // incoming creates an expected Layer for comparing against a received Layer. + // Because the expectation can depend on values in the received Layer, it is + // an input to incoming. For example, the ACK number needs to be checked in a + // TCP packet but only if the ACK flag is set in the received packet. + incoming(received Layer) Layer + + // sent updates the layerState based on the Layer that was sent. The input is + // a Layer with all prev and next pointers populated so that the entire frame + // as it was sent is available. + sent(sent Layer) error + + // received updates the layerState based on a Layer that is receieved. The + // input is a Layer with all prev and next pointers populated so that the + // entire frame as it was receieved is available. + received(received Layer) error + + // close frees associated resources held by the LayerState. + close() error } -// tcpLayerIndex is the position of the TCP layer in the TCPIPv4 connection. It -// is the third, after Ethernet and IPv4. -const tcpLayerIndex int = 2 +// etherState maintains state about an Ethernet connection. +type etherState struct { + out, in Ether +} -// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. -func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { +var _ layerState = (*etherState)(nil) + +// newEtherState creates a new etherState. +func newEtherState(out, in Ether) (*etherState, error) { lMAC, err := tcpip.ParseMACAddress(*localMAC) if err != nil { - t.Fatalf("can't parse localMAC %q: %s", *localMAC, err) + return nil, err } rMAC, err := tcpip.ParseMACAddress(*remoteMAC) if err != nil { - t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err) + return nil, err } - - portPickerFD, localPort, err := pickPort() - if err != nil { - t.Fatalf("can't pick a port: %s", err) + s := etherState{ + out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, + in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *etherState) outgoing() Layer { + return &s.out +} + +func (s *etherState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*etherState) sent(Layer) error { + return nil +} + +func (*etherState) received(Layer) error { + return nil +} + +func (*etherState) close() error { + return nil +} + +// ipv4State maintains state about an IPv4 connection. +type ipv4State struct { + out, in IPv4 +} + +var _ layerState = (*ipv4State)(nil) + +// newIPv4State creates a new ipv4State. +func newIPv4State(out, in IPv4) (*ipv4State, error) { lIP := tcpip.Address(net.ParseIP(*localIPv4).To4()) rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4()) - - sniffer, err := NewSniffer(t) - if err != nil { - t.Fatalf("can't make new sniffer: %s", err) + s := ipv4State{ + out: IPv4{SrcAddr: &lIP, DstAddr: &rIP}, + in: IPv4{SrcAddr: &rIP, DstAddr: &lIP}, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err } + return &s, nil +} - injector, err := NewInjector(t) +func (s *ipv4State) outgoing() Layer { + return &s.out +} + +func (s *ipv4State) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*ipv4State) sent(Layer) error { + return nil +} + +func (*ipv4State) received(Layer) error { + return nil +} + +func (*ipv4State) close() error { + return nil +} + +// tcpState maintains state about a TCP connection. +type tcpState struct { + out, in TCP + localSeqNum, remoteSeqNum *seqnum.Value + synAck *TCP + portPickerFD int + finSent bool +} + +var _ layerState = (*tcpState)(nil) + +// SeqNumValue is a helper routine that allocates a new seqnum.Value value to +// store v and returns a pointer to it. +func SeqNumValue(v seqnum.Value) *seqnum.Value { + return &v +} + +// newTCPState creates a new TCPState. +func newTCPState(out, in TCP) (*tcpState, error) { + portPickerFD, localPort, err := pickPort() if err != nil { - t.Fatalf("can't make new injector: %s", err) + return nil, err + } + s := tcpState{ + out: TCP{SrcPort: &localPort}, + in: TCP{DstPort: &localPort}, + localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())), + portPickerFD: portPickerFD, + finSent: false, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err } + return &s, nil +} - newOutgoingTCP := &TCP{ - SrcPort: &localPort, +func (s *tcpState) outgoing() Layer { + newOutgoing := deepcopy.Copy(s.out).(TCP) + if s.localSeqNum != nil { + newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum)) } - if err := newOutgoingTCP.merge(outgoingTCP); err != nil { - t.Fatalf("can't merge %+v into %+v: %s", outgoingTCP, newOutgoingTCP, err) + if s.remoteSeqNum != nil { + newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum)) } - newIncomingTCP := &TCP{ - DstPort: &localPort, + return &newOutgoing +} + +func (s *tcpState) incoming(received Layer) Layer { + tcpReceived, ok := received.(*TCP) + if !ok { + return nil } - if err := newIncomingTCP.merge(incomingTCP); err != nil { - t.Fatalf("can't merge %+v into %+v: %s", incomingTCP, newIncomingTCP, err) + newIn := deepcopy.Copy(s.in).(TCP) + if s.remoteSeqNum != nil { + newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum)) } - return TCPIPv4{ - outgoing: Layers{ - &Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, - &IPv4{SrcAddr: &lIP, DstAddr: &rIP}, - newOutgoingTCP}, - incoming: Layers{ - &Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, - &IPv4{SrcAddr: &rIP, DstAddr: &lIP}, - newIncomingTCP}, - sniffer: sniffer, - injector: injector, - portPickerFD: portPickerFD, - t: t, - LocalSeqNum: seqnum.Value(rand.Uint32()), + if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 { + // The caller didn't specify an AckNum so we'll expect the calculated one, + // but only if the ACK flag is set because the AckNum is not valid in a + // header if ACK is not set. + newIn.AckNum = Uint32(uint32(*s.localSeqNum)) } + return &newIn } -// Close the injector and sniffer associated with this connection. -func (conn *TCPIPv4) Close() { - conn.sniffer.Close() - conn.injector.Close() - if err := unix.Close(conn.portPickerFD); err != nil { - conn.t.Fatalf("can't close portPickerFD: %s", err) +func (s *tcpState) sent(sent Layer) error { + tcp, ok := sent.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", sent) } - conn.portPickerFD = -1 + if !s.finSent { + // update localSeqNum by the payload only when FIN is not yet sent by us + for current := tcp.next(); current != nil; current = current.next() { + s.localSeqNum.UpdateForward(seqnum.Size(current.length())) + } + } + if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.localSeqNum.UpdateForward(1) + } + if *tcp.Flags&(header.TCPFlagFin) != 0 { + s.finSent = true + } + return nil } -// CreateFrame builds a frame for the connection with tcp overriding defaults -// and additionalLayers added after the TCP header. -func (conn *TCPIPv4) CreateFrame(tcp TCP, additionalLayers ...Layer) Layers { - if tcp.SeqNum == nil { - tcp.SeqNum = Uint32(uint32(conn.LocalSeqNum)) +func (s *tcpState) received(l Layer) error { + tcp, ok := l.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", l) } - if tcp.AckNum == nil { - tcp.AckNum = Uint32(uint32(conn.RemoteSeqNum)) + s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum)) + if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.remoteSeqNum.UpdateForward(1) } - layersToSend := deepcopy.Copy(conn.outgoing).(Layers) - if err := layersToSend[tcpLayerIndex].(*TCP).merge(tcp); err != nil { - conn.t.Fatalf("can't merge %+v into %+v: %s", tcp, layersToSend[tcpLayerIndex], err) + for current := tcp.next(); current != nil; current = current.next() { + s.remoteSeqNum.UpdateForward(seqnum.Size(current.length())) } - layersToSend = append(layersToSend, additionalLayers...) - return layersToSend + return nil } -// SendFrame sends a frame with reasonable defaults. -func (conn *TCPIPv4) SendFrame(frame Layers) { - outBytes, err := frame.toBytes() - if err != nil { - conn.t.Fatalf("can't build outgoing TCP packet: %s", err) +// close frees the port associated with this connection. +func (s *tcpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err } - conn.injector.Send(outBytes) + s.portPickerFD = -1 + return nil +} - // Compute the next TCP sequence number. - for i := tcpLayerIndex + 1; i < len(frame); i++ { - conn.LocalSeqNum.UpdateForward(seqnum.Size(frame[i].length())) +// udpState maintains state about a UDP connection. +type udpState struct { + out, in UDP + portPickerFD int +} + +var _ layerState = (*udpState)(nil) + +// newUDPState creates a new udpState. +func newUDPState(out, in UDP) (*udpState, error) { + portPickerFD, localPort, err := pickPort() + if err != nil { + return nil, err } - tcp := frame[tcpLayerIndex].(*TCP) - if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { - conn.LocalSeqNum.UpdateForward(1) + s := udpState{ + out: UDP{SrcPort: &localPort}, + in: UDP{DstPort: &localPort}, + portPickerFD: portPickerFD, + } + if err := s.out.merge(&out); err != nil { + return nil, err } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil } -// Send a packet with reasonable defaults and override some fields by tcp. -func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { - conn.SendFrame(conn.CreateFrame(tcp, additionalLayers...)) +func (s *udpState) outgoing() Layer { + return &s.out +} + +func (s *udpState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*udpState) sent(l Layer) error { + return nil +} + +func (*udpState) received(l Layer) error { + return nil } -// Recv gets a packet from the sniffer within the timeout provided. -// If no packet arrives before the timeout, it returns nil. -func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP { - layers := conn.RecvFrame(timeout) - if tcpLayerIndex < len(layers) { - return layers[tcpLayerIndex].(*TCP) +// close frees the port associated with this connection. +func (s *udpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err } + s.portPickerFD = -1 return nil } -// RecvFrame gets a frame (of type Layers) within the timeout provided. -// If no frame arrives before the timeout, it returns nil. -func (conn *TCPIPv4) RecvFrame(timeout time.Duration) Layers { - deadline := time.Now().Add(timeout) - for { - timeout = time.Until(deadline) - if timeout <= 0 { - break - } - b := conn.sniffer.Recv(timeout) - if b == nil { - break - } - layers, err := ParseEther(b) - if err != nil { - conn.t.Logf("can't parse frame: %s", err) - continue // Ignore packets that can't be parsed. - } - if !conn.incoming.match(layers) { - continue // Ignore packets that don't match the expected incoming. +// Connection holds a collection of layer states for maintaining a connection +// along with sockets for sniffer and injecting packets. +type Connection struct { + layerStates []layerState + injector Injector + sniffer Sniffer + t *testing.T +} + +// match tries to match each Layer in received against the incoming filter. If +// received is longer than layerStates then that may still count as a match. The +// reverse is never a match. override overrides the default matchers for each +// Layer. +func (conn *Connection) match(override, received Layers) bool { + if len(received) < len(conn.layerStates) { + return false + } + for i, s := range conn.layerStates { + toMatch := s.incoming(received[i]) + if toMatch == nil { + return false } - tcpHeader := (layers[tcpLayerIndex]).(*TCP) - conn.RemoteSeqNum = seqnum.Value(*tcpHeader.SeqNum) - if *tcpHeader.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { - conn.RemoteSeqNum.UpdateForward(1) + if i < len(override) { + toMatch.merge(override[i]) } - for i := tcpLayerIndex + 1; i < len(layers); i++ { - conn.RemoteSeqNum.UpdateForward(seqnum.Size(layers[i].length())) + if !toMatch.match(received[i]) { + return false } - return layers } - return nil + return true } -// Expect a packet that matches the provided tcp within the timeout specified. -// If it doesn't arrive in time, it returns nil. -func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) *TCP { - // We cannot implement this directly using ExpectFrame as we cannot specify - // the Payload part. - deadline := time.Now().Add(timeout) - for { - timeout = time.Until(deadline) - if timeout <= 0 { - return nil +// Close frees associated resources held by the Connection. +func (conn *Connection) Close() { + errs := multierr.Combine(conn.sniffer.close(), conn.injector.close()) + for _, s := range conn.layerStates { + if err := s.close(); err != nil { + errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err)) } - gotTCP := conn.Recv(timeout) - if tcp.match(gotTCP) { - return gotTCP + } + if errs != nil { + conn.t.Fatalf("unable to close %+v: %s", conn, errs) + } +} + +// CreateFrame builds a frame for the connection with layer overriding defaults +// of the innermost layer and additionalLayers added after it. +func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers { + var layersToSend Layers + for _, s := range conn.layerStates { + layersToSend = append(layersToSend, s.outgoing()) + } + if err := layersToSend[len(layersToSend)-1].merge(layer); err != nil { + conn.t.Fatalf("can't merge %+v into %+v: %s", layer, layersToSend[len(layersToSend)-1], err) + } + layersToSend = append(layersToSend, additionalLayers...) + return layersToSend +} + +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *Connection) SendFrame(frame Layers) { + outBytes, err := frame.toBytes() + if err != nil { + conn.t.Fatalf("can't build outgoing TCP packet: %s", err) + } + conn.injector.Send(outBytes) + + // frame might have nil values where the caller wanted to use default values. + // sentFrame will have no nil values in it because it comes from parsing the + // bytes that were actually sent. + sentFrame := parse(parseEther, outBytes) + // Update the state of each layer based on what was sent. + for i, s := range conn.layerStates { + if err := s.sent(sentFrame[i]); err != nil { + conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) } } } -// ExpectFrame expects a frame that matches the specified layers within the +// Send a packet with reasonable defaults. Potentially override the final layer +// in the connection with the provided layer and add additionLayers. +func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) { + conn.SendFrame(conn.CreateFrame(layer, additionalLayers...)) +} + +// recvFrame gets the next successfully parsed frame (of type Layers) within the +// timeout provided. If no parsable frame arrives before the timeout, it returns +// nil. +func (conn *Connection) recvFrame(timeout time.Duration) Layers { + if timeout <= 0 { + return nil + } + b := conn.sniffer.Recv(timeout) + if b == nil { + return nil + } + return parse(parseEther, b) +} + +// Expect a frame with the final layerStates layer matching the provided Layer +// within the timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) { + // Make a frame that will ignore all but the final layer. + layers := make([]Layer, len(conn.layerStates)) + layers[len(layers)-1] = layer + + gotFrame, err := conn.ExpectFrame(layers, timeout) + if err != nil { + return nil, err + } + if len(conn.layerStates)-1 < len(gotFrame) { + return gotFrame[len(conn.layerStates)-1], nil + } + conn.t.Fatal("the received frame should be at least as long as the expected layers") + return nil, fmt.Errorf("the received frame should be at least as long as the expected layers") +} + +// ExpectFrame expects a frame that matches the provided Layers within the // timeout specified. If it doesn't arrive in time, it returns nil. -func (conn *TCPIPv4) ExpectFrame(layers Layers, timeout time.Duration) Layers { +func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) { deadline := time.Now().Add(timeout) + var allLayers []string for { - timeout = time.Until(deadline) - if timeout <= 0 { - return nil + var gotLayers Layers + if timeout = time.Until(deadline); timeout > 0 { + gotLayers = conn.recvFrame(timeout) } - gotLayers := conn.RecvFrame(timeout) - if layers.match(gotLayers) { - return gotLayers + if gotLayers == nil { + return nil, fmt.Errorf("got %d packets:\n%s", len(allLayers), strings.Join(allLayers, "\n")) } + if conn.match(layers, gotLayers) { + for i, s := range conn.layerStates { + if err := s.received(gotLayers[i]); err != nil { + conn.t.Fatal(err) + } + } + return gotLayers, nil + } + allLayers = append(allLayers, fmt.Sprintf("%s", gotLayers)) } } -// ExpectData is a convenient method that expects a TCP packet along with -// the payload to arrive within the timeout specified. If it doesn't arrive -// in time, it causes a fatal test failure. -func (conn *TCPIPv4) ExpectData(tcp TCP, data []byte, timeout time.Duration) { - expected := []Layer{&Ether{}, &IPv4{}, &tcp} - if len(data) > 0 { - expected = append(expected, &Payload{Bytes: data}) +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *Connection) Drain() { + conn.sniffer.Drain() +} + +// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. +type TCPIPv4 Connection + +// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. +func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) + if err != nil { + t.Fatalf("can't make ipv4State: %s", err) } - if conn.ExpectFrame(expected, timeout) == nil { - conn.t.Fatalf("expected to get a TCP frame %s with payload %x", &tcp, data) + tcpState, err := newTCPState(outgoingTCP, incomingTCP) + if err != nil { + t.Fatalf("can't make tcpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return TCPIPv4{ + layerStates: []layerState{etherState, ipv4State, tcpState}, + injector: injector, + sniffer: sniffer, + t: t, } } -// Handshake performs a TCP 3-way handshake. +// Handshake performs a TCP 3-way handshake. The input Connection should have a +// final TCP Layer. func (conn *TCPIPv4) Handshake() { // Send the SYN. conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) // Wait for the SYN-ACK. - conn.SynAck = conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) - if conn.SynAck == nil { - conn.t.Fatalf("didn't get synack during handshake") + synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if synAck == nil { + conn.t.Fatalf("didn't get synack during handshake: %s", err) } + conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck // Send an ACK. conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) } -// UDPIPv4 maintains state about a UDP/IPv4 connection. -type UDPIPv4 struct { - outgoing Layers - incoming Layers - sniffer Sniffer - injector Injector - portPickerFD int - t *testing.T +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + } + return (*Connection)(conn).ExpectFrame(expected, timeout) +} + +// Send a packet with reasonable defaults. Potentially override the TCP layer in +// the connection with the provided layer and add additionLayers. +func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { + (*Connection)(conn).Send(&tcp, additionalLayers...) +} + +// Close frees associated resources held by the TCPIPv4 connection. +func (conn *TCPIPv4) Close() { + (*Connection)(conn).Close() } -// udpLayerIndex is the position of the UDP layer in the UDPIPv4 connection. It -// is the third, after Ethernet and IPv4. -const udpLayerIndex int = 2 +// Expect a frame with the TCP layer matching the provided TCP within the +// timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) { + layer, err := (*Connection)(conn).Expect(&tcp, timeout) + if layer == nil { + return nil, err + } + gotTCP, ok := layer.(*TCP) + if !ok { + conn.t.Fatalf("expected %s to be TCP", layer) + } + return gotTCP, err +} + +func (conn *TCPIPv4) state() *tcpState { + state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState) + if !ok { + conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates) + } + return state +} + +// RemoteSeqNum returns the next expected sequence number from the DUT. +func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value { + return conn.state().remoteSeqNum +} + +// LocalSeqNum returns the next sequence number to send from the testbench. +func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value { + return conn.state().localSeqNum +} + +// SynAck returns the SynAck that was part of the handshake. +func (conn *TCPIPv4) SynAck() *TCP { + return conn.state().synAck +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *TCPIPv4) Drain() { + conn.sniffer.Drain() +} + +// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. +type UDPIPv4 Connection // NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { - lMAC, err := tcpip.ParseMACAddress(*localMAC) + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { - t.Fatalf("can't parse localMAC %q: %s", *localMAC, err) + t.Fatalf("can't make etherState: %s", err) } - - rMAC, err := tcpip.ParseMACAddress(*remoteMAC) + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) if err != nil { - t.Fatalf("can't parse remoteMAC %q: %s", *remoteMAC, err) + t.Fatalf("can't make ipv4State: %s", err) } - - portPickerFD, localPort, err := pickPort() + tcpState, err := newUDPState(outgoingUDP, incomingUDP) if err != nil { - t.Fatalf("can't pick a port: %s", err) + t.Fatalf("can't make udpState: %s", err) } - lIP := tcpip.Address(net.ParseIP(*localIPv4).To4()) - rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4()) - - sniffer, err := NewSniffer(t) + injector, err := NewInjector(t) if err != nil { - t.Fatalf("can't make new sniffer: %s", err) + t.Fatalf("can't make injector: %s", err) } - - injector, err := NewInjector(t) + sniffer, err := NewSniffer(t) if err != nil { - t.Fatalf("can't make new injector: %s", err) + t.Fatalf("can't make sniffer: %s", err) } - newOutgoingUDP := &UDP{ - SrcPort: &localPort, - } - if err := newOutgoingUDP.merge(outgoingUDP); err != nil { - t.Fatalf("can't merge %+v into %+v: %s", outgoingUDP, newOutgoingUDP, err) - } - newIncomingUDP := &UDP{ - DstPort: &localPort, - } - if err := newIncomingUDP.merge(incomingUDP); err != nil { - t.Fatalf("can't merge %+v into %+v: %s", incomingUDP, newIncomingUDP, err) - } return UDPIPv4{ - outgoing: Layers{ - &Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, - &IPv4{SrcAddr: &lIP, DstAddr: &rIP}, - newOutgoingUDP}, - incoming: Layers{ - &Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, - &IPv4{SrcAddr: &rIP, DstAddr: &lIP}, - newIncomingUDP}, - sniffer: sniffer, - injector: injector, - portPickerFD: portPickerFD, - t: t, + layerStates: []layerState{etherState, ipv4State, tcpState}, + injector: injector, + sniffer: sniffer, + t: t, } } -// Close the injector and sniffer associated with this connection. -func (conn *UDPIPv4) Close() { - conn.sniffer.Close() - conn.injector.Close() - if err := unix.Close(conn.portPickerFD); err != nil { - conn.t.Fatalf("can't close portPickerFD: %s", err) - } - conn.portPickerFD = -1 -} - -// CreateFrame builds a frame for the connection with the provided udp -// overriding defaults and the additionalLayers added after the UDP header. -func (conn *UDPIPv4) CreateFrame(udp UDP, additionalLayers ...Layer) Layers { - layersToSend := deepcopy.Copy(conn.outgoing).(Layers) - if err := layersToSend[udpLayerIndex].(*UDP).merge(udp); err != nil { - conn.t.Fatalf("can't merge %+v into %+v: %s", udp, layersToSend[udpLayerIndex], err) - } - layersToSend = append(layersToSend, additionalLayers...) - return layersToSend +// CreateFrame builds a frame for the connection with layer overriding defaults +// of the innermost layer and additionalLayers added after it. +func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers { + return (*Connection)(conn).CreateFrame(layer, additionalLayers...) } -// SendFrame sends a frame with reasonable defaults. +// SendFrame sends a frame on the wire and updates the state of all layers. func (conn *UDPIPv4) SendFrame(frame Layers) { - outBytes, err := frame.toBytes() - if err != nil { - conn.t.Fatalf("can't build outgoing UDP packet: %s", err) - } - conn.injector.Send(outBytes) + (*Connection)(conn).SendFrame(frame) } -// Send a packet with reasonable defaults and override some fields by udp. -func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) { - conn.SendFrame(conn.CreateFrame(udp, additionalLayers...)) -} - -// Recv gets a packet from the sniffer within the timeout provided. If no packet -// arrives before the timeout, it returns nil. -func (conn *UDPIPv4) Recv(timeout time.Duration) *UDP { - deadline := time.Now().Add(timeout) - for { - timeout = time.Until(deadline) - if timeout <= 0 { - break - } - b := conn.sniffer.Recv(timeout) - if b == nil { - break - } - layers, err := ParseEther(b) - if err != nil { - conn.t.Logf("can't parse frame: %s", err) - continue // Ignore packets that can't be parsed. - } - if !conn.incoming.match(layers) { - continue // Ignore packets that don't match the expected incoming. - } - return (layers[udpLayerIndex]).(*UDP) - } - return nil +// Close frees associated resources held by the UDPIPv4 connection. +func (conn *UDPIPv4) Close() { + (*Connection)(conn).Close() } -// Expect a packet that matches the provided udp within the timeout specified. -// If it doesn't arrive in time, the test fails. -func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) *UDP { - deadline := time.Now().Add(timeout) - for { - timeout = time.Until(deadline) - if timeout <= 0 { - return nil - } - gotUDP := conn.Recv(timeout) - if gotUDP == nil { - return nil - } - if udp.match(gotUDP) { - return gotUDP - } - } +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *UDPIPv4) Drain() { + conn.sniffer.Drain() } diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 4d6625941..5ce324f0d 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -15,6 +15,7 @@ package testbench import ( + "encoding/hex" "fmt" "reflect" "strings" @@ -64,6 +65,9 @@ type Layer interface { // setPrev sets the pointer to the Layer encapsulating this one. setPrev(Layer) + + // merge overrides the values in the interface with the provided values. + merge(Layer) error } // LayerBase is the common elements of all layers. @@ -91,6 +95,9 @@ func (lb *LayerBase) setPrev(l Layer) { // equalLayer compares that two Layer structs match while ignoring field in // which either input has a nil and also ignoring the LayerBase of the inputs. func equalLayer(x, y Layer) bool { + if x == nil || y == nil { + return true + } // opt ignores comparison pairs where either of the inputs is a nil. opt := cmp.FilterValues(func(x, y interface{}) bool { for _, l := range []interface{}{x, y} { @@ -104,6 +111,15 @@ func equalLayer(x, y Layer) bool { return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{})) } +// mergeLayer merges other in layer. Any non-nil value in other overrides the +// corresponding value in layer. If other is nil, no action is performed. +func mergeLayer(layer, other Layer) error { + if other == nil { + return nil + } + return mergo.Merge(layer, other, mergo.WithOverride) +} + func stringLayer(l Layer) string { v := reflect.ValueOf(l).Elem() t := v.Type() @@ -118,7 +134,12 @@ func stringLayer(l Layer) string { if v.IsNil() { continue } - ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + v = reflect.Indirect(v) + if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { + ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes()))) + } else { + ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + } } return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " ")) } @@ -153,7 +174,7 @@ func (l *Ether) toBytes() ([]byte, error) { fields.Type = header.IPv4ProtocolNumber default: // TODO(b/150301488): Support more protocols, like IPv6. - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", n) + return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n) } } h.Encode(fields) @@ -172,27 +193,46 @@ func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocol return &v } -// ParseEther parses the bytes assuming that they start with an ethernet header +// layerParser parses the input bytes and returns a Layer along with the next +// layerParser to run. If there is no more parsing to do, the returned +// layerParser is nil. +type layerParser func([]byte) (Layer, layerParser) + +// parse parses bytes starting with the first layerParser and using successive +// layerParsers until all the bytes are parsed. +func parse(parser layerParser, b []byte) Layers { + var layers Layers + for { + var layer Layer + layer, parser = parser(b) + layers = append(layers, layer) + if parser == nil { + break + } + b = b[layer.length():] + } + layers.linkLayers() + return layers +} + +// parseEther parses the bytes assuming that they start with an ethernet header // and continues parsing further encapsulations. -func ParseEther(b []byte) (Layers, error) { +func parseEther(b []byte) (Layer, layerParser) { h := header.Ethernet(b) ether := Ether{ SrcAddr: LinkAddress(h.SourceAddress()), DstAddr: LinkAddress(h.DestinationAddress()), Type: NetworkProtocolNumber(h.Type()), } - layers := Layers{ðer} + var nextParser layerParser switch h.Type() { case header.IPv4ProtocolNumber: - moreLayers, err := ParseIPv4(b[ether.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = parseIPv4 default: - // TODO(b/150301488): Support more protocols, like IPv6. - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %#v", b) + // Assume that the rest is a payload. + nextParser = parsePayload } + return ðer, nextParser } func (l *Ether) match(other Layer) bool { @@ -203,6 +243,12 @@ func (l *Ether) length() int { return header.EthernetMinimumSize } +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *Ether) merge(other Layer) error { + return mergeLayer(l, other) +} + // IPv4 can construct and match an IPv4 encapsulation. type IPv4 struct { LayerBase @@ -274,7 +320,7 @@ func (l *IPv4) toBytes() ([]byte, error) { fields.Protocol = uint8(header.UDPProtocolNumber) default: // TODO(b/150301488): Support more protocols as needed. - return nil, fmt.Errorf("can't deduce the ip header's next protocol: %#v", n) + return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n) } } if l.SrcAddr != nil { @@ -311,9 +357,9 @@ func Address(v tcpip.Address) *tcpip.Address { return &v } -// ParseIPv4 parses the bytes assuming that they start with an ipv4 header and +// parseIPv4 parses the bytes assuming that they start with an ipv4 header and // continues parsing further encapsulations. -func ParseIPv4(b []byte) (Layers, error) { +func parseIPv4(b []byte) (Layer, layerParser) { h := header.IPv4(b) tos, _ := h.TOS() ipv4 := IPv4{ @@ -329,22 +375,17 @@ func ParseIPv4(b []byte) (Layers, error) { SrcAddr: Address(h.SourceAddress()), DstAddr: Address(h.DestinationAddress()), } - layers := Layers{&ipv4} + var nextParser layerParser switch h.TransportProtocol() { case header.TCPProtocolNumber: - moreLayers, err := ParseTCP(b[ipv4.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = parseTCP case header.UDPProtocolNumber: - moreLayers, err := ParseUDP(b[ipv4.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + nextParser = parseUDP + default: + // Assume that the rest is a payload. + nextParser = parsePayload } - return nil, fmt.Errorf("can't deduce the ethernet header's next protocol: %d", h.Protocol()) + return &ipv4, nextParser } func (l *IPv4) match(other Layer) bool { @@ -358,6 +399,12 @@ func (l *IPv4) length() int { return int(*l.IHL) } +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv4) merge(other Layer) error { + return mergeLayer(l, other) +} + // TCP can construct and match a TCP encapsulation. type TCP struct { LayerBase @@ -468,9 +515,9 @@ func Uint32(v uint32) *uint32 { return &v } -// ParseTCP parses the bytes assuming that they start with a tcp header and +// parseTCP parses the bytes assuming that they start with a tcp header and // continues parsing further encapsulations. -func ParseTCP(b []byte) (Layers, error) { +func parseTCP(b []byte) (Layer, layerParser) { h := header.TCP(b) tcp := TCP{ SrcPort: Uint16(h.SourcePort()), @@ -483,12 +530,7 @@ func ParseTCP(b []byte) (Layers, error) { Checksum: Uint16(h.Checksum()), UrgentPointer: Uint16(h.UrgentPointer()), } - layers := Layers{&tcp} - moreLayers, err := ParsePayload(b[tcp.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + return &tcp, parsePayload } func (l *TCP) match(other Layer) bool { @@ -504,8 +546,8 @@ func (l *TCP) length() int { // merge overrides the values in l with the values from other but only in fields // where the value is not nil. -func (l *TCP) merge(other TCP) error { - return mergo.Merge(l, other, mergo.WithOverride) +func (l *TCP) merge(other Layer) error { + return mergeLayer(l, other) } // UDP can construct and match a UDP encapsulation. @@ -556,9 +598,9 @@ func setUDPChecksum(h *header.UDP, udp *UDP) error { return nil } -// ParseUDP parses the bytes assuming that they start with a udp header and -// continues parsing further encapsulations. -func ParseUDP(b []byte) (Layers, error) { +// parseUDP parses the bytes assuming that they start with a udp header and +// returns the parsed layer and the next parser to use. +func parseUDP(b []byte) (Layer, layerParser) { h := header.UDP(b) udp := UDP{ SrcPort: Uint16(h.SourcePort()), @@ -566,12 +608,7 @@ func ParseUDP(b []byte) (Layers, error) { Length: Uint16(h.Length()), Checksum: Uint16(h.Checksum()), } - layers := Layers{&udp} - moreLayers, err := ParsePayload(b[udp.length():]) - if err != nil { - return nil, err - } - return append(layers, moreLayers...), nil + return &udp, parsePayload } func (l *UDP) match(other Layer) bool { @@ -587,8 +624,8 @@ func (l *UDP) length() int { // merge overrides the values in l with the values from other but only in fields // where the value is not nil. -func (l *UDP) merge(other UDP) error { - return mergo.Merge(l, other, mergo.WithOverride) +func (l *UDP) merge(other Layer) error { + return mergeLayer(l, other) } // Payload has bytes beyond OSI layer 4. @@ -601,13 +638,13 @@ func (l *Payload) String() string { return stringLayer(l) } -// ParsePayload parses the bytes assuming that they start with a payload and +// parsePayload parses the bytes assuming that they start with a payload and // continue to the end. There can be no further encapsulations. -func ParsePayload(b []byte) (Layers, error) { +func parsePayload(b []byte) (Layer, layerParser) { payload := Payload{ Bytes: b, } - return Layers{&payload}, nil + return &payload, nil } func (l *Payload) toBytes() ([]byte, error) { @@ -622,18 +659,33 @@ func (l *Payload) length() int { return len(l.Bytes) } +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *Payload) merge(other Layer) error { + return mergeLayer(l, other) +} + // Layers is an array of Layer and supports similar functions to Layer. type Layers []Layer -func (ls *Layers) toBytes() ([]byte, error) { +// linkLayers sets the linked-list ponters in ls. +func (ls *Layers) linkLayers() { for i, l := range *ls { if i > 0 { l.setPrev((*ls)[i-1]) + } else { + l.setPrev(nil) } if i+1 < len(*ls) { l.setNext((*ls)[i+1]) + } else { + l.setNext(nil) } } +} + +func (ls *Layers) toBytes() ([]byte, error) { + ls.linkLayers() outBytes := []byte{} for _, l := range *ls { layerBytes, err := l.toBytes() @@ -649,8 +701,8 @@ func (ls *Layers) match(other Layers) bool { if len(*ls) > len(other) { return false } - for i := 0; i < len(*ls); i++ { - if !equalLayer((*ls)[i], other[i]) { + for i, l := range *ls { + if !equalLayer(l, other[i]) { return false } } diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go index b39839625..b32efda93 100644 --- a/test/packetimpact/testbench/layers_test.go +++ b/test/packetimpact/testbench/layers_test.go @@ -14,7 +14,11 @@ package testbench -import "testing" +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) func TestLayerMatch(t *testing.T) { var nilPayload *Payload @@ -47,3 +51,106 @@ func TestLayerMatch(t *testing.T) { } } } + +func TestLayerStringFormat(t *testing.T) { + for _, tt := range []struct { + name string + l Layer + want string + }{ + { + name: "TCP", + l: &TCP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + SeqNum: Uint32(3452155723), + AckNum: Uint32(2596996163), + DataOffset: Uint8(5), + Flags: Uint8(20), + WindowSize: Uint16(64240), + Checksum: Uint16(0x2e2b), + }, + want: "&testbench.TCP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "SeqNum:3452155723 " + + "AckNum:2596996163 " + + "DataOffset:5 " + + "Flags:20 " + + "WindowSize:64240 " + + "Checksum:11819" + + "}", + }, + { + name: "UDP", + l: &UDP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + Length: Uint16(12), + }, + want: "&testbench.UDP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "Length:12" + + "}", + }, + { + name: "IPv4", + l: &IPv4{ + IHL: Uint8(5), + TOS: Uint8(0), + TotalLength: Uint16(44), + ID: Uint16(0), + Flags: Uint8(2), + FragmentOffset: Uint16(0), + TTL: Uint8(64), + Protocol: Uint8(6), + Checksum: Uint16(0x2e2b), + SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})), + DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})), + }, + want: "&testbench.IPv4{" + + "IHL:5 " + + "TOS:0 " + + "TotalLength:44 " + + "ID:0 " + + "Flags:2 " + + "FragmentOffset:0 " + + "TTL:64 " + + "Protocol:6 " + + "Checksum:11819 " + + "SrcAddr:197.34.63.10 " + + "DstAddr:197.34.63.20" + + "}", + }, + { + name: "Ether", + l: &Ether{ + SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})), + DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})), + Type: NetworkProtocolNumber(4), + }, + want: "&testbench.Ether{" + + "SrcAddr:02:42:c5:22:3f:0a " + + "DstAddr:02:42:c5:22:3f:14 " + + "Type:4" + + "}", + }, + { + name: "Payload", + l: &Payload{ + Bytes: []byte("Hooray for packetimpact."), + }, + want: "&testbench.Payload{Bytes:\n" + + "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" + + "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" + + "}", + }, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.l.String(); got != tt.want { + t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want) + } + }) + } +} diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index 0074484f7..ff722d4a6 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -17,6 +17,7 @@ package testbench import ( "encoding/binary" "flag" + "fmt" "math" "net" "testing" @@ -97,12 +98,36 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { } } -// Close the socket that Sniffer is using. -func (s *Sniffer) Close() { +// Drain drains the Sniffer's socket receive buffer by receiving until there's +// nothing else to receive. +func (s *Sniffer) Drain() { + s.t.Helper() + flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0) + if err != nil { + s.t.Fatalf("failed to get sniffer socket fd flags: %s", err) + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil { + s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err) + } + for { + buf := make([]byte, maxReadSize) + _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) + if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { + break + } + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil { + s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err) + } +} + +// close the socket that Sniffer is using. +func (s *Sniffer) close() error { if err := unix.Close(s.fd); err != nil { - s.t.Fatalf("can't close sniffer socket: %s", err) + return fmt.Errorf("can't close sniffer socket: %w", err) } s.fd = -1 + return nil } // Injector can inject raw frames. @@ -148,10 +173,11 @@ func (i *Injector) Send(b []byte) { } } -// Close the underlying socket. -func (i *Injector) Close() { +// close the underlying socket. +func (i *Injector) close() error { if err := unix.Close(i.fd); err != nil { - i.t.Fatalf("can't close sniffer socket: %s", err) + return fmt.Errorf("can't close sniffer socket: %w", err) } i.fd = -1 + return nil } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index a9b2de9b9..690cee140 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -40,6 +40,54 @@ packetimpact_go_test( ], ) +packetimpact_go_test( + name = "tcp_outside_the_window", + srcs = ["tcp_outside_the_window_test.go"], + # TODO(eyalsoha): Fix #1607 then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_noaccept_close_rst", + srcs = ["tcp_noaccept_close_rst_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_should_piggyback", + srcs = ["tcp_should_piggyback_test.go"], + # TODO(b/153680566): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_close_wait_ack", + srcs = ["tcp_close_wait_ack_test.go"], + # TODO(b/153574037): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + sh_binary( name = "test_runner", srcs = ["test_runner.sh"], diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go index 2b3f39045..b98594f94 100644 --- a/test/packetimpact/tests/fin_wait2_timeout_test.go +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -47,20 +47,22 @@ func TestFinWait2Timeout(t *testing.T) { } dut.Close(acceptFd) - if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); gotOne == nil { - t.Fatal("expected a FIN-ACK within 1 second but got none") + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err) } conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) time.Sleep(5 * time.Second) + conn.Drain() + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) if tt.linger2 { - if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); gotOne == nil { - t.Fatal("expected a RST packet within a second but got none") + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected a RST packet within a second but got none: %s", err) } } else { - if gotOne := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); gotOne != nil { - t.Fatal("expected no RST packets within ten seconds but got one") + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); err == nil { + t.Fatalf("expected no RST packets within ten seconds but got one: %s", err) } } }) diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go new file mode 100644 index 000000000..eb4cc7a65 --- /dev/null +++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go @@ -0,0 +1,102 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_close_wait_ack_test + +import ( + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestCloseWaitAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP + seqNumOffset seqnum.Size + expectAck bool + }{ + {"OTW", GenerateOTWSeqSegment, 0, false}, + {"OTW", GenerateOTWSeqSegment, 1, true}, + {"OTW", GenerateOTWSeqSegment, 2, true}, + {"ACK", GenerateUnaccACKSegment, 0, false}, + {"ACK", GenerateUnaccACKSegment, 1, true}, + {"ACK", GenerateUnaccACKSegment, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Handshake() + acceptFd, _ := dut.Accept(listenFd) + + // Send a FIN to DUT to intiate the active close + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) + } + + // Send a segment with OTW Seq / unacc ACK and expect an ACK back + conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset), &tb.Payload{Bytes: []byte("Sample Data")}) + gotAck, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Fatalf("expected an ack but got none: %s", err) + } + if !tt.expectAck && gotAck != nil { + t.Fatalf("expected no ack but got one: %s", gotAck) + } + + // Now let's verify DUT is indeed in CLOSE_WAIT + dut.Close(acceptFd) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + t.Fatalf("expected DUT to send a FIN: %s", err) + } + // Ack the FIN from DUT + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + // Send some extra data to DUT + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, &tb.Payload{Bytes: []byte("Sample Data")}) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected DUT to send an RST: %s", err) + } + }) + } +} + +// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the +// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK +// is expected from the receiver. +func GenerateOTWSeqSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP { + windowSize := seqnum.Size(*conn.SynAck().WindowSize) + lastAcceptable := conn.LocalSeqNum().Add(windowSize - 1) + otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) + return tb.TCP{SeqNum: tb.Uint32(otwSeq), Flags: tb.Uint8(header.TCPFlagAck)} +} + +// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated +// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is +// expected from the receiver. +func GenerateUnaccACKSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP { + lastAcceptable := conn.RemoteSeqNum() + unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) + return tb.TCP{AckNum: tb.Uint32(unaccAck), Flags: tb.Uint8(header.TCPFlagAck)} +} diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go new file mode 100644 index 000000000..7ebdd1950 --- /dev/null +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_noaccept_close_rst_test + +import ( + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestTcpNoAcceptCloseReset(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + conn.Handshake() + defer conn.Close() + dut.Close(listenFd) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { + t.Fatalf("expected a RST-ACK packet but got none: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go new file mode 100644 index 000000000..db3d3273b --- /dev/null +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_outside_the_window_test + +import ( + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive +// that are inside or outside the TCP window. Packets that are outside the +// window should force an extra ACK, as described in RFC793 page 69: +// https://tools.ietf.org/html/rfc793#page-69 +func TestTCPOutsideTheWindow(t *testing.T) { + for _, tt := range []struct { + description string + tcpFlags uint8 + payload []tb.Layer + seqNumOffset seqnum.Size + expectACK bool + }{ + {"SYN", header.TCPFlagSyn, nil, 0, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true}, + {"ACK", header.TCPFlagAck, nil, 0, false}, + {"FIN", header.TCPFlagFin, nil, 0, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 0, true}, + + {"SYN", header.TCPFlagSyn, nil, 1, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true}, + {"ACK", header.TCPFlagAck, nil, 1, true}, + {"FIN", header.TCPFlagFin, nil, 1, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 1, true}, + + {"SYN", header.TCPFlagSyn, nil, 2, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true}, + {"ACK", header.TCPFlagAck, nil, 2, true}, + {"FIN", header.TCPFlagFin, nil, 2, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + conn.Handshake() + acceptFD, _ := dut.Accept(listenFD) + defer dut.Close(acceptFD) + + windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset + conn.Drain() + // Ignore whatever incrementing that this out-of-order packet might cause + // to the AckNum. + localSeqNum := tb.Uint32(uint32(*conn.LocalSeqNum())) + conn.Send(tb.TCP{ + Flags: tb.Uint8(tt.tcpFlags), + SeqNum: tb.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))), + }, tt.payload...) + timeout := 3 * time.Second + gotACK, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + if tt.expectACK && err != nil { + t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err) + } + if !tt.expectACK && gotACK != nil { + t.Fatalf("expected no ACK packet within %s but got one: %s", timeout, gotACK) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_should_piggyback_test.go b/test/packetimpact/tests/tcp_should_piggyback_test.go new file mode 100644 index 000000000..b0be6ba23 --- /dev/null +++ b/test/packetimpact/tests/tcp_should_piggyback_test.go @@ -0,0 +1,59 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_should_piggyback_test + +import ( + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestPiggyback(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort, WindowSize: tb.Uint16(12)}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Handshake() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + + dut.Send(acceptFd, sampleData, 0) + expectedTCP := tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)} + expectedPayload := tb.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err) + } + + // Cause DUT to send us more data as soon as we ACK their first data segment because we have + // a small window. + dut.Send(acceptFd, sampleData, 0) + + // DUT should ACK our segment by piggybacking ACK to their outstanding data segment instead of + // sending a separate ACK packet. + conn.Send(expectedTCP, &expectedPayload) + if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err) + } +} diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go index b48cc6491..c9354074e 100644 --- a/test/packetimpact/tests/tcp_window_shrink_test.go +++ b/test/packetimpact/tests/tcp_window_shrink_test.go @@ -38,15 +38,22 @@ func TestWindowShrink(t *testing.T) { dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") + samplePayload := &tb.Payload{Bytes: sampleData} dut.Send(acceptFd, sampleData, 0) - conn.ExpectData(tb.TCP{}, sampleData, time.Second) + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) dut.Send(acceptFd, sampleData, 0) dut.Send(acceptFd, sampleData, 0) - conn.ExpectData(tb.TCP{}, sampleData, time.Second) - conn.ExpectData(tb.TCP{}, sampleData, time.Second) + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } // We close our receiving window here conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)}) @@ -54,5 +61,8 @@ func TestWindowShrink(t *testing.T) { // Note: There is another kind of zero-window probing which Windows uses (by sending one // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change // the following lines. - conn.ExpectData(tb.TCP{SeqNum: tb.Uint32(uint32(conn.RemoteSeqNum - 1))}, nil, time.Second) + expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1 + if _, err := conn.ExpectData(&tb.TCP{SeqNum: tb.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with sequence number %v: %s", expectedRemoteSeqNum, err) + } } diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go index bc1b0be49..61fd17050 100644 --- a/test/packetimpact/tests/udp_recv_multicast_test.go +++ b/test/packetimpact/tests/udp_recv_multicast_test.go @@ -30,7 +30,7 @@ func TestUDPRecvMulticast(t *testing.T) { defer dut.Close(boundFD) conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort}) defer conn.Close() - frame := conn.CreateFrame(tb.UDP{}, &tb.Payload{Bytes: []byte("hello world")}) + frame := conn.CreateFrame(&tb.UDP{}, &tb.Payload{Bytes: []byte("hello world")}) frame[1].(*tb.IPv4).DstAddr = tb.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4())) conn.SendFrame(frame) dut.Recv(boundFD, 100, 0) diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index 4038661cb..679342def 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -53,7 +53,7 @@ func verifyPid(pid int, path string) error { if scanner.Err() != nil { return scanner.Err() } - return fmt.Errorf("got: %s, want: %d", gots, pid) + return fmt.Errorf("got: %v, want: %d", gots, pid) } // TestCgroup sets cgroup options and checks that cgroup was properly configured. @@ -106,7 +106,7 @@ func TestMemCGroup(t *testing.T) { time.Sleep(100 * time.Millisecond) } - t.Fatalf("%vMB is less than %vMB: %v", memUsage>>20, allocMemSize>>20) + t.Fatalf("%vMB is less than %vMB", memUsage>>20, allocMemSize>>20) } // TestCgroup sets cgroup options and checks that cgroup was properly configured. diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go index 126f0975a..22488b05d 100644 --- a/test/root/oom_score_adj_test.go +++ b/test/root/oom_score_adj_test.go @@ -46,7 +46,7 @@ func TestOOMScoreAdjSingle(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir ppid, err := specutils.GetParentPid(os.Getpid()) @@ -137,7 +137,7 @@ func TestOOMScoreAdjMulti(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfig() + conf := testutil.TestConfig(t) conf.RootDir = rootDir ppid, err := specutils.GetParentPid(os.Getpid()) diff --git a/test/runtimes/blacklist_test.go b/test/runtimes/blacklist_test.go index 52f49b984..0ff69ab18 100644 --- a/test/runtimes/blacklist_test.go +++ b/test/runtimes/blacklist_test.go @@ -32,6 +32,6 @@ func TestBlacklists(t *testing.T) { t.Fatalf("error parsing blacklist: %v", err) } if *blacklistFile != "" && len(bl) == 0 { - t.Errorf("got empty blacklist for file %q", blacklistFile) + t.Errorf("got empty blacklist for file %q", *blacklistFile) } } diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go index ddb890dbc..3c98f4570 100644 --- a/test/runtimes/runner.go +++ b/test/runtimes/runner.go @@ -114,7 +114,7 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int F: func(t *testing.T) { // Is the test blacklisted? if _, ok := blacklist[tc]; ok { - t.Skip("SKIP: blacklisted test %q", tc) + t.Skipf("SKIP: blacklisted test %q", tc) } var ( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index d0c431234..d9095c95f 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -138,7 +138,6 @@ cc_library( hdrs = ["socket_netlink_route_util.h"], deps = [ ":socket_netlink_util", - "@com_google_absl//absl/types:optional", ], ) @@ -663,10 +662,7 @@ cc_binary( cc_binary( name = "exec_binary_test", testonly = 1, - srcs = select_arch( - amd64 = ["exec_binary.cc"], - arm64 = [], - ), + srcs = ["exec_binary.cc"], linkstatic = 1, deps = [ "//test/util:cleanup", @@ -2026,6 +2022,8 @@ cc_binary( "//test/util:file_descriptor", "@com_google_absl//absl/strings", gtest, + ":ip_socket_test_util", + ":unix_domain_socket_test_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -2802,13 +2800,13 @@ cc_binary( srcs = ["socket_netlink_route.cc"], linkstatic = 1, deps = [ + ":socket_netlink_route_util", ":socket_netlink_util", ":socket_test_util", "//test/util:capability_util", "//test/util:cleanup", "//test/util:file_descriptor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:optional", gtest, "//test/util:test_main", "//test/util:test_util", diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc index a33daff17..806d5729e 100644 --- a/test/syscalls/linux/aio.cc +++ b/test/syscalls/linux/aio.cc @@ -89,6 +89,7 @@ class AIOTest : public FileTest { FileTest::TearDown(); if (ctx_ != 0) { ASSERT_THAT(DestroyContext(), SyscallSucceeds()); + ctx_ = 0; } } @@ -188,14 +189,19 @@ TEST_F(AIOTest, BadWrite) { } TEST_F(AIOTest, ExitWithPendingIo) { - // Setup a context that is 5 entries deep. - ASSERT_THAT(SetupContext(5), SyscallSucceeds()); + // Setup a context that is 100 entries deep. + ASSERT_THAT(SetupContext(100), SyscallSucceeds()); struct iocb cb = CreateCallback(); struct iocb* cbs[] = {&cb}; // Submit a request but don't complete it to make it pending. - EXPECT_THAT(Submit(1, cbs), SyscallSucceeds()); + for (int i = 0; i < 100; ++i) { + EXPECT_THAT(Submit(1, cbs), SyscallSucceeds()); + } + + ASSERT_THAT(DestroyContext(), SyscallSucceeds()); + ctx_ = 0; } int Submitter(void* arg) { diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc index a4f8f3cec..f57d38dc7 100644 --- a/test/syscalls/linux/epoll.cc +++ b/test/syscalls/linux/epoll.cc @@ -56,10 +56,6 @@ TEST(EpollTest, AllWritable) { struct epoll_event result[kFDsPerEpoll]; ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), SyscallSucceedsWithValue(kFDsPerEpoll)); - // TODO(edahlgren): Why do some tests check epoll_event::data, and others - // don't? Does Linux actually guarantee that, in any of these test cases, - // epoll_wait will necessarily write out the epoll_events in the order that - // they were registered? for (int i = 0; i < kFDsPerEpoll; i++) { ASSERT_EQ(result[i].events, EPOLLOUT); } diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc index 736452b0c..1a9f203b9 100644 --- a/test/syscalls/linux/exec_binary.cc +++ b/test/syscalls/linux/exec_binary.cc @@ -48,10 +48,17 @@ namespace { using ::testing::AnyOf; using ::testing::Eq; -#ifndef __x86_64__ +#if !defined(__x86_64__) && !defined(__aarch64__) // The assembly stub and ELF internal details must be ported to other arches. -#error "Test only supported on x86-64" -#endif // __x86_64__ +#error "Test only supported on x86-64/arm64" +#endif // __x86_64__ || __aarch64__ + +#if defined(__x86_64__) +#define EM_TYPE EM_X86_64 +#define IP_REG(p) ((p).rip) +#define RAX_REG(p) ((p).rax) +#define RDI_REG(p) ((p).rdi) +#define RETURN_REG(p) ((p).rax) // amd64 stub that calls PTRACE_TRACEME and sends itself SIGSTOP. const char kPtraceCode[] = { @@ -139,6 +146,76 @@ const char kPtraceCode[] = { // Size of a syscall instruction. constexpr int kSyscallSize = 2; +#elif defined(__aarch64__) +#define EM_TYPE EM_AARCH64 +#define IP_REG(p) ((p).pc) +#define RAX_REG(p) ((p).regs[8]) +#define RDI_REG(p) ((p).regs[0]) +#define RETURN_REG(p) ((p).regs[0]) + +const char kPtraceCode[] = { + // MOVD $117, R8 /* ptrace */ + '\xa8', + '\x0e', + '\x80', + '\xd2', + // MOVD $0, R0 /* PTRACE_TRACEME */ + '\x00', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R1 /* pid */ + '\x01', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R2 /* addr */ + '\x02', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R3 /* data */ + '\x03', + '\x00', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', + // MOVD $172, R8 /* getpid */ + '\x88', + '\x15', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', + // MOVD $129, R8 /* kill, R0=pid */ + '\x28', + '\x10', + '\x80', + '\xd2', + // MOVD $19, R1 /* SIGSTOP */ + '\x61', + '\x02', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', +}; +// Size of a syscall instruction. +constexpr int kSyscallSize = 4; +#else +#error "Unknown architecture" +#endif + // This test suite tests executable loading in the kernel (ELF and interpreter // scripts). @@ -281,7 +358,7 @@ ElfBinary<64> StandardElf() { elf.header.e_ident[EI_DATA] = ELFDATA2LSB; elf.header.e_ident[EI_VERSION] = EV_CURRENT; elf.header.e_type = ET_EXEC; - elf.header.e_machine = EM_X86_64; + elf.header.e_machine = EM_TYPE; elf.header.e_version = EV_CURRENT; elf.header.e_phoff = sizeof(elf.header); elf.header.e_phentsize = sizeof(decltype(elf)::ElfPhdr); @@ -327,9 +404,15 @@ TEST(ElfTest, Execute) { ASSERT_NO_ERRNO(WaitStopped(child)); struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); - // RIP is just beyond the final syscall instruction. - EXPECT_EQ(regs.rip, elf.header.e_entry + sizeof(kPtraceCode)); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); + // RIP/PC is just beyond the final syscall instruction. + EXPECT_EQ(IP_REG(regs), elf.header.e_entry + sizeof(kPtraceCode)); EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, @@ -718,9 +801,16 @@ TEST(ElfTest, PIE) { // RIP tells us which page the first segment was loaded into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ // text page. @@ -787,9 +877,15 @@ TEST(ElfTest, PIENonZeroStart) { // RIP tells us which page the first segment was loaded into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1); // The ELF is loaded at an arbitrary address, not the first PT_LOAD vaddr. // @@ -910,9 +1006,15 @@ TEST(ElfTest, ELFInterpreter) { // RIP tells us which page the first segment of the interpreter was loaded // into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT( child, ContainsMappings(std::vector<ProcMapsEntry>({ @@ -1084,9 +1186,15 @@ TEST(ElfTest, ELFInterpreterRelative) { // RIP tells us which page the first segment of the interpreter was loaded // into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT( child, ContainsMappings(std::vector<ProcMapsEntry>({ @@ -1480,14 +1588,21 @@ TEST(ExecveTest, BrkAfterBinary) { ASSERT_NO_ERRNO(WaitStopped(child)); struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); // RIP is just beyond the final syscall instruction. Rewind to execute a brk // syscall. - regs.rip -= kSyscallSize; - regs.rax = __NR_brk; - regs.rdi = 0; - ASSERT_THAT(ptrace(PTRACE_SETREGS, child, 0, ®s), SyscallSucceeds()); + IP_REG(regs) -= kSyscallSize; + RAX_REG(regs) = __NR_brk; + RDI_REG(regs) = 0; + ASSERT_THAT(ptrace(PTRACE_SETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); // Resume the child, waiting for syscall entry. ASSERT_THAT(ptrace(PTRACE_SYSCALL, child, 0, 0), SyscallSucceeds()); @@ -1504,7 +1619,12 @@ TEST(ExecveTest, BrkAfterBinary) { ASSERT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP) << "status = " << status; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); // brk is after the text page. // @@ -1512,7 +1632,7 @@ TEST(ExecveTest, BrkAfterBinary) { // address will be, but it is always beyond the final page in the binary. // i.e., it does not start immediately after memsz in the middle of a page. // Userspace may expect to use that space. - EXPECT_GE(regs.rax, 0x41000); + EXPECT_GE(RETURN_REG(regs), 0x41000); } } // namespace diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h index 6f80bc97c..fb418e052 100644 --- a/test/syscalls/linux/file_base.h +++ b/test/syscalls/linux/file_base.h @@ -52,17 +52,6 @@ class FileTest : public ::testing::Test { test_file_fd_ = ASSERT_NO_ERRNO_AND_VALUE( Open(test_file_name_, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR)); - // FIXME(edahlgren): enable when mknod syscall is supported. - // test_fifo_name_ = NewTempAbsPath(); - // ASSERT_THAT(mknod(test_fifo_name_.c_str()), S_IFIFO|0644, 0, - // SyscallSucceeds()); - // ASSERT_THAT(test_fifo_[1] = open(test_fifo_name_.c_str(), - // O_WRONLY), - // SyscallSucceeds()); - // ASSERT_THAT(test_fifo_[0] = open(test_fifo_name_.c_str(), - // O_RDONLY), - // SyscallSucceeds()); - ASSERT_THAT(pipe(test_pipe_), SyscallSucceeds()); ASSERT_THAT(fcntl(test_pipe_[0], F_SETFL, O_NONBLOCK), SyscallSucceeds()); } @@ -96,18 +85,12 @@ class FileTest : public ::testing::Test { CloseFile(); UnlinkFile(); ClosePipes(); - - // FIXME(edahlgren): enable when mknod syscall is supported. - // close(test_fifo_[0]); - // close(test_fifo_[1]); - // unlink(test_fifo_name_.c_str()); } + protected: std::string test_file_name_; - std::string test_fifo_name_; FileDescriptor test_file_fd_; - int test_fifo_[2]; int test_pipe_[2]; }; diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc index ff8bdfeb0..853f6231a 100644 --- a/test/syscalls/linux/fork.cc +++ b/test/syscalls/linux/fork.cc @@ -431,7 +431,6 @@ TEST(CloneTest, NewUserNamespacePermitsAllOtherNamespaces) { << "status = " << status; } -#ifdef __x86_64__ // Clone with CLONE_SETTLS and a non-canonical TLS address is rejected. TEST(CloneTest, NonCanonicalTLS) { constexpr uintptr_t kNonCanonical = 1ull << 48; @@ -440,11 +439,25 @@ TEST(CloneTest, NonCanonicalTLS) { // on this. char stack; + // The raw system call interface on x86-64 is: + // long clone(unsigned long flags, void *stack, + // int *parent_tid, int *child_tid, + // unsigned long tls); + // + // While on arm64, the order of the last two arguments is reversed: + // long clone(unsigned long flags, void *stack, + // int *parent_tid, unsigned long tls, + // int *child_tid); +#if defined(__x86_64__) EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr, nullptr, kNonCanonical), SyscallFailsWithErrno(EPERM)); -} +#elif defined(__aarch64__) + EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr, + kNonCanonical, nullptr), + SyscallFailsWithErrno(EPERM)); #endif +} } // namespace } // namespace testing diff --git a/test/syscalls/linux/getrandom.cc b/test/syscalls/linux/getrandom.cc index f97f60029..f87cdd7a1 100644 --- a/test/syscalls/linux/getrandom.cc +++ b/test/syscalls/linux/getrandom.cc @@ -29,6 +29,8 @@ namespace { #define SYS_getrandom 318 #elif defined(__i386__) #define SYS_getrandom 355 +#elif defined(__aarch64__) +#define SYS_getrandom 278 #else #error "Unknown architecture" #endif diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc index bba022a41..98d07ae85 100644 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ b/test/syscalls/linux/ip_socket_test_util.cc @@ -16,7 +16,6 @@ #include <net/if.h> #include <netinet/in.h> -#include <sys/ioctl.h> #include <sys/socket.h> #include <cstring> @@ -35,12 +34,11 @@ uint16_t PortFromInetSockaddr(const struct sockaddr* addr) { } PosixErrorOr<int> InterfaceIndex(std::string name) { - // TODO(igudger): Consider using netlink. - ifreq req = {}; - memcpy(req.ifr_name, name.c_str(), name.size()); - ASSIGN_OR_RETURN_ERRNO(auto sock, Socket(AF_INET, SOCK_DGRAM, 0)); - RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(sock.get(), SIOCGIFINDEX, &req)); - return req.ifr_ifindex; + int index = if_nametoindex(name.c_str()); + if (index) { + return index; + } + return PosixError(errno); } namespace { @@ -177,17 +175,17 @@ SocketKind IPv6TCPUnboundSocket(int type) { PosixError IfAddrHelper::Load() { Release(); RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_)); - return PosixError(0); + return NoError(); } void IfAddrHelper::Release() { if (ifaddr_) { freeifaddrs(ifaddr_); + ifaddr_ = nullptr; } - ifaddr_ = nullptr; } -std::vector<std::string> IfAddrHelper::InterfaceList(int family) { +std::vector<std::string> IfAddrHelper::InterfaceList(int family) const { std::vector<std::string> names; for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) { if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) { @@ -198,7 +196,7 @@ std::vector<std::string> IfAddrHelper::InterfaceList(int family) { return names; } -sockaddr* IfAddrHelper::GetAddr(int family, std::string name) { +const sockaddr* IfAddrHelper::GetAddr(int family, std::string name) const { for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) { if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) { continue; @@ -210,7 +208,7 @@ sockaddr* IfAddrHelper::GetAddr(int family, std::string name) { return nullptr; } -PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) { +PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) const { return InterfaceIndex(name); } diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 39fd6709d..9c3859fcd 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -110,10 +110,10 @@ class IfAddrHelper { PosixError Load(); void Release(); - std::vector<std::string> InterfaceList(int family); + std::vector<std::string> InterfaceList(int family) const; - struct sockaddr* GetAddr(int family, std::string name); - PosixErrorOr<int> GetIndex(std::string name); + const sockaddr* GetAddr(int family, std::string name) const; + PosixErrorOr<int> GetIndex(std::string name) const; private: struct ifaddrs* ifaddr_; diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc index a8af8e545..6ce1e6cc3 100644 --- a/test/syscalls/linux/lseek.cc +++ b/test/syscalls/linux/lseek.cc @@ -53,7 +53,7 @@ TEST(LseekTest, NegativeOffset) { // A 32-bit off_t is not large enough to represent an offset larger than // maximum file size on standard file systems, so it isn't possible to cause // overflow. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(__aarch64__) TEST(LseekTest, Overflow) { // HA! Classic Linux. We really should have an EOVERFLOW // here, since we're seeking to something that cannot be diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc index e57b49a4a..f8b7f7938 100644 --- a/test/syscalls/linux/memfd.cc +++ b/test/syscalls/linux/memfd.cc @@ -16,6 +16,7 @@ #include <fcntl.h> #include <linux/magic.h> #include <linux/memfd.h> +#include <linux/unistd.h> #include <string.h> #include <sys/mman.h> #include <sys/statfs.h> diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc index 367a90fe1..78ac96bed 100644 --- a/test/syscalls/linux/mlock.cc +++ b/test/syscalls/linux/mlock.cc @@ -199,8 +199,10 @@ TEST(MunlockallTest, Basic) { } #ifndef SYS_mlock2 -#ifdef __x86_64__ +#if defined(__x86_64__) #define SYS_mlock2 325 +#elif defined(__aarch64__) +#define SYS_mlock2 284 #endif #endif diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index 11fb1b457..6d3227ab6 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -361,7 +361,7 @@ TEST_F(MMapTest, MapFixed) { } // 64-bit addresses work too -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(__aarch64__) TEST_F(MMapTest, MapFixed64) { EXPECT_THAT(Map(0x300000000000, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0), @@ -571,6 +571,12 @@ const uint8_t machine_code[] = { 0xb8, 0x2a, 0x00, 0x00, 0x00, // movl $42, %eax 0xc3, // retq }; +#elif defined(__aarch64__) +const uint8_t machine_code[] = { + 0x40, 0x05, 0x80, 0x52, // mov w0, #42 + 0xc0, 0x03, 0x5f, 0xd6, // ret +}; +#endif // PROT_EXEC allows code execution TEST_F(MMapTest, ProtExec) { @@ -605,7 +611,6 @@ TEST_F(MMapTest, NoProtExecDeath) { EXPECT_EXIT(func(), ::testing::KilledBySignal(SIGSEGV), ""); } -#endif TEST_F(MMapTest, NoExceedLimitData) { void* prevbrk; @@ -1644,6 +1649,7 @@ TEST(MMapNoFixtureTest, MapReadOnlyAfterCreateWriteOnly) { } // Conditional on MAP_32BIT. +// This flag is supported only on x86-64, for 64-bit programs. #ifdef __x86_64__ TEST(MMapNoFixtureTest, Map32Bit) { diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index d8e19e910..67228b66b 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -265,6 +265,8 @@ TEST_P(PipeTest, OffsetCalls) { SyscallFailsWithErrno(ESPIPE)); struct iovec iov; + iov.iov_base = &buf; + iov.iov_len = sizeof(buf); EXPECT_THAT(preadv(wfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); EXPECT_THAT(pwritev(rfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); } diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc index 2cecf2e5f..bcdbbb044 100644 --- a/test/syscalls/linux/pread64.cc +++ b/test/syscalls/linux/pread64.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <fcntl.h> +#include <linux/unistd.h> #include <sys/mman.h> #include <sys/socket.h> #include <sys/types.h> @@ -118,6 +119,21 @@ TEST_F(Pread64Test, EndOfFile) { EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallSucceedsWithValue(0)); } +int memfd_create(const std::string& name, unsigned int flags) { + return syscall(__NR_memfd_create, name.c_str(), flags); +} + +TEST_F(Pread64Test, Overflow) { + int f = memfd_create("negative", 0); + const FileDescriptor fd(f); + + EXPECT_THAT(ftruncate(fd.get(), 0x7fffffffffffffffull), SyscallSucceeds()); + + char buf[10]; + EXPECT_THAT(pread64(fd.get(), buf, sizeof(buf), 0x7fffffffffffffffull), + SyscallFailsWithErrno(EINVAL)); +} + TEST(Pread64TestNoTempFile, CantReadSocketPair_NoRandomSave) { int sock_fds[2]; EXPECT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds), SyscallSucceeds()); diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 5a70f6c3b..79a625ebc 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -994,7 +994,7 @@ constexpr uint64_t kMappingSize = 100 << 20; // Tolerance on RSS comparisons to account for background thread mappings, // reclaimed pages, newly faulted pages, etc. -constexpr uint64_t kRSSTolerance = 5 << 20; +constexpr uint64_t kRSSTolerance = 10 << 20; // Capture RSS before and after an anonymous mapping with passed prot. void MapPopulateRSS(int prot, uint64_t* before, uint64_t* after) { @@ -1326,8 +1326,6 @@ TEST(ProcPidSymlink, SubprocessRunning) { SyscallSucceedsWithValue(sizeof(buf))); } -// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux -// on proc files. TEST(ProcPidSymlink, SubprocessZombied) { ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); @@ -1337,7 +1335,7 @@ TEST(ProcPidSymlink, SubprocessZombied) { int want = EACCES; if (!IsRunningOnGvisor()) { auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); - if (version.major == 4 && version.minor > 3) { + if (version.major > 4 || (version.major == 4 && version.minor > 3)) { want = ENOENT; } } @@ -1350,30 +1348,25 @@ TEST(ProcPidSymlink, SubprocessZombied) { SyscallFailsWithErrno(want)); } - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux - // on proc files. + // FIXME(gvisor.dev/issue/164): Inconsistent behavior between linux on proc + // files. // // ~4.3: Syscall fails with EACCES. - // 4.17 & gVisor: Syscall succeeds and returns 1. + // 4.17: Syscall succeeds and returns 1. // - // EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)), - // SyscallFailsWithErrno(EACCES)); + if (!IsRunningOnGvisor()) { + return; + } - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux - // on proc files. - // - // ~4.3: Syscall fails with EACCES. - // 4.17 & gVisor: Syscall succeeds and returns 1. - // - // EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)), - // SyscallFailsWithErrno(EACCES)); + EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)), + SyscallFailsWithErrno(want)); + + EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)), + SyscallFailsWithErrno(want)); } // Test whether /proc/PID/ symlinks can be read for an exited process. TEST(ProcPidSymlink, SubprocessExited) { - // FIXME(gvisor.dev/issue/164): These all succeed on gVisor. - SKIP_IF(IsRunningOnGvisor()); - char buf[1]; EXPECT_THAT(ReadlinkWhileExited("exe", buf, sizeof(buf)), diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 4e23d1e78..cac394910 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -353,7 +353,7 @@ TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) { EXPECT_EQ(oldNoPorts, newNoPorts - 1); } -TEST(ProcNetSnmp, UdpIn) { +TEST(ProcNetSnmp, UdpIn_NoRandomSave) { // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. const DisableSave ds; diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index cb828ff88..926690eb8 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -400,9 +400,11 @@ TEST(PtraceTest, GetRegSet) { // Read exactly the full register set. EXPECT_EQ(iov.iov_len, sizeof(regs)); -#ifdef __x86_64__ +#if defined(__x86_64__) // Child called kill(2), with SIGSTOP as arg 2. EXPECT_EQ(regs.rsi, SIGSTOP); +#elif defined(__aarch64__) + EXPECT_EQ(regs.regs[1], SIGSTOP); #endif // Suppress SIGSTOP and resume the child. @@ -752,15 +754,23 @@ TEST(PtraceTest, SyscallSucceeds()); EXPECT_TRUE(siginfo.si_code == SIGTRAP || siginfo.si_code == (SIGTRAP | 0x80)) << "si_code = " << siginfo.si_code; -#ifdef __x86_64__ + { struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov), + SyscallSucceeds()); +#if defined(__x86_64__) EXPECT_TRUE(regs.orig_rax == SYS_vfork || regs.orig_rax == SYS_clone) << "orig_rax = " << regs.orig_rax; EXPECT_EQ(grandchild_pid, regs.rax); - } +#elif defined(__aarch64__) + EXPECT_TRUE(regs.regs[8] == SYS_clone) << "regs[8] = " << regs.regs[8]; + EXPECT_EQ(grandchild_pid, regs.regs[0]); #endif // defined(__x86_64__) + } // After this point, the child will be making wait4 syscalls that will be // interrupted by saving, so saving is not permitted. Note that this is @@ -805,14 +815,21 @@ TEST(PtraceTest, SyscallSucceedsWithValue(child_pid)); EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80)) << " status " << status; -#ifdef __x86_64__ { struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov), + SyscallSucceeds()); +#if defined(__x86_64__) EXPECT_EQ(SYS_wait4, regs.orig_rax); EXPECT_EQ(grandchild_pid, regs.rax); - } +#elif defined(__aarch64__) + EXPECT_EQ(SYS_wait4, regs.regs[8]); + EXPECT_EQ(grandchild_pid, regs.regs[0]); #endif // defined(__x86_64__) + } // Detach from the child and wait for it to exit. ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds()); diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc index b48fe540d..e69794910 100644 --- a/test/syscalls/linux/pwrite64.cc +++ b/test/syscalls/linux/pwrite64.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <fcntl.h> +#include <linux/unistd.h> #include <sys/socket.h> #include <sys/types.h> #include <unistd.h> @@ -27,14 +28,7 @@ namespace testing { namespace { -// This test is currently very rudimentary. -// -// TODO(edahlgren): -// * bad buffer states (EFAULT). -// * bad fds (wrong permission, wrong type of file, EBADF). -// * check offset is not incremented. -// * check for EOF. -// * writing to pipes, symlinks, special files. +// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary. class Pwrite64 : public ::testing::Test { void SetUp() override { name_ = NewTempAbsPath(); @@ -72,6 +66,17 @@ TEST_F(Pwrite64, InvalidArgs) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST_F(Pwrite64, Overflow) { + int fd; + ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds()); + constexpr int64_t kBufSize = 1024; + std::vector<char> buf(kBufSize); + std::fill(buf.begin(), buf.end(), 'a'); + EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), 0x7fffffffffffffffull), + SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(close(fd), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index ebaafe47e..64123e904 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <linux/unistd.h> #include <sys/eventfd.h> #include <sys/sendfile.h> #include <unistd.h> @@ -70,6 +71,28 @@ TEST(SendFileTest, InvalidOffset) { SyscallFailsWithErrno(EINVAL)); } +int memfd_create(const std::string& name, unsigned int flags) { + return syscall(__NR_memfd_create, name.c_str(), flags); +} + +TEST(SendFileTest, Overflow) { + // Create input file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Open the output file. + int fd; + EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds()); + const FileDescriptor outf(fd); + + // out_offset + kSize overflows INT64_MAX. + loff_t out_offset = 0x7ffffffffffffffeull; + constexpr int kSize = 3; + EXPECT_THAT(sendfile(outf.get(), inf.get(), &out_offset, kSize), + SyscallFailsWithErrno(EINVAL)); +} + TEST(SendFileTest, SendTrivially) { // Create temp files. constexpr char kData[] = "To be, or not to be, that is the question:"; diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc index e94672679..c101fe9d2 100644 --- a/test/syscalls/linux/sendfile_socket.cc +++ b/test/syscalls/linux/sendfile_socket.cc @@ -23,6 +23,7 @@ #include "gtest/gtest.h" #include "absl/strings/string_view.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" @@ -35,61 +36,39 @@ namespace { class SendFileTest : public ::testing::TestWithParam<int> { protected: - PosixErrorOr<std::tuple<int, int>> Sockets() { + PosixErrorOr<std::unique_ptr<SocketPair>> Sockets(int type) { // Bind a server socket. int family = GetParam(); - struct sockaddr server_addr = {}; switch (family) { case AF_INET: { - struct sockaddr_in* server_addr_in = - reinterpret_cast<struct sockaddr_in*>(&server_addr); - server_addr_in->sin_family = family; - server_addr_in->sin_addr.s_addr = INADDR_ANY; - break; + if (type == SOCK_STREAM) { + return SocketPairKind{ + "TCP", AF_INET, type, 0, + TCPAcceptBindSocketPairCreator(AF_INET, type, 0, false)} + .Create(); + } else { + return SocketPairKind{ + "UDP", AF_INET, type, 0, + UDPBidirectionalBindSocketPairCreator(AF_INET, type, 0, false)} + .Create(); + } } case AF_UNIX: { - struct sockaddr_un* server_addr_un = - reinterpret_cast<struct sockaddr_un*>(&server_addr); - server_addr_un->sun_family = family; - server_addr_un->sun_path[0] = '\0'; - break; + if (type == SOCK_STREAM) { + return SocketPairKind{ + "UNIX", AF_UNIX, type, 0, + FilesystemAcceptBindSocketPairCreator(AF_UNIX, type, 0)} + .Create(); + } else { + return SocketPairKind{ + "UNIX", AF_UNIX, type, 0, + FilesystemBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)} + .Create(); + } } default: return PosixError(EINVAL); } - int server = socket(family, SOCK_STREAM, 0); - if (bind(server, &server_addr, sizeof(server_addr)) < 0) { - return PosixError(errno); - } - if (listen(server, 1) < 0) { - close(server); - return PosixError(errno); - } - - // Fetch the address; both are anonymous. - socklen_t length = sizeof(server_addr); - if (getsockname(server, &server_addr, &length) < 0) { - close(server); - return PosixError(errno); - } - - // Connect the client. - int client = socket(family, SOCK_STREAM, 0); - if (connect(client, &server_addr, length) < 0) { - close(server); - close(client); - return PosixError(errno); - } - - // Accept on the server. - int server_client = accept(server, nullptr, 0); - if (server_client < 0) { - close(server); - close(client); - return PosixError(errno); - } - close(server); - return std::make_tuple(client, server_client); } }; @@ -106,9 +85,7 @@ TEST_P(SendFileTest, SendMultiple) { const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); // Create sockets. - std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); - const FileDescriptor server(std::get<0>(fds)); - FileDescriptor client(std::get<1>(fds)); // non-const, reset is used. + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM)); // Thread that reads data from socket and dumps to a file. ScopedThread th([&] { @@ -118,7 +95,7 @@ TEST_P(SendFileTest, SendMultiple) { // Read until socket is closed. char buf[10240]; for (int cnt = 0;; cnt++) { - int r = RetryEINTR(read)(server.get(), buf, sizeof(buf)); + int r = RetryEINTR(read)(socks->first_fd(), buf, sizeof(buf)); // We cannot afford to save on every read() call. if (cnt % 1000 == 0) { ASSERT_THAT(r, SyscallSucceeds()); @@ -152,7 +129,7 @@ TEST_P(SendFileTest, SendMultiple) { << ", remain=" << remain << std::endl; // Send data and verify that sendfile returns the correct value. - int res = sendfile(client.get(), inf.get(), nullptr, remain); + int res = sendfile(socks->second_fd(), inf.get(), nullptr, remain); // We cannot afford to save on every sendfile() call. if (cnt % 120 == 0) { MaybeSave(); @@ -169,7 +146,7 @@ TEST_P(SendFileTest, SendMultiple) { } // Close socket to stop thread. - client.reset(); + close(socks->release_second_fd()); th.Join(); // Verify that the output file has the correct data. @@ -183,9 +160,7 @@ TEST_P(SendFileTest, SendMultiple) { TEST_P(SendFileTest, Shutdown) { // Create a socket. - std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); - const FileDescriptor client(std::get<0>(fds)); - FileDescriptor server(std::get<1>(fds)); // non-const, reset below. + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM)); // If this is a TCP socket, then turn off linger. if (GetParam() == AF_INET) { @@ -193,7 +168,7 @@ TEST_P(SendFileTest, Shutdown) { sl.l_onoff = 1; sl.l_linger = 0; ASSERT_THAT( - setsockopt(server.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + setsockopt(socks->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), SyscallSucceeds()); } @@ -212,12 +187,12 @@ TEST_P(SendFileTest, Shutdown) { ScopedThread t([&]() { size_t done = 0; while (done < data.size()) { - int n = RetryEINTR(read)(server.get(), data.data(), data.size()); + int n = RetryEINTR(read)(socks->first_fd(), data.data(), data.size()); ASSERT_THAT(n, SyscallSucceeds()); done += n; } // Close the server side socket. - server.reset(); + close(socks->release_first_fd()); }); // Continuously stream from the file to the socket. Note we do not assert @@ -225,7 +200,7 @@ TEST_P(SendFileTest, Shutdown) { // data is written. Eventually, we should get a connection reset error. while (1) { off_t offset = 0; // Always read from the start. - int n = sendfile(client.get(), inf.get(), &offset, data.size()); + int n = sendfile(socks->second_fd(), inf.get(), &offset, data.size()); EXPECT_THAT(n, AnyOf(SyscallFailsWithErrno(ECONNRESET), SyscallFailsWithErrno(EPIPE), SyscallSucceeds())); if (n <= 0) { @@ -234,6 +209,20 @@ TEST_P(SendFileTest, Shutdown) { } } +TEST_P(SendFileTest, SendpageFromEmptyFileToUDP) { + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_DGRAM)); + + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + + // The value to the count argument has to be so that it is impossible to + // allocate a buffer of this size. In Linux, sendfile transfer at most + // 0x7ffff000 (MAX_RW_COUNT) bytes. + EXPECT_THAT(sendfile(socks->first_fd(), fd.get(), 0x0, 0x8000000000004), + SyscallSucceedsWithValue(0)); +} + INSTANTIATE_TEST_SUITE_P(AddressFamily, SendFileTest, ::testing::Values(AF_UNIX, AF_INET)); diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 1b34e4ef7..d3000dbc6 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -319,17 +319,14 @@ TEST_P(SocketInetLoopbackTest, TCPListenUnbound) { tcpSimpleConnectTest(listener, connector, false); } -TEST_P(SocketInetLoopbackTest, TCPListenClose) { +TEST_P(SocketInetLoopbackTest, TCPListenShutdown) { auto const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; - constexpr int kAcceptCount = 32; - constexpr int kBacklog = kAcceptCount * 2; - constexpr int kFDs = 128; - constexpr int kThreadCount = 4; - constexpr int kFDsPerThread = kFDs / kThreadCount; + constexpr int kBacklog = 2; + constexpr int kFDs = kBacklog + 1; // Create the listening socket. FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( @@ -348,39 +345,167 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - DisableSave ds; // Too many system calls. sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - FileDescriptor clients[kFDs]; - std::unique_ptr<ScopedThread> threads[kThreadCount]; + + // Shutdown the write of the listener, expect to not have any effect. + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_WR), SyscallSucceeds()); + for (int i = 0; i < kFDs; i++) { - clients[i] = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds()); } - for (int i = 0; i < kThreadCount; i++) { - threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr, - &clients, i]() { - for (int j = 0; j < kFDsPerThread; j++) { - int k = i * kFDsPerThread + j; - int ret = - connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); - if (ret != 0) { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - } - } - }); + + // Shutdown the read of the listener, expect to fail subsequent + // server accepts, binds and client connects. + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds()); + + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EINVAL)); + + // Check that shutdown did not release the port. + FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + bind(new_listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Check that subsequent connection attempts receive a RST. + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + for (int i = 0; i < kFDs; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallFailsWithErrno(ECONNREFUSED)); } - for (int i = 0; i < kThreadCount; i++) { - threads[i]->Join(); +} + +TEST_P(SocketInetLoopbackTest, TCPListenClose) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kAcceptCount = 2; + constexpr int kBacklog = kAcceptCount + 2; + constexpr int kFDs = kBacklog * 3; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector<FileDescriptor> clients; + for (int i = 0; i < kFDs; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + } + clients.push_back(std::move(client)); } for (int i = 0; i < kAcceptCount; i++) { auto accepted = ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); } - // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked - // before function end. - // ds.reset(); +} + +void TestListenWhileConnect(const TestParam& param, + void (*stopListen)(FileDescriptor&)) { + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kBacklog = 2; + constexpr int kClients = kBacklog + 1; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector<FileDescriptor> clients; + for (int i = 0; i < kClients; i++) { + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + clients.push_back(std::move(client)); + } + } + + stopListen(listen_fd); + + for (auto& client : clients) { + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = client.get(), + .events = POLLIN, + }; + // When the listening socket is closed, then we expect the remote to reset + // the connection. + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR); + char c; + // Subsequent read can fail with: + // ECONNRESET: If the client connection was established and was reset by the + // remote. + // ECONNREFUSED: If the client connection failed to be established. + ASSERT_THAT(read(client.get(), &c, sizeof(c)), + AnyOf(SyscallFailsWithErrno(ECONNRESET), + SyscallFailsWithErrno(ECONNREFUSED))); + } +} + +TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { + TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(close(f.release()), SyscallSucceeds()); + }); +} + +TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) { + TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds()); + }); } TEST_P(SocketInetLoopbackTest, TCPbacklog) { @@ -1090,6 +1215,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { if (connects_received >= kConnectAttempts) { // Another thread have shutdown our read side causing the // accept to fail. + ASSERT_EQ(errno, EINVAL); break; } ASSERT_NO_ERRNO(fd); @@ -1157,7 +1283,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } -TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -1270,7 +1396,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } -TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) { +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -2146,8 +2272,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(connect(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), bound_addr_len), + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), + bound_addr_len), SyscallSucceeds()); // Get the ephemeral port. diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index 40e673625..d690d9564 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -45,37 +45,31 @@ void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { got_if_infos_ = false; // Get interface list. - std::vector<std::string> if_names; ASSERT_NO_ERRNO(if_helper_.Load()); - if_names = if_helper_.InterfaceList(AF_INET); + std::vector<std::string> if_names = if_helper_.InterfaceList(AF_INET); if (if_names.size() != 2) { return; } // Figure out which interface is where. - int lo = 0, eth = 1; - if (if_names[lo] != "lo") { - lo = 1; - eth = 0; - } - - if (if_names[lo] != "lo") { - return; - } - - lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[lo])); - lo_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[lo]); - if (lo_if_addr_ == nullptr) { + std::string lo = if_names[0]; + std::string eth = if_names[1]; + if (lo != "lo") std::swap(lo, eth); + if (lo != "lo") return; + + lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(lo)); + auto lo_if_addr = if_helper_.GetAddr(AF_INET, lo); + if (lo_if_addr == nullptr) { return; } - lo_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(lo_if_addr_)->sin_addr; + lo_if_addr_ = *reinterpret_cast<const sockaddr_in*>(lo_if_addr); - eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[eth])); - eth_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[eth]); - if (eth_if_addr_ == nullptr) { + eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(eth)); + auto eth_if_addr = if_helper_.GetAddr(AF_INET, eth); + if (eth_if_addr == nullptr) { return; } - eth_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(eth_if_addr_)->sin_addr; + eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr); got_if_infos_ = true; } @@ -242,7 +236,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Bind the non-receiving socket to the unicast ethernet address. auto norecv_addr = rcv1_addr; reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr = - eth_if_sin_addr_; + eth_if_addr_.sin_addr; ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), norecv_addr.addr_len), SyscallSucceedsWithValue(0)); @@ -1028,7 +1022,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn iface = {}; iface.imr_ifindex = lo_if_idx_; - iface.imr_address = eth_if_sin_addr_; + iface.imr_address = eth_if_addr_.sin_addr; ASSERT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, sizeof(iface)), SyscallSucceeds()); @@ -1058,7 +1052,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SKIP_IF(IsRunningOnGvisor()); // Verify the received source address. - EXPECT_EQ(eth_if_sin_addr_.s_addr, src_addr_in->sin_addr.s_addr); + EXPECT_EQ(eth_if_addr_.sin_addr.s_addr, src_addr_in->sin_addr.s_addr); } // Check that when we are bound to one interface we can set IP_MULTICAST_IF to @@ -1075,7 +1069,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Create sender and bind to eth interface. auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - ASSERT_THAT(bind(sender->get(), eth_if_addr_, sizeof(sockaddr_in)), + ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(ð_if_addr_), + sizeof(eth_if_addr_)), SyscallSucceeds()); // Run through all possible combinations of index and address for @@ -1085,9 +1080,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, struct in_addr imr_address; } test_data[] = { {lo_if_idx_, {}}, - {0, lo_if_sin_addr_}, - {lo_if_idx_, lo_if_sin_addr_}, - {lo_if_idx_, eth_if_sin_addr_}, + {0, lo_if_addr_.sin_addr}, + {lo_if_idx_, lo_if_addr_.sin_addr}, + {lo_if_idx_, eth_if_addr_.sin_addr}, }; for (auto t : test_data) { ip_mreqn iface = {}; diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h index bec2e96ee..10b90b1e0 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h @@ -36,10 +36,8 @@ class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest { // Interface infos. int lo_if_idx_; int eth_if_idx_; - sockaddr* lo_if_addr_; - sockaddr* eth_if_addr_; - in_addr lo_if_sin_addr_; - in_addr eth_if_sin_addr_; + sockaddr_in lo_if_addr_; + sockaddr_in eth_if_addr_; }; } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index 2efb96bc3..fbe61c5a0 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -26,7 +26,7 @@ #include "gtest/gtest.h" #include "absl/strings/str_format.h" -#include "absl/types/optional.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" #include "test/syscalls/linux/socket_netlink_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/capability_util.h" @@ -118,24 +118,6 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { // TODO(mpratt): Check ifinfomsg contents and following attrs. } -PosixError DumpLinks( - const FileDescriptor& fd, uint32_t seq, - const std::function<void(const struct nlmsghdr* hdr)>& fn) { - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - }; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = seq; - req.ifm.ifi_family = AF_UNSPEC; - - return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); -} - TEST(NetlinkRouteTest, GetLinkDump) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -161,37 +143,6 @@ TEST(NetlinkRouteTest, GetLinkDump) { EXPECT_TRUE(loopbackFound); } -struct Link { - int index; - std::string name; -}; - -PosixErrorOr<absl::optional<Link>> FindLoopbackLink() { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - absl::optional<Link> link; - RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { - if (hdr->nlmsg_type != RTM_NEWLINK || - hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { - return; - } - const struct ifinfomsg* msg = - reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); - if (msg->ifi_type == ARPHRD_LOOPBACK) { - const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); - if (rta == nullptr) { - // Ignore links that do not have a name. - return; - } - - link = Link(); - link->index = msg->ifi_index; - link->name = std::string(reinterpret_cast<const char*>(RTA_DATA(rta))); - } - })); - return link; -} - // CheckLinkMsg checks a netlink message against an expected link. void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) { ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK)); @@ -209,9 +160,7 @@ void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) { } TEST(NetlinkRouteTest, GetLinkByIndex) { - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -227,13 +176,13 @@ TEST(NetlinkRouteTest, GetLinkByIndex) { req.hdr.nlmsg_flags = NLM_F_REQUEST; req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; - req.ifm.ifi_index = loopback_link->index; + req.ifm.ifi_index = loopback_link.index; bool found = false; ASSERT_NO_ERRNO(NetlinkRequestResponse( fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { - CheckLinkMsg(hdr, *loopback_link); + CheckLinkMsg(hdr, loopback_link); found = true; }, false)); @@ -241,9 +190,7 @@ TEST(NetlinkRouteTest, GetLinkByIndex) { } TEST(NetlinkRouteTest, GetLinkByName) { - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -262,8 +209,8 @@ TEST(NetlinkRouteTest, GetLinkByName) { req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; req.rtattr.rta_type = IFLA_IFNAME; - req.rtattr.rta_len = RTA_LENGTH(loopback_link->name.size() + 1); - strncpy(req.ifname, loopback_link->name.c_str(), sizeof(req.ifname)); + req.rtattr.rta_len = RTA_LENGTH(loopback_link.name.size() + 1); + strncpy(req.ifname, loopback_link.name.c_str(), sizeof(req.ifname)); req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); @@ -271,7 +218,7 @@ TEST(NetlinkRouteTest, GetLinkByName) { ASSERT_NO_ERRNO(NetlinkRequestResponse( fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { - CheckLinkMsg(hdr, *loopback_link); + CheckLinkMsg(hdr, loopback_link); found = true; }, false)); @@ -523,9 +470,7 @@ TEST(NetlinkRouteTest, LookupAll) { TEST(NetlinkRouteTest, AddAddr) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -545,7 +490,7 @@ TEST(NetlinkRouteTest, AddAddr) { req.ifa.ifa_prefixlen = 24; req.ifa.ifa_flags = 0; req.ifa.ifa_scope = 0; - req.ifa.ifa_index = loopback_link->index; + req.ifa.ifa_index = loopback_link.index; req.rtattr.rta_type = IFA_LOCAL; req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr)); inet_pton(AF_INET, "10.0.0.1", &req.addr); diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc index 53eb3b6b2..bde1dbb4d 100644 --- a/test/syscalls/linux/socket_netlink_route_util.cc +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -18,7 +18,6 @@ #include <linux/netlink.h> #include <linux/rtnetlink.h> -#include "absl/types/optional.h" #include "test/syscalls/linux/socket_netlink_util.h" namespace gvisor { @@ -73,14 +72,14 @@ PosixErrorOr<std::vector<Link>> DumpLinks() { return links; } -PosixErrorOr<absl::optional<Link>> FindLoopbackLink() { +PosixErrorOr<Link> LoopbackLink() { ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); for (const auto& link : links) { if (link.type == ARPHRD_LOOPBACK) { - return absl::optional<Link>(link); + return link; } } - return absl::optional<Link>(); + return PosixError(ENOENT, "loopback link not found"); } PosixError LinkAddLocalAddr(int index, int family, int prefixlen, diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h index 2c018e487..149c4a7f6 100644 --- a/test/syscalls/linux/socket_netlink_route_util.h +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -20,7 +20,6 @@ #include <vector> -#include "absl/types/optional.h" #include "test/syscalls/linux/socket_netlink_util.h" namespace gvisor { @@ -37,7 +36,8 @@ PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq, PosixErrorOr<std::vector<Link>> DumpLinks(); -PosixErrorOr<absl::optional<Link>> FindLoopbackLink(); +// Returns the loopback link on the system. ENOENT if not found. +PosixErrorOr<Link> LoopbackLink(); // LinkAddLocalAddr sets IFA_LOCAL attribute on the interface. PosixError LinkAddLocalAddr(int index, int family, int prefixlen, diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index 5d3a39868..53b678e94 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -364,11 +364,6 @@ CreateTCPConnectAcceptSocketPair(int bound, int connected, int type, } MaybeSave(); // Successful accept. - // FIXME(b/110484944) - if (connect_result == -1) { - absl::SleepFor(absl::Seconds(1)); - } - T extra_addr = {}; LocalhostAddr(&extra_addr, dual_stack); return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr, diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index faa1247f6..f103e2e56 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <linux/unistd.h> #include <sys/eventfd.h> #include <sys/resource.h> #include <sys/sendfile.h> diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc index 53ad2dda3..6195b11e1 100644 --- a/test/syscalls/linux/tuntap.cc +++ b/test/syscalls/linux/tuntap.cc @@ -56,14 +56,14 @@ PosixErrorOr<std::set<std::string>> DumpLinkNames() { return names; } -PosixErrorOr<absl::optional<Link>> GetLinkByName(const std::string& name) { +PosixErrorOr<Link> GetLinkByName(const std::string& name) { ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); for (const auto& link : links) { if (link.name == name) { - return absl::optional<Link>(link); + return link; } } - return absl::optional<Link>(); + return PosixError(ENOENT, "interface not found"); } struct pihdr { @@ -242,7 +242,7 @@ TEST_F(TuntapTest, InvalidReadWrite) { TEST_F(TuntapTest, WriteToDownDevice) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - // FIXME: gVisor always creates enabled/up'd interfaces. + // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces. SKIP_IF(IsRunningOnGvisor()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); @@ -268,23 +268,21 @@ PosixErrorOr<FileDescriptor> OpenAndAttachTap( return PosixError(errno); } - ASSIGN_OR_RETURN_ERRNO(absl::optional<Link> link, GetLinkByName(dev_name)); - if (!link.has_value()) { - return PosixError(ENOENT, "no link"); - } + ASSIGN_OR_RETURN_ERRNO(auto link, GetLinkByName(dev_name)); // Interface setup. struct in_addr addr; inet_pton(AF_INET, dev_ipv4_addr.c_str(), &addr); - EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24, - &addr, sizeof(addr))); + EXPECT_NO_ERRNO(LinkAddLocalAddr(link.index, AF_INET, /*prefixlen=*/24, &addr, + sizeof(addr))); if (!IsRunningOnGvisor()) { - // FIXME: gVisor doesn't support setting MAC address on interfaces yet. - RETURN_IF_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA))); + // FIXME(b/110961832): gVisor doesn't support setting MAC address on + // interfaces yet. + RETURN_IF_ERRNO(LinkSetMacAddr(link.index, kMacA, sizeof(kMacA))); - // FIXME: gVisor always creates enabled/up'd interfaces. - RETURN_IF_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP)); + // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces. + RETURN_IF_ERRNO(LinkChangeFlags(link.index, IFF_UP, IFF_UP)); } return fd; diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc index 6218fbce1..ff66a79f4 100644 --- a/test/syscalls/linux/uidgid.cc +++ b/test/syscalls/linux/uidgid.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <grp.h> +#include <sys/resource.h> #include <sys/types.h> #include <unistd.h> @@ -249,6 +250,17 @@ TEST(UidGidRootTest, Setgroups) { SyscallFailsWithErrno(EFAULT)); } +TEST(UidGidRootTest, Setuid_prlimit) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); + + // Change our UID. + EXPECT_THAT(seteuid(65534), SyscallSucceeds()); + + // Despite the UID change, we should be able to get our own limits. + struct rlimit rl = {}; + ASSERT_THAT(prlimit(0, RLIMIT_NOFILE, NULL, &rl), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc index 3a927a430..22e6d1a85 100644 --- a/test/syscalls/linux/utimes.cc +++ b/test/syscalls/linux/utimes.cc @@ -34,17 +34,10 @@ namespace testing { namespace { -// TODO(b/36516566): utimes(nullptr) does not pick the "now" time in the -// application's time domain, so when asserting that times are within a window, -// we expand the window to allow for differences between the time domains. -constexpr absl::Duration kClockSlack = absl::Milliseconds(100); - // TimeBoxed runs fn, setting before and after to (coarse realtime) times // guaranteed* to come before and after fn started and completed, respectively. // // fn may be called more than once if the clock is adjusted. -// -// * See the comment on kClockSlack. gVisor breaks this guarantee. void TimeBoxed(absl::Time* before, absl::Time* after, std::function<void()> const& fn) { do { @@ -69,12 +62,6 @@ void TimeBoxed(absl::Time* before, absl::Time* after, // which could lead to test failures, but that is very unlikely to happen. continue; } - - if (IsRunningOnGvisor()) { - // See comment on kClockSlack. - *before -= kClockSlack; - *after += kClockSlack; - } } while (*after < *before); } @@ -235,10 +222,7 @@ void TestUtimensat(int dirFd, std::string const& path) { EXPECT_GE(mtime3, before); EXPECT_LE(mtime3, after); - if (!IsRunningOnGvisor()) { - // FIXME(b/36516566): Gofers set atime and mtime to different "now" times. - EXPECT_EQ(atime3, mtime3); - } + EXPECT_EQ(atime3, mtime3); } TEST(UtimensatTest, OnAbsPath) { diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc index 9b219cfd6..39b5b2f56 100644 --- a/test/syscalls/linux/write.cc +++ b/test/syscalls/linux/write.cc @@ -31,14 +31,8 @@ namespace gvisor { namespace testing { namespace { -// This test is currently very rudimentary. -// -// TODO(edahlgren): -// * bad buffer states (EFAULT). -// * bad fds (wrong permission, wrong type of file, EBADF). -// * check offset is incremented. -// * check for EOF. -// * writing to pipes, symlinks, special files. + +// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary. class WriteTest : public ::testing::Test { public: ssize_t WriteBytes(int fd, int bytes) { diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc index 8b00ef44c..3231732ec 100644 --- a/test/syscalls/linux/xattr.cc +++ b/test/syscalls/linux/xattr.cc @@ -41,12 +41,12 @@ class XattrTest : public FileTest {}; TEST_F(XattrTest, XattrNonexistentFile) { const char* path = "/does/not/exist"; - EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0), - SyscallFailsWithErrno(ENOENT)); - EXPECT_THAT(getxattr(path, nullptr, nullptr, 0), + const char* name = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallFailsWithErrno(ENOENT)); + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENOENT)); EXPECT_THAT(listxattr(path, nullptr, 0), SyscallFailsWithErrno(ENOENT)); - EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(ENOENT)); + EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENOENT)); } TEST_F(XattrTest, XattrNullName) { diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl index 92b0b5fc0..132040c20 100644 --- a/tools/bazeldefs/platforms.bzl +++ b/tools/bazeldefs/platforms.bzl @@ -2,15 +2,10 @@ # Platform to associated tags. platforms = { - "ptrace": [ - # TODO(b/120560048): Make the tests run without this tag. - "no-sandbox", - ], + "ptrace": [], "kvm": [ "manual", "local", - # TODO(b/120560048): Make the tests run without this tag. - "no-sandbox", ], } diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl index c5be52ecd..8c9995fd4 100644 --- a/tools/go_generics/defs.bzl +++ b/tools/go_generics/defs.bzl @@ -105,7 +105,6 @@ def _go_template_instance_impl(ctx): executable = ctx.executable._tool, ) - # TODO: How can we get the dependencies out? return struct( files = depset([output]), ) diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 3437aa476..309ee9c21 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -206,7 +206,7 @@ func main() { initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name)) } emitZeroCheck := func(name string) { - fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name) + fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { m.Failf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, name, name) } emitLoadValue := func(name, typName string) { fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName) diff --git a/tools/image_build.sh b/tools/image_build.sh deleted file mode 100755 index 9b20a740d..000000000 --- a/tools/image_build.sh +++ /dev/null @@ -1,98 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script is responsible for building a new GCP image that: 1) has nested -# virtualization enabled, and 2) has been completely set up with the -# image_setup.sh script. This script should be idempotent, as we memoize the -# setup script with a hash and check for that name. -# -# The GCP project name should be defined via a gcloud config. - -set -xeo pipefail - -# Parameters. -declare -r ZONE=${ZONE:-us-central1-f} -declare -r USERNAME=${USERNAME:-test} -declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud} -declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts} - -# Random names. -declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z) -declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z) -declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z) - -# Hashes inputs. -declare -r SETUP_BLOB=$(echo ${ZONE} ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && sha256sum "$@") -declare -r SETUP_HASH=$(echo ${SETUP_BLOB} | sha256sum - | cut -d' ' -f1 | cut -c 1-16) -declare -r IMAGE_NAME=${IMAGE_NAME:-image-}${SETUP_HASH} - -# Does the image already exist? Skip the build. -declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)") -if ! [[ -z "${existing}" ]]; then - echo "${existing}" - exit 0 -fi - -# Set the zone for all actions. -gcloud config set compute/zone "${ZONE}" - -# Start a unique instance. Note that this instance will have a unique persistent -# disk as it's boot disk with the same name as the instance. -gcloud compute instances create \ - --quiet \ - --image-project "${IMAGE_PROJECT}" \ - --image-family "${IMAGE_FAMILY}" \ - --boot-disk-size "200GB" \ - "${INSTANCE_NAME}" -function cleanup { - gcloud compute instances delete --quiet "${INSTANCE_NAME}" -} -trap cleanup EXIT - -# Wait for the instance to become available. -declare attempts=0 -while [[ "${attempts}" -lt 30 ]]; do - attempts=$((${attempts}+1)) - if gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- true; then - break - fi -done -if [[ "${attempts}" -ge 30 ]]; then - echo "too many attempts: failed" - exit 1 -fi - -# Run the install scripts provided. -for arg; do - gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}" -done - -# Stop the instance; required before creating an image. -gcloud compute instances stop --quiet "${INSTANCE_NAME}" - -# Create a snapshot of the instance disk. -gcloud compute disks snapshot \ - --quiet \ - --zone="${ZONE}" \ - --snapshot-names="${SNAPSHOT_NAME}" \ - "${INSTANCE_NAME}" - -# Create the disk image. -gcloud compute images create \ - --quiet \ - --source-snapshot="${SNAPSHOT_NAME}" \ - --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \ - "${IMAGE_NAME}" diff --git a/tools/images/BUILD b/tools/images/BUILD index 66ffd02aa..8d319e3e4 100644 --- a/tools/images/BUILD +++ b/tools/images/BUILD @@ -6,14 +6,9 @@ package( licenses = ["notice"], ) -genrule( +sh_binary( name = "zone", - outs = ["zone.txt"], - cmd = "gcloud config get-value compute/zone > \"$@\"", - tags = [ - "local", - "manual", - ], + srcs = ["zone.sh"], ) sh_binary( diff --git a/tools/images/README.md b/tools/images/README.md new file mode 100644 index 000000000..26c0f84f2 --- /dev/null +++ b/tools/images/README.md @@ -0,0 +1,42 @@ +# Images + +All commands in this directory require the `gcloud` project to be set. + +For example: `gcloud config set project gvisor-kokoro-testing`. + +Images can be generated by using the `vm_image` rule. This rule will generate a +binary target that builds an image in an idempotent way, and can be referenced +from other rules. + +For example: + +``` +vm_image( + name = "ubuntu", + project = "ubuntu-1604-lts", + family = "ubuntu-os-cloud", + scripts = [ + "script.sh", + "other.sh", + ], +) +``` + +These images can be built manually by executing the target. The output on +`stdout` will be the image id (in the current project). + +Images are always named per the hash of all the hermetic input scripts. This +allows images to be memoized quickly and easily. + +The `vm_test` rule can be used to execute a command remotely. This is still +under development however, and will likely change over time. + +For example: + +``` +vm_test( + name = "mycommand", + image = ":ubuntu", + targets = [":test"], +) +``` diff --git a/tools/images/build.sh b/tools/images/build.sh index f89f39cbd..f39f723b8 100755 --- a/tools/images/build.sh +++ b/tools/images/build.sh @@ -19,7 +19,7 @@ # image_setup.sh script. This script should be idempotent, as we memoize the # setup script with a hash and check for that name. -set -xeou pipefail +set -eou pipefail # Parameters. declare -r USERNAME=${USERNAME:-test} @@ -34,10 +34,10 @@ declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z) # Hash inputs in order to memoize the produced image. declare -r SETUP_HASH=$( (echo ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && cat "$@") | sha256sum - | cut -d' ' -f1 | cut -c 1-16) -declare -r IMAGE_NAME=${IMAGE_FAMILY:-image-}${SETUP_HASH} +declare -r IMAGE_NAME=${IMAGE_FAMILY:-image}-${SETUP_HASH} # Does the image already exist? Skip the build. -declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)") +declare -r existing=$(set -x; gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)") if ! [[ -z "${existing}" ]]; then echo "${existing}" exit 0 @@ -48,28 +48,30 @@ export PATH=${PATH:-/bin:/usr/bin:/usr/local/bin} # Start a unique instance. Note that this instance will have a unique persistent # disk as it's boot disk with the same name as the instance. -gcloud compute instances create \ +(set -x; gcloud compute instances create \ --quiet \ --image-project "${IMAGE_PROJECT}" \ --image-family "${IMAGE_FAMILY}" \ --boot-disk-size "200GB" \ --zone "${ZONE}" \ - "${INSTANCE_NAME}" >/dev/null + "${INSTANCE_NAME}" >/dev/null) function cleanup { - gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}" + (set -x; gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}") } trap cleanup EXIT # Wait for the instance to become available (up to 5 minutes). +echo -n "Waiting for ${INSTANCE_NAME}" declare timeout=300 declare success=0 declare internal="" declare -r start=$(date +%s) declare -r end=$((${start}+${timeout})) while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do - if gcloud compute ssh --zone "${internal}" "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then + echo -n "." + if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then success=$((${success}+1)) - elif gcloud compute ssh --zone --internal-ip "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then + elif gcloud compute ssh --internal-ip --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then success=$((${success}+1)) internal="--internal-ip" fi @@ -78,29 +80,34 @@ done if [[ "${success}" -eq "0" ]]; then echo "connect timed out after ${timeout} seconds." exit 1 +else + echo "done." fi # Run the install scripts provided. for arg; do - gcloud compute ssh --zone "${internal}" "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}" >/dev/null + (set -x; gcloud compute ssh ${internal} \ + --zone "${ZONE}" \ + "${USERNAME}"@"${INSTANCE_NAME}" -- \ + sudo bash - <"${arg}" >/dev/null) done # Stop the instance; required before creating an image. -gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null +(set -x; gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null) # Create a snapshot of the instance disk. -gcloud compute disks snapshot \ +(set -x; gcloud compute disks snapshot \ --quiet \ --zone "${ZONE}" \ --snapshot-names="${SNAPSHOT_NAME}" \ - "${INSTANCE_NAME}" >/dev/null + "${INSTANCE_NAME}" >/dev/null) # Create the disk image. -gcloud compute images create \ +(set -x; gcloud compute images create \ --quiet \ --source-snapshot="${SNAPSHOT_NAME}" \ --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \ - "${IMAGE_NAME}" >/dev/null + "${IMAGE_NAME}" >/dev/null) # Finish up. echo "${IMAGE_NAME}" diff --git a/tools/images/defs.bzl b/tools/images/defs.bzl index de365d153..2847e1847 100644 --- a/tools/images/defs.bzl +++ b/tools/images/defs.bzl @@ -1,76 +1,49 @@ -"""Image configuration. - -Images can be generated by using the vm_image rule. For example, - - vm_image( - name = "ubuntu", - project = "...", - family = "...", - scripts = [ - "script.sh", - "other.sh", - ], - ) - -This will always create an vm_image in the current default gcloud project. The -rule has a text file as its output containing the image name. This will enforce -serialization for all dependent rules. - -Images are always named per the hash of all the hermetic input scripts. This -allows images to be memoized quickly and easily. - -The vm_test rule can be used to execute a command remotely. For example, - - vm_test( - name = "mycommand", - image = ":myimage", - targets = [":test"], - ) -""" +"""Image configuration. See README.md.""" load("//tools:defs.bzl", "default_installer") -def _vm_image_impl(ctx): +# vm_image_builder is a rule that will construct a shell script that actually +# generates a given VM image. Note that this does not _run_ the shell script +# (although it can be run manually). It will be run manually during generation +# of the vm_image target itself. This level of indirection is used so that the +# build system itself only runs the builder once when multiple targets depend +# on it, avoiding a set of races and conflicts. +def _vm_image_builder_impl(ctx): + # Generate a binary that actually builds the image. + builder = ctx.actions.declare_file(ctx.label.name) script_paths = [] for script in ctx.files.scripts: script_paths.append(script.short_path) + builder_content = "\n".join([ + "#!/bin/bash", + "export ZONE=$(%s)" % ctx.files.zone[0].short_path, + "export USERNAME=%s" % ctx.attr.username, + "export IMAGE_PROJECT=%s" % ctx.attr.project, + "export IMAGE_FAMILY=%s" % ctx.attr.family, + "%s %s" % (ctx.files._builder[0].short_path, " ".join(script_paths)), + "", + ]) + ctx.actions.write(builder, builder_content, is_executable = True) - resolved_inputs, argv, runfiles_manifests = ctx.resolve_command( - command = "USERNAME=%s ZONE=$(cat %s) IMAGE_PROJECT=%s IMAGE_FAMILY=%s %s %s > %s" % - ( - ctx.attr.username, - ctx.files.zone[0].path, - ctx.attr.project, - ctx.attr.family, - ctx.executable.builder.path, - " ".join(script_paths), - ctx.outputs.out.path, - ), - tools = [ctx.attr.builder] + ctx.attr.scripts, - ) - - ctx.actions.run_shell( - tools = resolved_inputs, - outputs = [ctx.outputs.out], - progress_message = "Building image...", - execution_requirements = {"local": "true"}, - command = argv, - input_manifests = runfiles_manifests, - ) + # Note that the scripts should only be files, and should not include any + # indirect transitive dependencies. The build script wouldn't work. return [DefaultInfo( - files = depset([ctx.outputs.out]), - runfiles = ctx.runfiles(files = [ctx.outputs.out]), + executable = builder, + runfiles = ctx.runfiles( + files = ctx.files.scripts + ctx.files._builder + ctx.files.zone, + ), )] -_vm_image = rule( +vm_image_builder = rule( attrs = { - "builder": attr.label( + "_builder": attr.label( executable = True, default = "//tools/images:builder", cfg = "host", ), "username": attr.string(default = "$(whoami)"), "zone": attr.label( + executable = True, default = "//tools/images:zone", cfg = "host", ), @@ -78,20 +51,55 @@ _vm_image = rule( "project": attr.string(mandatory = True), "scripts": attr.label_list(allow_files = True), }, - outputs = { - "out": "%{name}.txt", + executable = True, + implementation = _vm_image_builder_impl, +) + +# See vm_image_builder above. +def _vm_image_impl(ctx): + # Run the builder to generate our output. + echo = ctx.actions.declare_file(ctx.label.name) + resolved_inputs, argv, runfiles_manifests = ctx.resolve_command( + command = "echo -ne \"#!/bin/bash\\necho $(%s)\\n\" > %s && chmod 0755 %s" % ( + ctx.files.builder[0].path, + echo.path, + echo.path, + ), + tools = [ctx.attr.builder], + ) + ctx.actions.run_shell( + tools = resolved_inputs, + outputs = [echo], + progress_message = "Building image...", + execution_requirements = {"local": "true"}, + command = argv, + input_manifests = runfiles_manifests, + ) + + # Return just the echo command. All of the builder runfiles have been + # resolved and consumed in the generation of the trivial echo script. + return [DefaultInfo(executable = echo)] + +_vm_image = rule( + attrs = { + "builder": attr.label( + executable = True, + cfg = "host", + ), }, + executable = True, implementation = _vm_image_impl, ) -def vm_image(**kwargs): - _vm_image( - tags = [ - "local", - "manual", - ], +def vm_image(name, **kwargs): + vm_image_builder( + name = name + "_builder", **kwargs ) + _vm_image( + name = name, + builder = ":" + name + "_builder", + ) def _vm_test_impl(ctx): runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) diff --git a/tools/images/zone.sh b/tools/images/zone.sh new file mode 100755 index 000000000..79569fb19 --- /dev/null +++ b/tools/images/zone.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec gcloud config get-value compute/zone diff --git a/tools/nogo.json b/tools/nogo.json index 83cb76b93..ae969409e 100644 --- a/tools/nogo.json +++ b/tools/nogo.json @@ -9,27 +9,6 @@ "/external/": "allowed: not subject to unsafe naming rules" } }, - "copylocks": { - "exclude_files": { - ".*_state_autogen.go": "fix: m.Failf copies by value", - "/pkg/log/json.go": "fix: Emit passes lock by value: gvisor.dev/gvisor/pkg/log.JSONEmitter contains gvisor.dev/gvisor/pkg/log.Writer contains gvisor.dev/gvisor/pkg/sync.Mutex", - "/pkg/log/log_test.go": "fix: call of fmt.Printf copies lock value: gvisor.dev/gvisor/pkg/log.Writer contains gvisor.dev/gvisor/pkg/sync.Mutex", - "/pkg/sentry/fs/host/socket_test.go": "fix: call of t.Errorf copies lock value: gvisor.dev/gvisor/pkg/sentry/fs/host.ConnectedEndpoint contains gvisor.dev/gvisor/pkg/refs.AtomicRefCount contains gvisor.dev/gvisor/pkg/sync.Mutex", - "/pkg/sentry/fs/proc/sys_net.go": "fix: Truncate passes lock by value: gvisor.dev/gvisor/pkg/sentry/fs/proc.tcpMemInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.SimpleFileInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.InodeSimpleAttributes contains gvisor.dev/gvisor/pkg/sync.RWMutex", - "/pkg/sentry/fs/proc/sys_net.go": "fix: Truncate passes lock by value: gvisor.dev/gvisor/pkg/sentry/fs/proc.tcpSack contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.SimpleFileInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.InodeSimpleAttributes contains gvisor.dev/gvisor/pkg/sync.RWMutex", - "/pkg/sentry/fs/tty/slave.go": "fix: Truncate passes lock by value: gvisor.dev/gvisor/pkg/sentry/fs/tty.slaveInodeOperations contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.SimpleFileInode contains gvisor.dev/gvisor/pkg/sentry/fs/fsutil.InodeSimpleAttributes contains gvisor.dev/gvisor/pkg/sync.RWMutex", - "/pkg/sentry/kernel/time/time.go": "fix: Readiness passes lock by value: gvisor.dev/gvisor/pkg/sentry/kernel/time.ClockEventsQueue contains gvisor.dev/gvisor/pkg/waiter.Queue contains gvisor.dev/gvisor/pkg/sync.RWMutex", - "/pkg/sentry/kernel/syscalls_state.go": "fix: assignment copies lock value to *s: gvisor.dev/gvisor/pkg/sentry/kernel.SyscallTable contains gvisor.dev/gvisor/pkg/sentry/kernel.SyscallFlagsTable contains gvisor.dev/gvisor/pkg/sync.Mutex" - } - }, - "lostcancel": { - "exclude_files": { - "/pkg/tcpip/network/arp/arp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak", - "/pkg/tcpip/stack/ndp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak", - "/pkg/tcpip/transport/udp/udp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak", - "/pkg/tcpip/transport/tcp/testing/context/context.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak" - } - }, "nilness": { "exclude_files": { "/com_github_vishvananda_netlink/route_linux.go": "allowed: false positive", @@ -40,37 +19,6 @@ "/external/io_opencensus_go/tag/map_codec.go": "allowed: false positive" } }, - "printf": { - "exclude_files": { - ".*_abi_autogen_test.go": "fix: Sprintf format has insufficient args", - "/pkg/segment/test/segment_test.go": "fix: Errorf format %d arg seg.Start is a func value, not called", - "/pkg/tcpip/tcpip_test.go": "fix: Error call has possible formatting directive %q", - "/pkg/tcpip/header/eth_test.go": "fix: Fatalf format %s reads arg #3, but call has 2 args", - "/pkg/tcpip/header/ndp_test.go": "fix: Errorf format %d reads arg #1, but call has 0 args", - "/pkg/eventchannel/event_test.go": "fix: Fatal call has possible formatting directive %v", - "/pkg/tcpip/stack/ndp.go": "fix: Fatalf format %s has arg protocolAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress", - "/pkg/sentry/fs/fdpipe/pipe_test.go": "fix: Errorf format %s has arg flags of wrong type gvisor.dev/gvisor/pkg/sentry/fs.FileFlags", - "/pkg/sentry/fs/fdpipe/pipe_test.go": "fix: Errorf format %d arg f.FD is a func value, not called", - "/pkg/tcpip/link/fdbased/endpoint.go": "fix: Sprintf format %v with arg p causes recursive String method call", - "/pkg/tcpip/transport/udp/udp_test.go": "fix: Fatalf format %s has arg h.srcAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.FullAddress", - "/pkg/tcpip/transport/tcp/tcp_test.go": "fix: Fatalf format %s has arg tcpTW of wrong type gvisor.dev/gvisor/pkg/tcpip.TCPTimeWaitTimeoutOption", - "/pkg/tcpip/transport/tcp/tcp_test.go": "fix: Errorf call needs 1 arg but has 2 args", - "/pkg/tcpip/stack/ndp_test.go": "fix: Errorf format %s reads arg #3, but call has 2 args", - "/pkg/tcpip/stack/ndp_test.go": "fix: Fatalf format %s reads arg #5, but call has 4 args", - "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg protoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress", - "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg nic1ProtoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress", - "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg nic2ProtoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress", - "/pkg/tcpip/stack/stack_test.go": "fix: Fatal call has possible formatting directive %t", - "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf call has arguments but no formatting directives", - "/pkg/tcpip/link/fdbased/endpoint.go": "fix: Sprintf format %v with arg p causes recursive String method call", - "/pkg/sentry/fsimpl/tmpfs/stat_test.go": "fix: Errorf format %v reads arg #1, but call has 0 args", - "/runsc/container/test_app/test_app.go": "fix: Fatal call has possible formatting directive %q", - "/test/root/cgroup_test.go": "fix: Errorf format %s has arg gots of wrong type []int", - "/test/root/cgroup_test.go": "fix: Fatalf format %v reads arg #3, but call has 2 args", - "/test/runtimes/runner.go": "fix: Skip call has possible formatting directive %q", - "/test/runtimes/blacklist_test.go": "fix: Errorf format %q has arg blacklistFile of wrong type *string" - } - }, "structtag": { "exclude_files": { "/external/": "allowed: may use arbitrary tags" @@ -83,16 +31,9 @@ "/pkg/gohacks/gohacks_unsafe.go": "allowed: special case", "/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go": "allowed: special case", "/pkg/sentry/platform/kvm/(bluepill|machine)_unsafe.go": "allowed: special case", - "/pkg/sentry/platform/kvm/machine_arm64_unsafe.go": "fix: gvisor.dev/issue/22464", "/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go": "allowed: special case", "/pkg/sentry/platform/safecopy/safecopy_unsafe.go": "allowed: special case", "/pkg/sentry/vfs/mount_unsafe.go": "allowed: special case" } - }, - "unusedresult": { - "exclude_files": { - "/pkg/sentry/fsimpl/proc/task_net.go": "fix: result of fmt.Sprintf call not used", - "/pkg/sentry/fsimpl/proc/tasks_net.go": "fix: result of fmt.Sprintf call not used" - } } } |