diff options
Diffstat (limited to 'pkg')
168 files changed, 5090 insertions, 1727 deletions
diff --git a/pkg/context/context.go b/pkg/context/context.go index 23e009ef3..5319b6d8d 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -127,10 +127,6 @@ func (logContext) Value(key interface{}) interface{} { var bgContext = &logContext{Logger: log.Log()} // Background returns an empty context using the default logger. -// -// Users should be wary of using a Background context. Please tag any use with -// FIXME(b/38173783) and a note to remove this use. -// // Generally, one should use the Task as their context when available, or avoid // having to use a context in places where a Task is unavailable. // diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go index 7f41b4a27..43750360b 100644 --- a/pkg/eventchannel/event_test.go +++ b/pkg/eventchannel/event_test.go @@ -78,7 +78,7 @@ func TestMultiEmitter(t *testing.T) { for _, name := range names { m := testMessage{name: name} if _, err := me.Emit(m); err != nil { - t.Fatal("me.Emit(%v) failed: %v", m, err) + t.Fatalf("me.Emit(%v) failed: %v", m, err) } } @@ -96,7 +96,7 @@ func TestMultiEmitter(t *testing.T) { // Close multiEmitter. if err := me.Close(); err != nil { - t.Fatal("me.Close() failed: %v", err) + t.Fatalf("me.Close() failed: %v", err) } // All testEmitters should be closed. diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go index 8f93e4d6d..0d07da3b1 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/list.go @@ -86,12 +86,21 @@ func (l *List) Back() Element { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *List) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *List) PushFront(e Element) { linker := ElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { ElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -106,7 +115,6 @@ func (l *List) PushBack(e Element) { linker := ElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { ElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -127,7 +135,6 @@ func (l *List) PushBackList(m *List) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/log/glog.go b/pkg/log/glog.go index b4f7bb5a4..f57c4427b 100644 --- a/pkg/log/glog.go +++ b/pkg/log/glog.go @@ -25,7 +25,7 @@ import ( // GoogleEmitter is a wrapper that emits logs in a format compatible with // package github.com/golang/glog. type GoogleEmitter struct { - Writer + *Writer } // pid is used for the threadid component of the header. @@ -46,7 +46,7 @@ var pid = os.Getpid() // line The line number // msg The user-supplied message // -func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) { +func (g GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) { // Log level. prefix := byte('?') switch level { @@ -81,5 +81,5 @@ func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format message := fmt.Sprintf(format, args...) // Emit the formatted result. - fmt.Fprintf(&g.Writer, "%c%02d%02d %02d:%02d:%02d.%06d % 7d %s:%d] %s\n", prefix, int(month), day, hour, minute, second, microsecond, pid, file, line, message) + fmt.Fprintf(g.Writer, "%c%02d%02d %02d:%02d:%02d.%06d % 7d %s:%d] %s\n", prefix, int(month), day, hour, minute, second, microsecond, pid, file, line, message) } diff --git a/pkg/log/json.go b/pkg/log/json.go index 0943db1cc..bdf9d691e 100644 --- a/pkg/log/json.go +++ b/pkg/log/json.go @@ -58,7 +58,7 @@ func (lv *Level) UnmarshalJSON(b []byte) error { // JSONEmitter logs messages in json format. type JSONEmitter struct { - Writer + *Writer } // Emit implements Emitter.Emit. diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go index 6c6fc8b6f..5883e95e1 100644 --- a/pkg/log/json_k8s.go +++ b/pkg/log/json_k8s.go @@ -29,11 +29,11 @@ type k8sJSONLog struct { // K8sJSONEmitter logs messages in json format that is compatible with // Kubernetes fluent configuration. type K8sJSONEmitter struct { - Writer + *Writer } // Emit implements Emitter.Emit. -func (e *K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) { +func (e K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) { j := k8sJSONLog{ Log: fmt.Sprintf(format, v...), Level: level, diff --git a/pkg/log/log.go b/pkg/log/log.go index a794da1aa..37e0605ad 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -374,5 +374,5 @@ func CopyStandardLogTo(l Level) error { func init() { // Store the initial value for the log. - log.Store(&BasicLogger{Level: Info, Emitter: &GoogleEmitter{Writer{Next: os.Stderr}}}) + log.Store(&BasicLogger{Level: Info, Emitter: GoogleEmitter{&Writer{Next: os.Stderr}}}) } diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go index 402cc29ae..9ff18559b 100644 --- a/pkg/log/log_test.go +++ b/pkg/log/log_test.go @@ -52,7 +52,7 @@ func TestDropMessages(t *testing.T) { t.Fatalf("Write should have failed") } - fmt.Printf("writer: %+v\n", w) + fmt.Printf("writer: %#v\n", &w) tw.fail = false if _, err := w.Write([]byte("line 2\n")); err != nil { @@ -76,7 +76,7 @@ func TestDropMessages(t *testing.T) { func TestCaller(t *testing.T) { tw := &testWriter{} - e := &GoogleEmitter{Writer: Writer{Next: tw}} + e := GoogleEmitter{Writer: &Writer{Next: tw}} bl := &BasicLogger{ Emitter: e, Level: Debug, @@ -94,7 +94,7 @@ func BenchmarkGoogleLogging(b *testing.B) { tw := &testWriter{ limit: 1, // Only record one message. } - e := &GoogleEmitter{Writer: Writer{Next: tw}} + e := GoogleEmitter{Writer: &Writer{Next: tw}} bl := &BasicLogger{ Emitter: e, Level: Debug, diff --git a/pkg/p9/client.go b/pkg/p9/client.go index a6f493b82..71e944c30 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -174,7 +174,7 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client // our sendRecv function to use that functionality. Otherwise, // we stick to sendRecvLegacy. rversion := Rversion{} - err := c.sendRecvLegacy(&Tversion{ + _, err := c.sendRecvLegacy(&Tversion{ Version: versionString(requested), MSize: messageSize, }, &rversion) @@ -219,11 +219,11 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client c.sendRecv = c.sendRecvChannel } else { // Channel setup failed; fallback. - c.sendRecv = c.sendRecvLegacy + c.sendRecv = c.sendRecvLegacySyscallErr } } else { // No channels available: use the legacy mechanism. - c.sendRecv = c.sendRecvLegacy + c.sendRecv = c.sendRecvLegacySyscallErr } // Ensure that the socket and channels are closed when the socket is shut @@ -305,7 +305,7 @@ func (c *Client) openChannel(id int) error { ) // Open the data channel. - if err := c.sendRecvLegacy(&Tchannel{ + if _, err := c.sendRecvLegacy(&Tchannel{ ID: uint32(id), Control: 0, }, &rchannel0); err != nil { @@ -319,7 +319,7 @@ func (c *Client) openChannel(id int) error { defer rchannel0.FilePayload().Close() // Open the channel for file descriptors. - if err := c.sendRecvLegacy(&Tchannel{ + if _, err := c.sendRecvLegacy(&Tchannel{ ID: uint32(id), Control: 1, }, &rchannel1); err != nil { @@ -431,13 +431,28 @@ func (c *Client) waitAndRecv(done chan error) error { } } +// sendRecvLegacySyscallErr is a wrapper for sendRecvLegacy that converts all +// non-syscall errors to EIO. +func (c *Client) sendRecvLegacySyscallErr(t message, r message) error { + received, err := c.sendRecvLegacy(t, r) + if !received { + log.Warningf("p9.Client.sendRecvChannel: %v", err) + return syscall.EIO + } + return err +} + // sendRecvLegacy performs a roundtrip message exchange. // +// sendRecvLegacy returns true if a message was received. This allows us to +// differentiate between failed receives and successful receives where the +// response was an error message. +// // This is called by internal functions. -func (c *Client) sendRecvLegacy(t message, r message) error { +func (c *Client) sendRecvLegacy(t message, r message) (bool, error) { tag, ok := c.tagPool.Get() if !ok { - return ErrOutOfTags + return false, ErrOutOfTags } defer c.tagPool.Put(tag) @@ -457,12 +472,12 @@ func (c *Client) sendRecvLegacy(t message, r message) error { err := send(c.socket, Tag(tag), t) c.sendMu.Unlock() if err != nil { - return err + return false, err } // Co-ordinate with other receivers. if err := c.waitAndRecv(resp.done); err != nil { - return err + return false, err } // Is it an error message? @@ -470,14 +485,14 @@ func (c *Client) sendRecvLegacy(t message, r message) error { // For convenience, we transform these directly // into errors. Handlers need not handle this case. if rlerr, ok := resp.r.(*Rlerror); ok { - return syscall.Errno(rlerr.Error) + return true, syscall.Errno(rlerr.Error) } // At this point, we know it matches. // // Per recv call above, we will only allow a type // match (and give our r) or an instance of Rlerror. - return nil + return true, nil } // sendRecvChannel uses channels to send a message. @@ -486,7 +501,7 @@ func (c *Client) sendRecvChannel(t message, r message) error { c.channelsMu.Lock() if len(c.availableChannels) == 0 { c.channelsMu.Unlock() - return c.sendRecvLegacy(t, r) + return c.sendRecvLegacySyscallErr(t, r) } idx := len(c.availableChannels) - 1 ch := c.availableChannels[idx] @@ -526,7 +541,11 @@ func (c *Client) sendRecvChannel(t message, r message) error { } // Parse the server's response. - _, retErr := ch.recv(r, rsz) + resp, retErr := ch.recv(r, rsz) + if resp == nil { + log.Warningf("p9.Client.sendRecvChannel: p9.channel.recv: %v", retErr) + retErr = syscall.EIO + } // Release the channel. c.channelsMu.Lock() diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go index 29a0afadf..c757583e0 100644 --- a/pkg/p9/client_test.go +++ b/pkg/p9/client_test.go @@ -96,7 +96,12 @@ func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) e } func BenchmarkSendRecvLegacy(b *testing.B) { - benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy }) + benchmarkSendRecv(b, func(c *Client) func(message, message) error { + return func(t message, r message) error { + _, err := c.sendRecvLegacy(t, r) + return err + } + }) } func BenchmarkSendRecvChannel(b *testing.B) { diff --git a/pkg/p9/file.go b/pkg/p9/file.go index d4ffbc8e3..cab35896f 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -97,12 +97,12 @@ type File interface { // free to ignore the hint entirely (i.e. the value returned may be larger // than size). All size checking is done independently at the syscall layer. // - // TODO(b/127675828): Determine concurrency guarantees once implemented. + // On the server, GetXattr has a read concurrency guarantee. GetXattr(name string, size uint64) (string, error) // SetXattr sets extended attributes on this node. // - // TODO(b/127675828): Determine concurrency guarantees once implemented. + // On the server, SetXattr has a write concurrency guarantee. SetXattr(name, value string, flags uint32) error // ListXattr lists the names of the extended attributes on this node. @@ -113,12 +113,12 @@ type File interface { // free to ignore the hint entirely (i.e. the value returned may be larger // than size). All size checking is done independently at the syscall layer. // - // TODO(b/148303075): Determine concurrency guarantees once implemented. + // On the server, ListXattr has a read concurrency guarantee. ListXattr(size uint64) (map[string]struct{}, error) // RemoveXattr removes extended attributes on this node. // - // TODO(b/148303075): Determine concurrency guarantees once implemented. + // On the server, RemoveXattr has a write concurrency guarantee. RemoveXattr(name string) error // Allocate allows the caller to directly manipulate the allocated disk space diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 2ac45eb80..1db5797dd 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -48,6 +48,8 @@ func ExtractErrno(err error) syscall.Errno { return ExtractErrno(e.Err) case *os.SyscallError: return ExtractErrno(e.Err) + case *os.LinkError: + return ExtractErrno(e.Err) } // Default case. @@ -920,8 +922,15 @@ func (t *Tgetxattr) handle(cs *connState) message { } defer ref.DecRef() - val, err := ref.file.GetXattr(t.Name, t.Size) - if err != nil { + var val string + if err := ref.safelyRead(func() (err error) { + // Don't allow getxattr on files that have been deleted. + if ref.isDeleted() { + return syscall.EINVAL + } + val, err = ref.file.GetXattr(t.Name, t.Size) + return err + }); err != nil { return newErr(err) } return &Rgetxattr{Value: val} @@ -935,7 +944,13 @@ func (t *Tsetxattr) handle(cs *connState) message { } defer ref.DecRef() - if err := ref.file.SetXattr(t.Name, t.Value, t.Flags); err != nil { + if err := ref.safelyWrite(func() error { + // Don't allow setxattr on files that have been deleted. + if ref.isDeleted() { + return syscall.EINVAL + } + return ref.file.SetXattr(t.Name, t.Value, t.Flags) + }); err != nil { return newErr(err) } return &Rsetxattr{} @@ -949,10 +964,18 @@ func (t *Tlistxattr) handle(cs *connState) message { } defer ref.DecRef() - xattrs, err := ref.file.ListXattr(t.Size) - if err != nil { + var xattrs map[string]struct{} + if err := ref.safelyRead(func() (err error) { + // Don't allow listxattr on files that have been deleted. + if ref.isDeleted() { + return syscall.EINVAL + } + xattrs, err = ref.file.ListXattr(t.Size) + return err + }); err != nil { return newErr(err) } + xattrList := make([]string, 0, len(xattrs)) for x := range xattrs { xattrList = append(xattrList, x) @@ -968,7 +991,13 @@ func (t *Tremovexattr) handle(cs *connState) message { } defer ref.DecRef() - if err := ref.file.RemoveXattr(t.Name); err != nil { + if err := ref.safelyWrite(func() error { + // Don't allow removexattr on files that have been deleted. + if ref.isDeleted() { + return syscall.EINVAL + } + return ref.file.RemoveXattr(t.Name) + }); err != nil { return newErr(err) } return &Rremovexattr{} diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index 3863ad1f5..57b89ad7d 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -1926,19 +1926,17 @@ func (r *Rreaddir) decode(b *buffer) { // encode implements encoder.encode. func (r *Rreaddir) encode(b *buffer) { entriesBuf := buffer{} + payloadSize := 0 for _, d := range r.Entries { d.encode(&entriesBuf) - if len(entriesBuf.data) >= int(r.Count) { + if len(entriesBuf.data) > int(r.Count) { break } + payloadSize = len(entriesBuf.data) } - if len(entriesBuf.data) < int(r.Count) { - r.Count = uint32(len(entriesBuf.data)) - r.payload = entriesBuf.data - } else { - r.payload = entriesBuf.data[:r.Count] - } - b.Write32(uint32(r.Count)) + r.Count = uint32(payloadSize) + r.payload = entriesBuf.data[:payloadSize] + b.Write32(r.Count) } // Type implements message.Type. diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go index c20324404..7facc9f5e 100644 --- a/pkg/p9/messages_test.go +++ b/pkg/p9/messages_test.go @@ -216,7 +216,7 @@ func TestEncodeDecode(t *testing.T) { }, &Rreaddir{ // Count must be sufficient to encode a dirent. - Count: 0x18, + Count: 0x1a, Entries: []Dirent{{QID: QID{Type: 2}}}, }, &Tfsync{ diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go index a0d274f3b..38038abdf 100644 --- a/pkg/p9/transport_flipcall.go +++ b/pkg/p9/transport_flipcall.go @@ -236,7 +236,7 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) { // Convert errors appropriately; see above. if rlerr, ok := r.(*Rlerror); ok { - return nil, syscall.Errno(rlerr.Error) + return r, syscall.Errno(rlerr.Error) } return r, nil diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s index 129691d68..00b46c18f 100644 --- a/pkg/safecopy/memcpy_amd64.s +++ b/pkg/safecopy/memcpy_amd64.s @@ -55,15 +55,9 @@ TEXT ·memcpy(SB), NOSPLIT, $0-36 MOVQ from+8(FP), SI MOVQ n+16(FP), BX - // REP instructions have a high startup cost, so we handle small sizes - // with some straightline code. The REP MOVSQ instruction is really fast - // for large sizes. The cutover is approximately 2K. tail: - // move_129through256 or smaller work whether or not the source and the - // destination memory regions overlap because they load all data into - // registers before writing it back. move_256through2048 on the other - // hand can be used only when the memory regions don't overlap or the copy - // direction is forward. + // BSR+branch table make almost all memmove/memclr benchmarks worse. Not + // worth doing. TESTQ BX, BX JEQ move_0 CMPQ BX, $2 @@ -83,31 +77,45 @@ tail: JBE move_65through128 CMPQ BX, $256 JBE move_129through256 - // TODO: use branch table and BSR to make this just a single dispatch -/* - * forward copy loop - */ - CMPQ BX, $2048 - JLS move_256through2048 - - // Check alignment - MOVL SI, AX - ORL DI, AX - TESTL $7, AX - JEQ fwdBy8 - - // Do 1 byte at a time - MOVQ BX, CX - REP; MOVSB - RET - -fwdBy8: - // Do 8 bytes at a time - MOVQ BX, CX - SHRQ $3, CX - ANDQ $7, BX - REP; MOVSQ +move_257plus: + SUBQ $256, BX + MOVOU (SI), X0 + MOVOU X0, (DI) + MOVOU 16(SI), X1 + MOVOU X1, 16(DI) + MOVOU 32(SI), X2 + MOVOU X2, 32(DI) + MOVOU 48(SI), X3 + MOVOU X3, 48(DI) + MOVOU 64(SI), X4 + MOVOU X4, 64(DI) + MOVOU 80(SI), X5 + MOVOU X5, 80(DI) + MOVOU 96(SI), X6 + MOVOU X6, 96(DI) + MOVOU 112(SI), X7 + MOVOU X7, 112(DI) + MOVOU 128(SI), X8 + MOVOU X8, 128(DI) + MOVOU 144(SI), X9 + MOVOU X9, 144(DI) + MOVOU 160(SI), X10 + MOVOU X10, 160(DI) + MOVOU 176(SI), X11 + MOVOU X11, 176(DI) + MOVOU 192(SI), X12 + MOVOU X12, 192(DI) + MOVOU 208(SI), X13 + MOVOU X13, 208(DI) + MOVOU 224(SI), X14 + MOVOU X14, 224(DI) + MOVOU 240(SI), X15 + MOVOU X15, 240(DI) + CMPQ BX, $256 + LEAQ 256(SI), SI + LEAQ 256(DI), DI + JGE move_257plus JMP tail move_1or2: @@ -209,42 +217,3 @@ move_129through256: MOVOU -16(SI)(BX*1), X15 MOVOU X15, -16(DI)(BX*1) RET -move_256through2048: - SUBQ $256, BX - MOVOU (SI), X0 - MOVOU X0, (DI) - MOVOU 16(SI), X1 - MOVOU X1, 16(DI) - MOVOU 32(SI), X2 - MOVOU X2, 32(DI) - MOVOU 48(SI), X3 - MOVOU X3, 48(DI) - MOVOU 64(SI), X4 - MOVOU X4, 64(DI) - MOVOU 80(SI), X5 - MOVOU X5, 80(DI) - MOVOU 96(SI), X6 - MOVOU X6, 96(DI) - MOVOU 112(SI), X7 - MOVOU X7, 112(DI) - MOVOU 128(SI), X8 - MOVOU X8, 128(DI) - MOVOU 144(SI), X9 - MOVOU X9, 144(DI) - MOVOU 160(SI), X10 - MOVOU X10, 160(DI) - MOVOU 176(SI), X11 - MOVOU X11, 176(DI) - MOVOU 192(SI), X12 - MOVOU X12, 192(DI) - MOVOU 208(SI), X13 - MOVOU X13, 208(DI) - MOVOU 224(SI), X14 - MOVOU X14, 224(DI) - MOVOU 240(SI), X15 - MOVOU X15, 240(DI) - CMPQ BX, $256 - LEAQ 256(SI), SI - LEAQ 256(DI), DI - JGE move_256through2048 - JMP tail diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go index f19a005f3..97b16c158 100644 --- a/pkg/segment/test/segment_test.go +++ b/pkg/segment/test/segment_test.go @@ -63,7 +63,7 @@ func checkSet(s *Set, expectedSegments int) error { return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) } if got, want := seg.Value(), seg.Start()+valueOffset; got != want { - return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start, got, want) + return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start(), got, want) } prev = next havePrev = true diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go index 09bceabc9..1108fa0bd 100644 --- a/pkg/sentry/arch/stack.go +++ b/pkg/sentry/arch/stack.go @@ -97,7 +97,6 @@ func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) { if c < 0 { return 0, fmt.Errorf("bad binary.Size for %T", v) } - // TODO(b/38173783): Use a real context.Context. n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{}) if err != nil || c != n { return 0, err @@ -121,11 +120,9 @@ func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) { var err error if isVaddr { value := s.Arch.Native(uintptr(0)) - // TODO(b/38173783): Use a real context.Context. n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{}) *vaddr = usermem.Addr(s.Arch.Value(value)) } else { - // TODO(b/38173783): Use a real context.Context. n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{}) } if err != nil { diff --git a/pkg/sentry/contexttest/contexttest.go b/pkg/sentry/contexttest/contexttest.go index 031fc64ec..8e5658c7a 100644 --- a/pkg/sentry/contexttest/contexttest.go +++ b/pkg/sentry/contexttest/contexttest.go @@ -97,7 +97,7 @@ type hostClock struct { } // Now implements ktime.Clock.Now. -func (hostClock) Now() ktime.Time { +func (*hostClock) Now() ktime.Time { return ktime.FromNanoseconds(time.Now().UnixNano()) } @@ -127,7 +127,7 @@ func (t *TestContext) Value(key interface{}) interface{} { case uniqueid.CtxInotifyCookie: return atomic.AddUint32(&lastInotifyCookie, 1) case ktime.CtxRealtimeClock: - return hostClock{} + return &hostClock{} default: if val, ok := t.otherValues[key]; ok { return val diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index 0266a5287..65be12175 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -312,9 +312,9 @@ func (d *Dirent) SyncAll(ctx context.Context) { // There is nothing to sync for a read-only filesystem. if !d.Inode.MountSource.Flags.ReadOnly { - // FIXME(b/34856369): This should be a mount traversal, not a - // Dirent traversal, because some Inodes that need to be synced - // may no longer be reachable by name (after sys_unlink). + // NOTE(b/34856369): This should be a mount traversal, not a Dirent + // traversal, because some Inodes that need to be synced may no longer + // be reachable by name (after sys_unlink). // // Write out metadata, dirty page cached pages, and sync disk/remote // caches. diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go index 5aff0cc95..a0082ecca 100644 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_test.go @@ -119,7 +119,7 @@ func TestNewPipe(t *testing.T) { continue } if flags := p.flags; test.flags != flags { - t.Errorf("%s: got file flags %s, want %s", test.desc, flags, test.flags) + t.Errorf("%s: got file flags %v, want %v", test.desc, flags, test.flags) continue } if len(test.readAheadBuffer) != len(p.readAheadBuffer) { @@ -136,7 +136,7 @@ func TestNewPipe(t *testing.T) { continue } if !fdnotifier.HasFD(int32(f.FD())) { - t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD) + t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD()) } } } diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go index ff96b28ba..edd6576aa 100644 --- a/pkg/sentry/fs/gofer/file_state.go +++ b/pkg/sentry/fs/gofer/file_state.go @@ -34,7 +34,6 @@ func (f *fileOperations) afterLoad() { flags := f.flags flags.Truncate = false - // TODO(b/38173783): Context is not plumbed to save/restore. f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), flags, f.inodeOperations.cachingInodeOps) if err != nil { return fmt.Errorf("failed to re-open handle: %v", err) diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go index 9f7c3e89f..fc14249be 100644 --- a/pkg/sentry/fs/gofer/handles.go +++ b/pkg/sentry/fs/gofer/handles.go @@ -57,7 +57,6 @@ func (h *handles) DecRef() { } } } - // FIXME(b/38173783): Context is not plumbed here. if err := h.File.close(context.Background()); err != nil { log.Warningf("error closing p9 file: %v", err) } diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 1c934981b..a016c896e 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -273,7 +273,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle // operations on the old will see the new data. Then, make the new handle take // ownereship of the old FD and mark the old readHandle to not close the FD // when done. - if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), 0); err != nil { + if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), syscall.O_CLOEXEC); err != nil { return err } @@ -710,13 +710,10 @@ func init() { } // AddLink implements InodeOperations.AddLink, but is currently a noop. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (*inodeOperations) AddLink() {} // DropLink implements InodeOperations.DropLink, but is currently a noop. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (*inodeOperations) DropLink() {} // NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {} diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go index 238f7804c..a3402e343 100644 --- a/pkg/sentry/fs/gofer/inode_state.go +++ b/pkg/sentry/fs/gofer/inode_state.go @@ -123,7 +123,6 @@ func (i *inodeFileState) afterLoad() { // beforeSave. return fmt.Errorf("failed to find path for inode number %d. Device %s contains %s", i.sattr.InodeID, i.s.connID, fs.InodeMappings(i.s.inodeMappings)) } - // TODO(b/38173783): Context is not plumbed to save/restore. ctx := &dummyClockContext{context.Background()} _, i.file, err = i.s.attach.walk(ctx, splitAbsolutePath(name)) diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go index 111da59f9..2d398b753 100644 --- a/pkg/sentry/fs/gofer/session_state.go +++ b/pkg/sentry/fs/gofer/session_state.go @@ -104,7 +104,6 @@ func (s *session) afterLoad() { // If private unix sockets are enabled, create and fill the session's endpoint // maps. if opts.privateunixsocket { - // TODO(b/38173783): Context is not plumbed to save/restore. ctx := &dummyClockContext{context.Background()} if err = s.restoreEndpointMaps(ctx); err != nil { diff --git a/pkg/sentry/fs/gofer/util.go b/pkg/sentry/fs/gofer/util.go index 2d8d3a2ea..47a6c69bf 100644 --- a/pkg/sentry/fs/gofer/util.go +++ b/pkg/sentry/fs/gofer/util.go @@ -20,17 +20,29 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fs" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) func utimes(ctx context.Context, file contextFile, ts fs.TimeSpec) error { if ts.ATimeOmit && ts.MTimeOmit { return nil } + + // Replace requests to use the "system time" with the current time to + // ensure that timestamps remain consistent with the remote + // filesystem. + now := ktime.NowFromContext(ctx) + if ts.ATimeSetSystemTime { + ts.ATime = now + } + if ts.MTimeSetSystemTime { + ts.MTime = now + } mask := p9.SetAttrMask{ ATime: !ts.ATimeOmit, - ATimeNotSystemTime: !ts.ATimeSetSystemTime, + ATimeNotSystemTime: true, MTime: !ts.MTimeOmit, - MTimeNotSystemTime: !ts.MTimeSetSystemTime, + MTimeNotSystemTime: true, } as, ans := ts.ATime.Unix() ms, mns := ts.MTime.Unix() diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 1da3c0a17..62f1246aa 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -397,15 +397,12 @@ func (i *inodeOperations) StatFS(context.Context) (fs.Info, error) { } // AddLink implements fs.InodeOperations.AddLink. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (i *inodeOperations) AddLink() {} // DropLink implements fs.InodeOperations.DropLink. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (i *inodeOperations) DropLink() {} // NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange. -// FIXME(b/63117438): Remove this from InodeOperations altogether. func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {} // readdirAll returns all of the directory entries in i. diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index eb4afe520..affdbcacb 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -199,14 +199,14 @@ func TestListen(t *testing.T) { } func TestPasscred(t *testing.T) { - e := ConnectedEndpoint{} + e := &ConnectedEndpoint{} if got, want := e.Passcred(), false; got != want { t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want) } } func TestGetLocalAddress(t *testing.T) { - e := ConnectedEndpoint{path: "foo"} + e := &ConnectedEndpoint{path: "foo"} want := tcpip.FullAddress{Addr: tcpip.Address("foo")} if got, err := e.GetLocalAddress(); err != nil || got != want { t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil) @@ -214,7 +214,7 @@ func TestGetLocalAddress(t *testing.T) { } func TestQueuedSize(t *testing.T) { - e := ConnectedEndpoint{} + e := &ConnectedEndpoint{} tests := []struct { name string f func() int64 diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index 55fb71c16..a34fbc946 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -102,7 +102,6 @@ func (i *Inode) DecRef() { // destroy releases the Inode and releases the msrc reference taken. func (i *Inode) destroy() { - // FIXME(b/38173783): Context is not plumbed here. ctx := context.Background() if err := i.WriteOut(ctx); err != nil { // FIXME(b/65209558): Mark as warning again once noatime is @@ -397,8 +396,6 @@ func (i *Inode) Getlink(ctx context.Context) (*Dirent, error) { // AddLink calls i.InodeOperations.AddLink. func (i *Inode) AddLink() { if i.overlay != nil { - // FIXME(b/63117438): Remove this from InodeOperations altogether. - // // This interface is only used by ramfs to update metadata of // children. These filesystems should _never_ have overlay // Inodes cached as children. So explicitly disallow this diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go index 94deb553b..1fc9c703c 100644 --- a/pkg/sentry/fs/proc/mounts.go +++ b/pkg/sentry/fs/proc/mounts.go @@ -170,7 +170,8 @@ func superBlockOpts(mountPath string, msrc *fs.MountSource) string { // NOTE(b/147673608): If the mount is a cgroup, we also need to include // the cgroup name in the options. For now we just read that from the // path. - // TODO(gvisor.dev/issues/190): Once gVisor has full cgroup support, we + // + // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we // should get this value from the cgroup itself, and not rely on the // path. if msrc.FilesystemType == "cgroup" { diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index d4c4b533d..702fdd392 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -80,7 +80,7 @@ func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir } // Truncate implements fs.InodeOperations.Truncate. -func (tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error { +func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error { return nil } @@ -196,7 +196,7 @@ func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *f } // Truncate implements fs.InodeOperations.Truncate. -func (tcpSack) Truncate(context.Context, *fs.Inode, int64) error { +func (*tcpSack) Truncate(context.Context, *fs.Inode, int64) error { return nil } diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index d6c5dd2c1..4d42eac83 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -57,6 +57,16 @@ func getTaskMM(t *kernel.Task) (*mm.MemoryManager, error) { return m, nil } +func checkTaskState(t *kernel.Task) error { + switch t.ExitState() { + case kernel.TaskExitZombie: + return syserror.EACCES + case kernel.TaskExitDead: + return syserror.ESRCH + } + return nil +} + // taskDir represents a task-level directory. // // +stateify savable @@ -254,11 +264,12 @@ func newExe(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { } func (e *exe) executable() (file fsbridge.File, err error) { + if err := checkTaskState(e.t); err != nil { + return nil, err + } e.t.WithMuLocked(func(t *kernel.Task) { mm := t.MemoryManager() if mm == nil { - // TODO(b/34851096): Check shouldn't allow Readlink once the - // Task is zombied. err = syserror.EACCES return } @@ -268,7 +279,7 @@ func (e *exe) executable() (file fsbridge.File, err error) { // (with locks held). file = mm.Executable() if file == nil { - err = syserror.ENOENT + err = syserror.ESRCH } }) return @@ -313,11 +324,22 @@ func newNamespaceSymlink(t *kernel.Task, msrc *fs.MountSource, name string) *fs. return newProcInode(t, n, msrc, fs.Symlink, t) } +// Readlink reads the symlink value. +func (n *namespaceSymlink) Readlink(ctx context.Context, inode *fs.Inode) (string, error) { + if err := checkTaskState(n.t); err != nil { + return "", err + } + return n.Symlink.Readlink(ctx, inode) +} + // Getlink implements fs.InodeOperations.Getlink. func (n *namespaceSymlink) Getlink(ctx context.Context, inode *fs.Inode) (*fs.Dirent, error) { if !kernel.ContextCanTrace(ctx, n.t, false) { return nil, syserror.EACCES } + if err := checkTaskState(n.t); err != nil { + return nil, err + } // Create a new regular file to fake the namespace file. iops := fsutil.NewNoReadWriteFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0777), linux.PROC_SUPER_MAGIC) diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go index d5be56c3f..bc117ca6a 100644 --- a/pkg/sentry/fs/tmpfs/fs.go +++ b/pkg/sentry/fs/tmpfs/fs.go @@ -44,9 +44,6 @@ const ( // lookup. cacheRevalidate = "revalidate" - // TODO(edahlgren/mpratt): support a tmpfs size limit. - // size = "size" - // Permissions that exceed modeMask will be rejected. modeMask = 01777 diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go index 79b808359..89168220a 100644 --- a/pkg/sentry/fsbridge/vfs.go +++ b/pkg/sentry/fsbridge/vfs.go @@ -26,22 +26,22 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// fsFile implements File interface over vfs.FileDescription. +// VFSFile implements File interface over vfs.FileDescription. // // +stateify savable -type vfsFile struct { +type VFSFile struct { file *vfs.FileDescription } -var _ File = (*vfsFile)(nil) +var _ File = (*VFSFile)(nil) // NewVFSFile creates a new File over fs.File. func NewVFSFile(file *vfs.FileDescription) File { - return &vfsFile{file: file} + return &VFSFile{file: file} } // PathnameWithDeleted implements File. -func (f *vfsFile) PathnameWithDeleted(ctx context.Context) string { +func (f *VFSFile) PathnameWithDeleted(ctx context.Context) string { root := vfs.RootFromContext(ctx) defer root.DecRef() @@ -51,7 +51,7 @@ func (f *vfsFile) PathnameWithDeleted(ctx context.Context) string { } // ReadFull implements File. -func (f *vfsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) { +func (f *VFSFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) { var total int64 for dst.NumBytes() > 0 { n, err := f.file.PRead(ctx, dst, offset+total, vfs.ReadOptions{}) @@ -67,12 +67,12 @@ func (f *vfsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset i } // ConfigureMMap implements File. -func (f *vfsFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { +func (f *VFSFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { return f.file.ConfigureMMap(ctx, opts) } // Type implements File. -func (f *vfsFile) Type(ctx context.Context) (linux.FileMode, error) { +func (f *VFSFile) Type(ctx context.Context) (linux.FileMode, error) { stat, err := f.file.Stat(ctx, vfs.StatOptions{}) if err != nil { return 0, err @@ -81,15 +81,21 @@ func (f *vfsFile) Type(ctx context.Context) (linux.FileMode, error) { } // IncRef implements File. -func (f *vfsFile) IncRef() { +func (f *VFSFile) IncRef() { f.file.IncRef() } // DecRef implements File. -func (f *vfsFile) DecRef() { +func (f *VFSFile) DecRef() { f.file.DecRef() } +// FileDescription returns the FileDescription represented by f. It does not +// take an additional reference on the returned FileDescription. +func (f *VFSFile) FileDescription() *vfs.FileDescription { + return f.file +} + // fsLookup implements Lookup interface using fs.File. // // +stateify savable @@ -132,5 +138,5 @@ func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.Open if err != nil { return nil, err } - return &vfsFile{file: fd}, nil + return &VFSFile{file: fd}, nil } diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 48eaccdbc..afea58f65 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -476,7 +476,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath } // ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { _, _, err := fs.walk(rp, false) if err != nil { return nil, err @@ -485,7 +485,7 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([ } // GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { _, _, err := fs.walk(rp, false) if err != nil { return "", err diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index d15a36709..99d1e3f8f 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) @@ -54,3 +54,13 @@ go_library( "//pkg/usermem", ], ) + +go_test( + name = "gofer_test", + srcs = ["gofer_test.go"], + library = ":gofer", + deps = [ + "//pkg/p9", + "//pkg/sentry/contexttest", + ], +) diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 5dbfc6250..49d9f859b 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -56,14 +56,19 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fd.mu.Lock() defer fd.mu.Unlock() + d := fd.dentry() if fd.dirents == nil { - ds, err := fd.dentry().getDirents(ctx) + ds, err := d.getDirents(ctx) if err != nil { return err } fd.dirents = ds } + if d.fs.opts.interop != InteropModeShared { + d.touchAtime(fd.vfsfd.Mount()) + } + for fd.off < int64(len(fd.dirents)) { if err := cb.Handle(fd.dirents[fd.off]); err != nil { return err diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 269624362..cd744bf5e 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -356,7 +356,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if err := create(parent, name); err != nil { return err } - parent.touchCMtime(ctx) + if fs.opts.interop != InteropModeShared { + parent.touchCMtime() + } delete(parent.negativeChildren, name) parent.dirents = nil return nil @@ -435,14 +437,19 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b flags := uint32(0) if dir { if child != nil && !child.isDir() { + vfsObj.AbortDeleteDentry(childVFSD) return syserror.ENOTDIR } flags = linux.AT_REMOVEDIR } else { if child != nil && child.isDir() { + vfsObj.AbortDeleteDentry(childVFSD) return syserror.EISDIR } if rp.MustBeDir() { + if childVFSD != nil { + vfsObj.AbortDeleteDentry(childVFSD) + } return syserror.ENOTDIR } } @@ -454,7 +461,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b return err } if fs.opts.interop != InteropModeShared { - parent.touchCMtime(ctx) + parent.touchCMtime() if dir { parent.decLinks() } @@ -802,7 +809,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving d.IncRef() // reference held by child on its parent d d.vfsd.InsertChild(&child.vfsd, name) if d.fs.opts.interop != InteropModeShared { - d.touchCMtime(ctx) delete(d.negativeChildren, name) d.dirents = nil } @@ -834,6 +840,9 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } childVFSFD = &fd.vfsfd } + if d.fs.opts.interop != InteropModeShared { + d.touchCMtime() + } return childVFSFD, nil } @@ -975,6 +984,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa oldParent.decLinks() newParent.incLinks() } + oldParent.touchCMtime() + newParent.touchCMtime() + renamed.touchCtime() } vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, &newParent.vfsd, newName, replacedVFSD) return nil @@ -1068,7 +1080,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath } // ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(&ds) @@ -1076,11 +1088,11 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([ if err != nil { return nil, err } - return d.listxattr(ctx) + return d.listxattr(ctx, rp.Credentials(), size) } // GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(&ds) @@ -1088,7 +1100,7 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, nam if err != nil { return "", err } - return d.getxattr(ctx, name) + return d.getxattr(ctx, rp.Credentials(), &opts) } // SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. @@ -1100,7 +1112,7 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt if err != nil { return err } - return d.setxattr(ctx, &opts) + return d.setxattr(ctx, rp.Credentials(), &opts) } // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. @@ -1112,7 +1124,7 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, if err != nil { return err } - return d.removexattr(ctx, name) + return d.removexattr(ctx, rp.Credentials(), name) } // PrependPath implements vfs.FilesystemImpl.PrependPath. diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 8e41b6b1c..2485cdb53 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -34,6 +34,7 @@ package gofer import ( "fmt" "strconv" + "strings" "sync" "sync/atomic" "syscall" @@ -44,6 +45,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -72,6 +74,9 @@ type filesystem struct { // client is the client used by this filesystem. client is immutable. client *p9.Client + // clock is a realtime clock used to set timestamps in file operations. + clock ktime.Clock + // uid and gid are the effective KUID and KGID of the filesystem's creator, // and are used as the owner and group for files that don't specify one. // uid and gid are immutable. @@ -376,6 +381,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt uid: creds.EffectiveKUID, gid: creds.EffectiveKGID, client: client, + clock: ktime.RealtimeClockFromContext(ctx), dentries: make(map[*dentry]struct{}), specialFileFDs: make(map[*specialFileFD]struct{}), } @@ -439,7 +445,8 @@ type dentry struct { // refs is the reference count. Each dentry holds a reference on its // parent, even if disowned. refs is accessed using atomic memory - // operations. + // operations. When refs reaches 0, the dentry may be added to the cache or + // destroyed. If refs==-1 the dentry has already been destroyed. refs int64 // fs is the owning filesystem. fs is immutable. @@ -779,10 +786,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin // data, so there's no cache to truncate either.) return nil } - now, haveNow := nowFromContext(ctx) - if !haveNow { - ctx.Warningf("gofer.dentry.setStat: current time not available") - } + now := d.fs.clock.Now().Nanoseconds() if stat.Mask&linux.STATX_MODE != 0 { atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) } @@ -794,25 +798,19 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin } if setLocalAtime { if stat.Atime.Nsec == linux.UTIME_NOW { - if haveNow { - atomic.StoreInt64(&d.atime, now) - } + atomic.StoreInt64(&d.atime, now) } else { atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) } } if setLocalMtime { if stat.Mtime.Nsec == linux.UTIME_NOW { - if haveNow { - atomic.StoreInt64(&d.mtime, now) - } + atomic.StoreInt64(&d.mtime, now) } else { atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) } } - if haveNow { - atomic.StoreInt64(&d.ctime, now) - } + atomic.StoreInt64(&d.ctime, now) if stat.Mask&linux.STATX_SIZE != 0 { d.dataMu.Lock() oldSize := d.size @@ -864,7 +862,7 @@ func (d *dentry) IncRef() { func (d *dentry) TryIncRef() bool { for { refs := atomic.LoadInt64(&d.refs) - if refs == 0 { + if refs <= 0 { return false } if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { @@ -887,13 +885,20 @@ func (d *dentry) DecRef() { // checkCachingLocked should be called after d's reference count becomes 0 or it // becomes disowned. // +// It may be called on a destroyed dentry. For example, +// renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times +// for the same dentry when the dentry is visited more than once in the same +// operation. One of the calls may destroy the dentry, so subsequent calls will +// do nothing. +// // Preconditions: d.fs.renameMu must be locked for writing. func (d *dentry) checkCachingLocked() { // Dentries with a non-zero reference count must be retained. (The only way // to obtain a reference on a dentry with zero references is via path // resolution, which requires renameMu, so if d.refs is zero then it will // remain zero while we hold renameMu for writing.) - if atomic.LoadInt64(&d.refs) != 0 { + refs := atomic.LoadInt64(&d.refs) + if refs > 0 { if d.cached { d.fs.cachedDentries.Remove(d) d.fs.cachedDentriesLen-- @@ -901,6 +906,10 @@ func (d *dentry) checkCachingLocked() { } return } + if refs == -1 { + // Dentry has already been destroyed. + return + } // Non-child dentries with zero references are no longer reachable by path // resolution and should be dropped immediately. if d.vfsd.Parent() == nil || d.vfsd.IsDisowned() { @@ -953,9 +962,22 @@ func (d *dentry) checkCachingLocked() { } } +// destroyLocked destroys the dentry. It may flushes dirty pages from cache, +// close p9 file and remove reference on parent dentry. +// // Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. d is // not a child dentry. func (d *dentry) destroyLocked() { + switch atomic.LoadInt64(&d.refs) { + case 0: + // Mark the dentry destroyed. + atomic.StoreInt64(&d.refs, -1) + case -1: + panic("dentry.destroyLocked() called on already destroyed dentry") + default: + panic("dentry.destroyLocked() called with references on the dentry") + } + ctx := context.Background() d.handleMu.Lock() if !d.handle.file.isNil() { @@ -975,7 +997,10 @@ func (d *dentry) destroyLocked() { d.handle.close(ctx) } d.handleMu.Unlock() - d.file.close(ctx) + if !d.file.isNil() { + d.file.close(ctx) + d.file = p9file{} + } // Remove d from the set of all dentries. d.fs.syncMu.Lock() delete(d.fs.dentries, d) @@ -1000,21 +1025,50 @@ func (d *dentry) setDeleted() { atomic.StoreUint32(&d.deleted, 1) } -func (d *dentry) listxattr(ctx context.Context) ([]string, error) { - return nil, syserror.ENOTSUP +// We only support xattrs prefixed with "user." (see b/148380782). Currently, +// there is no need to expose any other xattrs through a gofer. +func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { + xattrMap, err := d.file.listXattr(ctx, size) + if err != nil { + return nil, err + } + xattrs := make([]string, 0, len(xattrMap)) + for x := range xattrMap { + if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) { + xattrs = append(xattrs, x) + } + } + return xattrs, nil } -func (d *dentry) getxattr(ctx context.Context, name string) (string, error) { - // TODO(jamieliu): add vfs.GetxattrOptions.Size - return d.file.getXattr(ctx, name, linux.XATTR_SIZE_MAX) +func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) { + if err := d.checkPermissions(creds, vfs.MayRead); err != nil { + return "", err + } + if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { + return "", syserror.EOPNOTSUPP + } + return d.file.getXattr(ctx, opts.Name, opts.Size) } -func (d *dentry) setxattr(ctx context.Context, opts *vfs.SetxattrOptions) error { +func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetxattrOptions) error { + if err := d.checkPermissions(creds, vfs.MayWrite); err != nil { + return err + } + if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags) } -func (d *dentry) removexattr(ctx context.Context, name string) error { - return syserror.ENOTSUP +func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name string) error { + if err := d.checkPermissions(creds, vfs.MayWrite); err != nil { + return err + } + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + return d.file.removeXattr(ctx, name) } // Preconditions: d.isRegularFile() || d.isDirectory(). @@ -1065,7 +1119,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // description, but this doesn't matter since they refer to the // same file (unless d.fs.opts.overlayfsStaleRead is true, // which we handle separately). - if err := syscall.Dup3(int(h.fd), int(d.handle.fd), 0); err != nil { + if err := syscall.Dup3(int(h.fd), int(d.handle.fd), syscall.O_CLOEXEC); err != nil { d.handleMu.Unlock() ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err) h.close(ctx) @@ -1165,21 +1219,21 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) } // Listxattr implements vfs.FileDescriptionImpl.Listxattr. -func (fd *fileDescription) Listxattr(ctx context.Context) ([]string, error) { - return fd.dentry().listxattr(ctx) +func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { + return fd.dentry().listxattr(ctx, auth.CredentialsFromContext(ctx), size) } // Getxattr implements vfs.FileDescriptionImpl.Getxattr. -func (fd *fileDescription) Getxattr(ctx context.Context, name string) (string, error) { - return fd.dentry().getxattr(ctx, name) +func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) { + return fd.dentry().getxattr(ctx, auth.CredentialsFromContext(ctx), &opts) } // Setxattr implements vfs.FileDescriptionImpl.Setxattr. func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { - return fd.dentry().setxattr(ctx, &opts) + return fd.dentry().setxattr(ctx, auth.CredentialsFromContext(ctx), &opts) } // Removexattr implements vfs.FileDescriptionImpl.Removexattr. func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { - return fd.dentry().removexattr(ctx, name) + return fd.dentry().removexattr(ctx, auth.CredentialsFromContext(ctx), name) } diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go new file mode 100644 index 000000000..82bc239db --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -0,0 +1,64 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "sync/atomic" + "testing" + + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sentry/contexttest" +) + +func TestDestroyIdempotent(t *testing.T) { + fs := filesystem{ + dentries: make(map[*dentry]struct{}), + opts: filesystemOptions{ + // Test relies on no dentry being held in the cache. + maxCachedDentries: 0, + }, + } + + ctx := contexttest.Context(t) + attr := &p9.Attr{ + Mode: p9.ModeRegular, + } + mask := p9.AttrMask{ + Mode: true, + Size: true, + } + parent, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr) + if err != nil { + t.Fatalf("fs.newDentry(): %v", err) + } + + child, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr) + if err != nil { + t.Fatalf("fs.newDentry(): %v", err) + } + parent.IncRef() // reference held by child on its parent. + parent.vfsd.InsertChild(&child.vfsd, "child") + + child.checkCachingLocked() + if got := atomic.LoadInt64(&child.refs); got != -1 { + t.Fatalf("child.refs=%d, want: -1", got) + } + // Parent will also be destroyed when child reference is removed. + if got := atomic.LoadInt64(&parent.refs); got != -1 { + t.Fatalf("parent.refs=%d, want: -1", got) + } + child.checkCachingLocked() + child.checkCachingLocked() +} diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go index 755ac2985..87f0b877f 100644 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ b/pkg/sentry/fsimpl/gofer/p9file.go @@ -85,6 +85,13 @@ func (f p9file) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAt return err } +func (f p9file) listXattr(ctx context.Context, size uint64) (map[string]struct{}, error) { + ctx.UninterruptibleSleepStart(false) + xattrs, err := f.file.ListXattr(size) + ctx.UninterruptibleSleepFinish(false) + return xattrs, err +} + func (f p9file) getXattr(ctx context.Context, name string, size uint64) (string, error) { ctx.UninterruptibleSleepStart(false) val, err := f.file.GetXattr(name, size) @@ -99,6 +106,13 @@ func (f p9file) setXattr(ctx context.Context, name, value string, flags uint32) return err } +func (f p9file) removeXattr(ctx context.Context, name string) error { + ctx.UninterruptibleSleepStart(false) + err := f.file.RemoveXattr(name) + ctx.UninterruptibleSleepFinish(false) + return err +} + func (f p9file) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error { ctx.UninterruptibleSleepStart(false) err := f.file.Allocate(mode, offset, length) diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 3593eb1d5..857f7c74e 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -104,7 +104,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs putDentryReadWriter(rw) if d.fs.opts.interop != InteropModeShared { // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed(). - d.touchAtime(ctx, fd.vfsfd.Mount()) + d.touchAtime(fd.vfsfd.Mount()) } return n, err } @@ -139,10 +139,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // Compare Linux's mm/filemap.c:__generic_file_write_iter() => // file_update_time(). This is d.touchCMtime(), but without locking // d.metadataMu (recursively). - if now, ok := nowFromContext(ctx); ok { - atomic.StoreInt64(&d.mtime, now) - atomic.StoreInt64(&d.ctime, now) - } + d.touchCMtimeLocked() } if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 { // Write dirty cached pages that will be touched by the write back to diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 274f7346f..507e0e276 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -76,7 +76,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs // hold here since specialFileFD doesn't client-cache data. Just buffer the // read instead. if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { - d.touchAtime(ctx, fd.vfsfd.Mount()) + d.touchAtime(fd.vfsfd.Mount()) } buf := make([]byte, dst.NumBytes()) n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) @@ -117,7 +117,7 @@ func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // Do a buffered write. See rationale in PRead. if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { - d.touchCMtime(ctx) + d.touchCMtime() } buf := make([]byte, src.NumBytes()) // Don't do partial writes if we get a partial read from src. diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go index adf43be60..2ec819f86 100644 --- a/pkg/sentry/fsimpl/gofer/symlink.go +++ b/pkg/sentry/fsimpl/gofer/symlink.go @@ -27,7 +27,7 @@ func (d *dentry) isSymlink() bool { // Precondition: d.isSymlink(). func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { if d.fs.opts.interop != InteropModeShared { - d.touchAtime(ctx, mnt) + d.touchAtime(mnt) d.dataMu.Lock() if d.haveTarget { target := d.target diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 7598ec6a8..2608e7e1d 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -18,8 +18,6 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -38,23 +36,12 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { } } -func nowFromContext(ctx context.Context) (int64, bool) { - if clock := ktime.RealtimeClockFromContext(ctx); clock != nil { - return clock.Now().Nanoseconds(), true - } - return 0, false -} - // Preconditions: fs.interop != InteropModeShared. -func (d *dentry) touchAtime(ctx context.Context, mnt *vfs.Mount) { +func (d *dentry) touchAtime(mnt *vfs.Mount) { if err := mnt.CheckBeginWrite(); err != nil { return } - now, ok := nowFromContext(ctx) - if !ok { - mnt.EndWrite() - return - } + now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() atomic.StoreInt64(&d.atime, now) d.metadataMu.Unlock() @@ -63,13 +50,25 @@ func (d *dentry) touchAtime(ctx context.Context, mnt *vfs.Mount) { // Preconditions: fs.interop != InteropModeShared. The caller has successfully // called vfs.Mount.CheckBeginWrite(). -func (d *dentry) touchCMtime(ctx context.Context) { - now, ok := nowFromContext(ctx) - if !ok { - return - } +func (d *dentry) touchCtime() { + now := d.fs.clock.Now().Nanoseconds() + d.metadataMu.Lock() + atomic.StoreInt64(&d.ctime, now) + d.metadataMu.Unlock() +} + +// Preconditions: fs.interop != InteropModeShared. The caller has successfully +// called vfs.Mount.CheckBeginWrite(). +func (d *dentry) touchCMtime() { + now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() atomic.StoreInt64(&d.mtime, now) atomic.StoreInt64(&d.ctime, now) d.metadataMu.Unlock() } + +func (d *dentry) touchCMtimeLocked() { + now := d.fs.clock.Now().Nanoseconds() + atomic.StoreInt64(&d.mtime, now) + atomic.StoreInt64(&d.ctime, now) +} diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 7d9dcd4c9..97fa7f7ab 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -74,31 +74,34 @@ func ImportFD(mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, err } // Retrieve metadata. - var s syscall.Stat_t - if err := syscall.Fstat(hostFD, &s); err != nil { + var s unix.Stat_t + if err := unix.Fstat(hostFD, &s); err != nil { return nil, err } fileMode := linux.FileMode(s.Mode) fileType := fileMode.FileType() - // Pipes, character devices, and sockets. - isStream := fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK + + // Determine if hostFD is seekable. If not, this syscall will return ESPIPE + // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character + // devices. + _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) + seekable := err != syserror.ESPIPE i := &inode{ hostFD: hostFD, - isStream: isStream, + seekable: seekable, isTTY: isTTY, canMap: canMap(uint32(fileType)), ino: fs.NextIno(), mode: fileMode, - // For simplicity, set offset to 0. Technically, we should - // only set to 0 on files that are not seekable (sockets, pipes, etc.), - // and use the offset from the host fd otherwise. + // For simplicity, set offset to 0. Technically, we should use the existing + // offset on the host if the file is seekable. offset: 0, } - // These files can't be memory mapped, assert this. - if i.isStream && i.canMap { + // Non-seekable files can't be memory mapped, assert this. + if !i.seekable && i.canMap { panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") } @@ -124,12 +127,12 @@ type inode struct { // This field is initialized at creation time and is immutable. hostFD int - // isStream is true if the host fd points to a file representing a stream, + // seekable is false if the host fd points to a file representing a stream, // e.g. a socket or a pipe. Such files are not seekable and can return // EWOULDBLOCK for I/O operations. // // This field is initialized at creation time and is immutable. - isStream bool + seekable bool // isTTY is true if this file represents a TTY. // @@ -481,8 +484,7 @@ func (f *fileDescription) Release() { // PRead implements FileDescriptionImpl. func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { i := f.inode - // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. - if i.isStream { + if !i.seekable { return 0, syserror.ESPIPE } @@ -492,8 +494,7 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off // Read implements FileDescriptionImpl. func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { i := f.inode - // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. - if i.isStream { + if !i.seekable { n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags) if isBlockError(err) { // If we got any data at all, return it as a "completed" partial read @@ -538,8 +539,7 @@ func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, off // PWrite implements FileDescriptionImpl. func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { i := f.inode - // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. - if i.isStream { + if !i.seekable { return 0, syserror.ESPIPE } @@ -549,8 +549,7 @@ func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, of // Write implements FileDescriptionImpl. func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { i := f.inode - // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. - if i.isStream { + if !i.seekable { n, err := writeToHostFD(ctx, i.hostFD, src, -1, opts.Flags) if isBlockError(err) { err = syserror.ErrWouldBlock @@ -593,8 +592,7 @@ func writeToHostFD(ctx context.Context, hostFD int, src usermem.IOSequence, offs // allow directory fds to be imported at all. func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (int64, error) { i := f.inode - // TODO(b/34716638): Some char devices do support seeking, e.g. /dev/null. - if i.isStream { + if !i.seekable { return 0, syserror.ESPIPE } diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index a429fa23d..baf81b4db 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -63,6 +63,9 @@ afterSymlink: rp.Advance() return nextVFSD, nil } + if len(name) > linux.NAME_MAX { + return nil, syserror.ENAMETOOLONG + } d.dirMu.Lock() nextVFSD, err := rp.ResolveChild(vfsd, name) if err != nil { @@ -76,16 +79,22 @@ afterSymlink: } // Resolve any symlink at current path component. if rp.ShouldFollowSymlink() && next.isSymlink() { - // TODO: VFS2 needs something extra for /proc/[pid]/fd/ "magic symlinks". - target, err := next.inode.Readlink(ctx) + targetVD, targetPathname, err := next.inode.Getlink(ctx) if err != nil { return nil, err } - if err := rp.HandleSymlink(target); err != nil { - return nil, err + if targetVD.Ok() { + err := rp.HandleJump(targetVD) + targetVD.DecRef() + if err != nil { + return nil, err + } + } else { + if err := rp.HandleSymlink(targetPathname); err != nil { + return nil, err + } } goto afterSymlink - } rp.Advance() return &next.vfsd, nil @@ -191,6 +200,9 @@ func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *v if pc == "." || pc == ".." { return "", syserror.EEXIST } + if len(pc) > linux.NAME_MAX { + return "", syserror.ENAMETOOLONG + } childVFSD, err := rp.ResolveChild(parentVFSD, pc) if err != nil { return "", err @@ -433,6 +445,9 @@ afterTrailingSymlink: if pc == "." || pc == ".." { return nil, syserror.EISDIR } + if len(pc) > linux.NAME_MAX { + return nil, syserror.ENAMETOOLONG + } // Determine whether or not we need to create a file. childVFSD, err := rp.ResolveChild(parentVFSD, pc) if err != nil { @@ -461,19 +476,25 @@ afterTrailingSymlink: } childDentry := childVFSD.Impl().(*Dentry) childInode := childDentry.inode - if rp.ShouldFollowSymlink() { - if childDentry.isSymlink() { - target, err := childInode.Readlink(ctx) + if rp.ShouldFollowSymlink() && childDentry.isSymlink() { + targetVD, targetPathname, err := childInode.Getlink(ctx) + if err != nil { + return nil, err + } + if targetVD.Ok() { + err := rp.HandleJump(targetVD) + targetVD.DecRef() if err != nil { return nil, err } - if err := rp.HandleSymlink(target); err != nil { + } else { + if err := rp.HandleSymlink(targetPathname); err != nil { return nil, err } - // rp.Final() may no longer be true since we now need to resolve the - // symlink target. - goto afterTrailingSymlink } + // rp.Final() may no longer be true since we now need to resolve the + // symlink target. + goto afterTrailingSymlink } if err := childInode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { return nil, err @@ -661,7 +682,7 @@ func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu if err != nil { return linux.Statfs{}, err } - // TODO: actually implement statfs + // TODO(gvisor.dev/issue/1193): actually implement statfs. return linux.Statfs{}, syserror.ENOSYS } @@ -742,7 +763,7 @@ func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath } // ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { +func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() @@ -755,7 +776,7 @@ func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([ } // GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { +func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 5c84b10c9..65f09af5d 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -181,6 +181,11 @@ func (InodeNotSymlink) Readlink(context.Context) (string, error) { return "", syserror.EINVAL } +// Getlink implements Inode.Getlink. +func (InodeNotSymlink) Getlink(context.Context) (vfs.VirtualDentry, string, error) { + return vfs.VirtualDentry{}, "", syserror.EINVAL +} + // InodeAttrs partially implements the Inode interface, specifically the // inodeMetadata sub interface. InodeAttrs provides functionality related to // inode attributes. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 2cefef020..ad76b9f64 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -414,7 +414,21 @@ type inodeDynamicLookup interface { } type inodeSymlink interface { - // Readlink resolves the target of a symbolic link. If an inode is not a + // Readlink returns the target of a symbolic link. If an inode is not a // symlink, the implementation should return EINVAL. Readlink(ctx context.Context) (string, error) + + // Getlink returns the target of a symbolic link, as used by path + // resolution: + // + // - If the inode is a "magic link" (a link whose target is most accurately + // represented as a VirtualDentry), Getlink returns (ok VirtualDentry, "", + // nil). A reference is taken on the returned VirtualDentry. + // + // - If the inode is an ordinary symlink, Getlink returns (zero-value + // VirtualDentry, symlink target, nil). + // + // - If the inode is not a symlink, Getlink returns (zero-value + // VirtualDentry, "", EINVAL). + Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) } diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index 5918d3309..018aa503c 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -55,6 +55,11 @@ func (s *StaticSymlink) Readlink(_ context.Context) (string, error) { return s.target, nil } +// Getlink implements Inode.Getlink. +func (s *StaticSymlink) Getlink(_ context.Context) (vfs.VirtualDentry, string, error) { + return vfs.VirtualDentry{}, s.target, nil +} + // SetStat implements Inode.SetStat not allowing inode attributes to be changed. func (*StaticSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 8156984eb..17c1342b5 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -22,7 +22,6 @@ go_library( "//pkg/log", "//pkg/refs", "//pkg/safemem", - "//pkg/sentry/fs", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/inet", diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index aee2a4392..888afc0fd 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -214,22 +214,6 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData { return &ioData{ioUsage: t} } -func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry { - // Namespace symlinks should contain the namespace name and the inode number - // for the namespace instance, so for example user:[123456]. We currently fake - // the inode number by sticking the symlink inode in its place. - target := fmt.Sprintf("%s:[%d]", ns, ino) - - inode := &kernfs.StaticSymlink{} - // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), ino, target) - - taskInode := &taskOwnedInode{Inode: inode, owner: task} - d := &kernfs.Dentry{} - d.Init(taskInode) - return d -} - // newCgroupData creates inode that shows cgroup information. // From man 7 cgroups: "For each cgroup hierarchy of which the process is a // member, there is one entry containing three colon-separated fields: diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 76bfc5307..046265eca 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -30,34 +30,35 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -type fdDir struct { - inoGen InoGenerator - task *kernel.Task - - // When produceSymlinks is set, dirents produces for the FDs are reported - // as symlink. Otherwise, they are reported as regular files. - produceSymlink bool -} - -func (i *fdDir) lookup(name string) (*vfs.FileDescription, kernel.FDFlags, error) { - fd, err := strconv.ParseUint(name, 10, 64) - if err != nil { - return nil, kernel.FDFlags{}, syserror.ENOENT - } - +func getTaskFD(t *kernel.Task, fd int32) (*vfs.FileDescription, kernel.FDFlags) { var ( file *vfs.FileDescription flags kernel.FDFlags ) - i.task.WithMuLocked(func(t *kernel.Task) { - if fdTable := t.FDTable(); fdTable != nil { - file, flags = fdTable.GetVFS2(int32(fd)) + t.WithMuLocked(func(t *kernel.Task) { + if fdt := t.FDTable(); fdt != nil { + file, flags = fdt.GetVFS2(fd) } }) + return file, flags +} + +func taskFDExists(t *kernel.Task, fd int32) bool { + file, _ := getTaskFD(t, fd) if file == nil { - return nil, kernel.FDFlags{}, syserror.ENOENT + return false } - return file, flags, nil + file.DecRef() + return true +} + +type fdDir struct { + inoGen InoGenerator + task *kernel.Task + + // When produceSymlinks is set, dirents produces for the FDs are reported + // as symlink. Otherwise, they are reported as regular files. + produceSymlink bool } // IterDirents implements kernfs.inodeDynamicLookup. @@ -128,11 +129,15 @@ func newFDDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { // Lookup implements kernfs.inodeDynamicLookup. func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { - file, _, err := i.lookup(name) + fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { - return nil, err + return nil, syserror.ENOENT + } + fd := int32(fdInt) + if !taskFDExists(i.task, fd) { + return nil, syserror.ENOENT } - taskDentry := newFDSymlink(i.task.Credentials(), file, i.inoGen.NextIno()) + taskDentry := newFDSymlink(i.task, fd, i.inoGen.NextIno()) return taskDentry.VFSDentry(), nil } @@ -169,19 +174,22 @@ func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentia // // +stateify savable type fdSymlink struct { - refs.AtomicRefCount kernfs.InodeAttrs + kernfs.InodeNoopRefCount kernfs.InodeSymlink - file *vfs.FileDescription + task *kernel.Task + fd int32 } var _ kernfs.Inode = (*fdSymlink)(nil) -func newFDSymlink(creds *auth.Credentials, file *vfs.FileDescription, ino uint64) *kernfs.Dentry { - file.IncRef() - inode := &fdSymlink{file: file} - inode.Init(creds, ino, linux.ModeSymlink|0777) +func newFDSymlink(task *kernel.Task, fd int32, ino uint64) *kernfs.Dentry { + inode := &fdSymlink{ + task: task, + fd: fd, + } + inode.Init(task.Credentials(), ino, linux.ModeSymlink|0777) d := &kernfs.Dentry{} d.Init(inode) @@ -189,21 +197,25 @@ func newFDSymlink(creds *auth.Credentials, file *vfs.FileDescription, ino uint64 } func (s *fdSymlink) Readlink(ctx context.Context) (string, error) { + file, _ := getTaskFD(s.task, s.fd) + if file == nil { + return "", syserror.ENOENT + } + defer file.DecRef() root := vfs.RootFromContext(ctx) defer root.DecRef() - - vfsObj := s.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem() - return vfsObj.PathnameWithDeleted(ctx, root, s.file.VirtualDentry()) -} - -func (s *fdSymlink) DecRef() { - s.AtomicRefCount.DecRefWithDestructor(func() { - s.Destroy() - }) + return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry()) } -func (s *fdSymlink) Destroy() { - s.file.DecRef() +func (s *fdSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) { + file, _ := getTaskFD(s.task, s.fd) + if file == nil { + return vfs.VirtualDentry{}, "", syserror.ENOENT + } + defer file.DecRef() + vd := file.VirtualDentry() + vd.IncRef() + return vd, "", nil } // fdInfoDirInode represents the inode for /proc/[pid]/fdinfo directory. @@ -238,12 +250,18 @@ func newFDInfoDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { // Lookup implements kernfs.inodeDynamicLookup. func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { - file, flags, err := i.lookup(name) + fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { - return nil, err + return nil, syserror.ENOENT + } + fd := int32(fdInt) + if !taskFDExists(i.task, fd) { + return nil, syserror.ENOENT + } + data := &fdInfoData{ + task: i.task, + fd: fd, } - - data := &fdInfoData{file: file, flags: flags} dentry := newTaskOwnedFile(i.task, i.inoGen.NextIno(), 0444, data) return dentry.VFSDentry(), nil } @@ -262,26 +280,23 @@ type fdInfoData struct { kernfs.DynamicBytesFile refs.AtomicRefCount - file *vfs.FileDescription - flags kernel.FDFlags + task *kernel.Task + fd int32 } var _ dynamicInode = (*fdInfoData)(nil) -func (d *fdInfoData) DecRef() { - d.AtomicRefCount.DecRefWithDestructor(d.destroy) -} - -func (d *fdInfoData) destroy() { - d.file.DecRef() -} - // Generate implements vfs.DynamicBytesSource.Generate. func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { + file, descriptorFlags := getTaskFD(d.task, d.fd) + if file == nil { + return syserror.ENOENT + } + defer file.DecRef() // TODO(b/121266871): Include pos, locks, and other data. For now we only // have flags. // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt - flags := uint(d.file.StatusFlags()) | d.flags.ToLinuxFileFlags() + flags := uint(file.StatusFlags()) | descriptorFlags.ToLinuxFileFlags() fmt.Fprintf(buf, "flags:\t0%o\n", flags) return nil } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index df0d1bcc5..2c6f8bdfc 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -64,6 +64,16 @@ func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) { return m, nil } +func checkTaskState(t *kernel.Task) error { + switch t.ExitState() { + case kernel.TaskExitZombie: + return syserror.EACCES + case kernel.TaskExitDead: + return syserror.ESRCH + } + return nil +} + type bufferWriter struct { buf *bytes.Buffer } @@ -610,12 +620,31 @@ func (s *exeSymlink) Readlink(ctx context.Context) (string, error) { return exec.PathnameWithDeleted(ctx), nil } +// Getlink implements kernfs.Inode.Getlink. +func (s *exeSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) { + if !kernel.ContextCanTrace(ctx, s.task, false) { + return vfs.VirtualDentry{}, "", syserror.EACCES + } + + exec, err := s.executable() + if err != nil { + return vfs.VirtualDentry{}, "", err + } + defer exec.DecRef() + + vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry() + vd.IncRef() + return vd, "", nil +} + func (s *exeSymlink) executable() (file fsbridge.File, err error) { + if err := checkTaskState(s.task); err != nil { + return nil, err + } + s.task.WithMuLocked(func(t *kernel.Task) { mm := t.MemoryManager() if mm == nil { - // TODO(b/34851096): Check shouldn't allow Readlink once the - // Task is zombied. err = syserror.EACCES return } @@ -625,7 +654,7 @@ func (s *exeSymlink) executable() (file fsbridge.File, err error) { // (with locks held). file = mm.Executable() if file == nil { - err = syserror.ENOENT + err = syserror.ESRCH } }) return @@ -692,3 +721,41 @@ func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error { i.task.Kernel().VFS().GenerateProcMounts(ctx, rootDir, buf) return nil } + +type namespaceSymlink struct { + kernfs.StaticSymlink + + task *kernel.Task +} + +func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry { + // Namespace symlinks should contain the namespace name and the inode number + // for the namespace instance, so for example user:[123456]. We currently fake + // the inode number by sticking the symlink inode in its place. + target := fmt.Sprintf("%s:[%d]", ns, ino) + + inode := &namespaceSymlink{task: task} + // Note: credentials are overridden by taskOwnedInode. + inode.Init(task.Credentials(), ino, target) + + taskInode := &taskOwnedInode{Inode: inode, owner: task} + d := &kernfs.Dentry{} + d.Init(taskInode) + return d +} + +// Readlink implements Inode. +func (s *namespaceSymlink) Readlink(ctx context.Context) (string, error) { + if err := checkTaskState(s.task); err != nil { + return "", err + } + return s.StaticSymlink.Readlink(ctx) +} + +// Getlink implements Inode.Getlink. +func (s *namespaceSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) { + if err := checkTaskState(s.task); err != nil { + return vfs.VirtualDentry{}, "", err + } + return s.StaticSymlink.Getlink(ctx) +} diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 373a7b17d..6595fcee6 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -32,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/usermem" @@ -206,22 +206,21 @@ var _ dynamicInode = (*netUnixData)(nil) func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { buf.WriteString("Num RefCount Protocol Flags Type St Inode Path\n") for _, se := range n.kernel.ListSockets() { - s := se.Sock.Get() - if s == nil { - log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock) + s := se.SockVFS2 + if !s.TryIncRef() { + log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s) continue } - sfile := s.(*fs.File) - if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX { + if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX { s.DecRef() // Not a unix socket. continue } - sops := sfile.FileOperations.(*unix.SocketOperations) + sops := s.Impl().(*unix.SocketVFS2) addr, err := sops.Endpoint().GetLocalAddress() if err != nil { - log.Warningf("Failed to retrieve socket name from %+v: %v", sfile, err) + log.Warningf("Failed to retrieve socket name from %+v: %v", s, err) addr.Addr = "<unknown>" } @@ -234,6 +233,15 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { } } + // Get inode number. + var ino uint64 + stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_INO}) + if statErr != nil || stat.Mask&linux.STATX_INO == 0 { + log.Warningf("Failed to retrieve ino for socket file: %v", statErr) + } else { + ino = stat.Ino + } + // In the socket entry below, the value for the 'Num' field requires // some consideration. Linux prints the address to the struct // unix_sock representing a socket in the kernel, but may redact the @@ -252,14 +260,14 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { // the definition of this struct changes over time. // // For now, we always redact this pointer. - fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %5d", + fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %8d", (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct. - sfile.ReadRefs()-1, // RefCount, don't count our own ref. + s.Refs()-1, // RefCount, don't count our own ref. 0, // Protocol, always 0 for UDS. sockFlags, // Flags. sops.Endpoint().Type(), // Type. sops.State(), // State. - sfile.InodeID(), // Inode. + ino, // Inode. ) // Path @@ -341,15 +349,14 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, t := kernel.TaskFromContext(ctx) for _, se := range k.ListSockets() { - s := se.Sock.Get() - if s == nil { - log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID) + s := se.SockVFS2 + if !s.TryIncRef() { + log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s) continue } - sfile := s.(*fs.File) - sops, ok := sfile.FileOperations.(socket.Socket) + sops, ok := s.Impl().(socket.SocketVFS2) if !ok { - panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) + panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s)) } if fa, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) { s.DecRef() @@ -398,14 +405,15 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, // Unimplemented. fmt.Fprintf(buf, "%08X ", 0) + stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO}) + // Field: uid. - uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx) - if err != nil { - log.Warningf("Failed to retrieve unstable attr for socket file: %v", err) + if statErr != nil || stat.Mask&linux.STATX_UID == 0 { + log.Warningf("Failed to retrieve uid for socket file: %v", statErr) fmt.Fprintf(buf, "%5d ", 0) } else { creds := auth.CredentialsFromContext(ctx) - fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow())) + fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow())) } // Field: timeout; number of unanswered 0-window probes. @@ -413,11 +421,16 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, fmt.Fprintf(buf, "%8d ", 0) // Field: inode. - fmt.Fprintf(buf, "%8d ", sfile.InodeID()) + if statErr != nil || stat.Mask&linux.STATX_INO == 0 { + log.Warningf("Failed to retrieve inode for socket file: %v", statErr) + fmt.Fprintf(buf, "%8d ", 0) + } else { + fmt.Fprintf(buf, "%8d ", stat.Ino) + } // Field: refcount. Don't count the ref we obtain while deferencing // the weakref to this socket. - fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1) + fmt.Fprintf(buf, "%d ", s.Refs()-1) // Field: Socket struct address. Redacted due to the same reason as // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData. @@ -499,15 +512,14 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error { t := kernel.TaskFromContext(ctx) for _, se := range d.kernel.ListSockets() { - s := se.Sock.Get() - if s == nil { - log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID) + s := se.SockVFS2 + if !s.TryIncRef() { + log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s) continue } - sfile := s.(*fs.File) - sops, ok := sfile.FileOperations.(socket.Socket) + sops, ok := s.Impl().(socket.SocketVFS2) if !ok { - panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) + panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s)) } if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM { s.DecRef() @@ -551,25 +563,31 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error { // Field: retrnsmt. Always 0 for UDP. fmt.Fprintf(buf, "%08X ", 0) + stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO}) + // Field: uid. - uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx) - if err != nil { - log.Warningf("Failed to retrieve unstable attr for socket file: %v", err) + if statErr != nil || stat.Mask&linux.STATX_UID == 0 { + log.Warningf("Failed to retrieve uid for socket file: %v", statErr) fmt.Fprintf(buf, "%5d ", 0) } else { creds := auth.CredentialsFromContext(ctx) - fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow())) + fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow())) } // Field: timeout. Always 0 for UDP. fmt.Fprintf(buf, "%8d ", 0) // Field: inode. - fmt.Fprintf(buf, "%8d ", sfile.InodeID()) + if statErr != nil || stat.Mask&linux.STATX_INO == 0 { + log.Warningf("Failed to retrieve inode for socket file: %v", statErr) + fmt.Fprintf(buf, "%8d ", 0) + } else { + fmt.Fprintf(buf, "%8d ", stat.Ino) + } // Field: ref; reference count on the socket inode. Don't count the ref // we obtain while deferencing the weakref to this socket. - fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1) + fmt.Fprintf(buf, "%d ", s.Refs()-1) // Field: Socket struct address. Redacted due to the same reason as // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData. @@ -670,9 +688,9 @@ func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error { if line.prefix == "Tcp" { tcp := stat.(*inet.StatSNMPTCP) // "Tcp" needs special processing because MaxConn is signed. RFC 2012. - fmt.Sprintf("%s: %s %d %s\n", line.prefix, sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:])) + fmt.Fprintf(buf, "%s: %s %d %s\n", line.prefix, sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:])) } else { - fmt.Sprintf("%s: %s\n", line.prefix, sprintSlice(toSlice(stat))) + fmt.Fprintf(buf, "%s: %s\n", line.prefix, sprintSlice(toSlice(stat))) } } return nil diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 882c1981e..4621e2de0 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -63,6 +63,11 @@ func (s *selfSymlink) Readlink(ctx context.Context) (string, error) { return strconv.FormatUint(uint64(tgid), 10), nil } +func (s *selfSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) { + target, err := s.Readlink(ctx) + return vfs.VirtualDentry{}, target, err +} + // SetStat implements Inode.SetStat not allowing inode attributes to be changed. func (*selfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM @@ -101,6 +106,11 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) { return fmt.Sprintf("%d/task/%d", tgid, tid), nil } +func (s *threadSelfSymlink) Getlink(ctx context.Context) (vfs.VirtualDentry, string, error) { + target, err := s.Readlink(ctx) + return vfs.VirtualDentry{}, target, err +} + // SetStat implements Inode.SetStat not allowing inode attributes to be changed. func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM diff --git a/pkg/sentry/fsimpl/sockfs/BUILD b/pkg/sentry/fsimpl/sockfs/BUILD index 790d50e65..52084ddb5 100644 --- a/pkg/sentry/fsimpl/sockfs/BUILD +++ b/pkg/sentry/fsimpl/sockfs/BUILD @@ -7,6 +7,7 @@ go_library( srcs = ["sockfs.go"], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/abi/linux", "//pkg/context", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index c13511de2..3f7ad1d65 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -16,6 +16,7 @@ package sockfs import ( + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -60,6 +61,10 @@ type filesystem struct { } // inode implements kernfs.Inode. +// +// TODO(gvisor.dev/issue/1476): Add device numbers to this inode (which are +// not included in InodeAttrs) to store the numbers of the appropriate +// socket device. Override InodeAttrs.Stat() accordingly. type inode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink @@ -71,3 +76,27 @@ type inode struct { func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { return nil, syserror.ENXIO } + +// InitSocket initializes a socket FileDescription, with a corresponding +// Dentry in mnt. +// +// fd should be the FileDescription associated with socketImpl, i.e. its first +// field. mnt should be the global socket mount, Kernel.socketMount. +func InitSocket(socketImpl vfs.FileDescriptionImpl, fd *vfs.FileDescription, mnt *vfs.Mount, creds *auth.Credentials) error { + fsimpl := mnt.Filesystem().Impl() + fs := fsimpl.(*kernfs.Filesystem) + + // File mode matches net/socket.c:sock_alloc. + filemode := linux.FileMode(linux.S_IFSOCK | 0600) + i := &inode{} + i.Init(creds, fs.NextIno(), filemode) + + d := &kernfs.Dentry{} + d.Init(i) + + opts := &vfs.FileDescriptionOptions{UseDentryMetadata: true} + if err := fd.Init(socketImpl, linux.O_RDWR, mnt, d.VFSDentry(), opts); err != nil { + return err + } + return nil +} diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 6ea35affb..4e6cd3491 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -24,6 +24,7 @@ go_library( "filesystem.go", "named_pipe.go", "regular_file.go", + "socket_file.go", "symlink.go", "tmpfs.go", ], @@ -50,6 +51,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/vfs", "//pkg/sentry/vfs/lock", + "//pkg/sentry/vfs/memxattr", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index 37c75ab64..45712c9b9 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -68,6 +68,8 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fs.mu.Lock() defer fs.mu.Unlock() + fd.inode().touchAtime(fd.vfsfd.Mount()) + if fd.off == 0 { if err := cb.Handle(vfs.Dirent{ Name: ".", diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index e678ecc37..f4d50d64f 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -46,6 +46,9 @@ func stepLocked(rp *vfs.ResolvingPath, d *dentry) (*dentry, error) { return nil, err } afterSymlink: + if len(rp.Component()) > linux.NAME_MAX { + return nil, syserror.ENAMETOOLONG + } nextVFSD, err := rp.ResolveComponent(&d.vfsd) if err != nil { return nil, err @@ -57,7 +60,7 @@ afterSymlink: } next := nextVFSD.Impl().(*dentry) if symlink, ok := next.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { - // TODO(gvisor.dev/issues/1197): Symlink traversals updates + // TODO(gvisor.dev/issue/1197): Symlink traversals updates // access time. if err := rp.HandleSymlink(symlink.target); err != nil { return nil, err @@ -133,6 +136,9 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa if name == "." || name == ".." { return syserror.EEXIST } + if len(name) > linux.NAME_MAX { + return syserror.ENAMETOOLONG + } // Call parent.vfsd.Child() instead of stepLocked() or rp.ResolveChild(), // because if the child exists we want to return EEXIST immediately instead // of attempting symlink/mount traversal. @@ -153,7 +159,11 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa return err } defer mnt.EndWrite() - return create(parent, name) + if err := create(parent, name); err != nil { + return err + } + parent.inode.touchCMtime() + return nil } // AccessAt implements vfs.Filesystem.Impl.AccessAt. @@ -251,8 +261,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v case linux.S_IFCHR: childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor) case linux.S_IFSOCK: - // Not yet supported. - return syserror.EPERM + childInode = fs.newSocketFile(rp.Credentials(), opts.Mode, opts.Endpoint) default: return syserror.EINVAL } @@ -328,7 +337,12 @@ afterTrailingSymlink: child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode)) parent.vfsd.InsertChild(&child.vfsd, name) parent.inode.impl.(*directory).childList.PushBack(child) - return child.open(ctx, rp, &opts, true) + fd, err := child.open(ctx, rp, &opts, true) + if err != nil { + return nil, err + } + parent.inode.touchCMtime() + return fd, nil } if err != nil { return nil, err @@ -381,6 +395,8 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open return newNamedPipeFD(ctx, impl, rp, &d.vfsd, opts.Flags) case *deviceFile: return rp.VirtualFilesystem().OpenDeviceSpecialFile(ctx, rp.Mount(), &d.vfsd, impl.kind, impl.major, impl.minor, opts) + case *socketFile: + return nil, syserror.ENXIO default: panic(fmt.Sprintf("unknown inode type: %T", d.inode.impl)) } @@ -398,6 +414,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st if !ok { return "", syserror.EINVAL } + symlink.inode.touchAtime(rp.Mount()) return symlink.target, nil } @@ -515,7 +532,10 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa oldParent.inode.decLinksLocked() newParent.inode.incLinksLocked() } - // TODO(gvisor.dev/issues/1197): Update timestamps and parent directory + oldParent.inode.touchCMtime() + newParent.inode.touchCMtime() + renamed.inode.touchCtime() + // TODO(gvisor.dev/issue/1197): Update timestamps and parent directory // sizes. vfsObj.CommitRenameReplaceDentry(renamedVFSD, &newParent.vfsd, newName, replacedVFSD) return nil @@ -565,6 +585,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error parent.inode.decLinksLocked() // from child's ".." child.inode.decLinksLocked() vfsObj.CommitDeleteDentry(childVFSD) + parent.inode.touchCMtime() return nil } @@ -600,7 +621,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu if err != nil { return linux.Statfs{}, err } - // TODO(gvisor.dev/issues/1197): Actually implement statfs. + // TODO(gvisor.dev/issue/1197): Actually implement statfs. return linux.Statfs{}, syserror.ENOSYS } @@ -654,62 +675,68 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error parent.inode.impl.(*directory).childList.Remove(child) child.inode.decLinksLocked() vfsObj.CommitDeleteDentry(childVFSD) + parent.inode.touchCMtime() return nil } // BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. -// -// TODO(gvisor.dev/issue/1476): Implement BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) { - return nil, syserror.ECONNREFUSED + fs.mu.RLock() + defer fs.mu.RUnlock() + d, err := resolveLocked(rp) + if err != nil { + return nil, err + } + switch impl := d.inode.impl.(type) { + case *socketFile: + return impl.ep, nil + default: + return nil, syserror.ECONNREFUSED + } } // ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { fs.mu.RLock() defer fs.mu.RUnlock() - _, err := resolveLocked(rp) + d, err := resolveLocked(rp) if err != nil { return nil, err } - // TODO(b/127675828): support extended attributes - return nil, syserror.ENOTSUP + return d.inode.listxattr(size) } // GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { fs.mu.RLock() defer fs.mu.RUnlock() - _, err := resolveLocked(rp) + d, err := resolveLocked(rp) if err != nil { return "", err } - // TODO(b/127675828): support extended attributes - return "", syserror.ENOTSUP + return d.inode.getxattr(rp.Credentials(), &opts) } // SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { fs.mu.RLock() defer fs.mu.RUnlock() - _, err := resolveLocked(rp) + d, err := resolveLocked(rp) if err != nil { return err } - // TODO(b/127675828): support extended attributes - return syserror.ENOTSUP + return d.inode.setxattr(rp.Credentials(), &opts) } // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { fs.mu.RLock() defer fs.mu.RUnlock() - _, err := resolveLocked(rp) + d, err := resolveLocked(rp) if err != nil { return err } - // TODO(b/127675828): support extended attributes - return syserror.ENOTSUP + return d.inode.removexattr(rp.Credentials(), name) } // PrependPath implements vfs.FilesystemImpl.PrependPath. diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 26cd65605..57e5e28ec 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -286,7 +286,8 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs rw := getRegularFileReadWriter(f, offset) n, err := dst.CopyOutFrom(ctx, rw) putRegularFileReadWriter(rw) - return int64(n), err + fd.inode().touchAtime(fd.vfsfd.Mount()) + return n, err } // Read implements vfs.FileDescriptionImpl.Read. @@ -323,6 +324,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off f.inode.mu.Lock() rw := getRegularFileReadWriter(f, offset) n, err := src.CopyInTo(ctx, rw) + fd.inode().touchCMtimeLocked() f.inode.mu.Unlock() putRegularFileReadWriter(rw) return n, err diff --git a/pkg/tcpip/stack/packet_buffer_state.go b/pkg/sentry/fsimpl/tmpfs/socket_file.go index 0c6b7924c..25c2321af 100644 --- a/pkg/tcpip/stack/packet_buffer_state.go +++ b/pkg/sentry/fsimpl/tmpfs/socket_file.go @@ -12,16 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -package stack +package tmpfs -import "gvisor.dev/gvisor/pkg/tcpip/buffer" +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" +) -// beforeSave is invoked by stateify. -func (pk *PacketBuffer) beforeSave() { - // Non-Data fields may be slices of the Data field. This causes - // problems for SR, so during save we make each header independent. - pk.Header = pk.Header.DeepCopy() - pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) - pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) - pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) +// socketFile is a socket (=S_IFSOCK) tmpfs file. +type socketFile struct { + inode inode + ep transport.BoundEndpoint +} + +func (fs *filesystem) newSocketFile(creds *auth.Credentials, mode linux.FileMode, ep transport.BoundEndpoint) *inode { + file := &socketFile{ep: ep} + file.inode.init(file, fs, creds, mode) + file.inode.nlink = 1 // from parent directory + return &file.inode } diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go index ebe035dee..d4f59ee5b 100644 --- a/pkg/sentry/fsimpl/tmpfs/stat_test.go +++ b/pkg/sentry/fsimpl/tmpfs/stat_test.go @@ -29,7 +29,7 @@ func TestStatAfterCreate(t *testing.T) { mode := linux.FileMode(0644) // Run with different file types. - // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets. + // TODO(gvisor.dev/issue/1197): Also test symlinks and sockets. for _, typ := range []string{"file", "dir", "pipe"} { t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) { var ( @@ -140,7 +140,7 @@ func TestSetStatAtime(t *testing.T) { Mask: 0, Atime: linux.NsecToStatxTimestamp(100), }}); err != nil { - t.Errorf("SetStat atime without mask failed: %v") + t.Errorf("SetStat atime without mask failed: %v", err) } // Atime should be unchanged. if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { @@ -155,7 +155,7 @@ func TestSetStatAtime(t *testing.T) { Atime: linux.NsecToStatxTimestamp(100), } if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil { - t.Errorf("SetStat atime with mask failed: %v") + t.Errorf("SetStat atime with mask failed: %v", err) } if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { t.Errorf("Stat got error: %v", err) @@ -169,7 +169,7 @@ func TestSetStat(t *testing.T) { mode := linux.FileMode(0644) // Run with different file types. - // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets. + // TODO(gvisor.dev/issue/1197): Also test symlinks and sockets. for _, typ := range []string{"file", "dir", "pipe"} { t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) { var ( @@ -205,7 +205,7 @@ func TestSetStat(t *testing.T) { Mask: 0, Atime: linux.NsecToStatxTimestamp(100), }}); err != nil { - t.Errorf("SetStat atime without mask failed: %v") + t.Errorf("SetStat atime without mask failed: %v", err) } // Atime should be unchanged. if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { @@ -220,7 +220,7 @@ func TestSetStat(t *testing.T) { Atime: linux.NsecToStatxTimestamp(100), } if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil { - t.Errorf("SetStat atime with mask failed: %v") + t.Errorf("SetStat atime with mask failed: %v", err) } if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { t.Errorf("Stat got error: %v", err) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 8bc8818c0..9fa8637d5 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -27,6 +27,7 @@ package tmpfs import ( "fmt" "math" + "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -37,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/vfs/lock" + "gvisor.dev/gvisor/pkg/sentry/vfs/memxattr" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -186,6 +188,11 @@ type inode struct { // filesystem.RmdirAt() drops the reference. refs int64 + // xattrs implements extended attributes. + // + // TODO(b/148380782): Support xattrs other than user.* + xattrs memxattr.SimpleExtendedAttributes + // Inode metadata. Writing multiple fields atomically requires holding // mu, othewise atomic operations can be used. mu sync.Mutex @@ -315,7 +322,7 @@ func (i *inode) statTo(stat *linux.Statx) { stat.Atime = linux.NsecToStatxTimestamp(i.atime) stat.Ctime = linux.NsecToStatxTimestamp(i.ctime) stat.Mtime = linux.NsecToStatxTimestamp(i.mtime) - // TODO(gvisor.dev/issues/1197): Device number. + // TODO(gvisor.dev/issue/1197): Device number. switch impl := i.impl.(type) { case *regularFile: stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS @@ -331,7 +338,7 @@ func (i *inode) statTo(stat *linux.Statx) { case *deviceFile: stat.RdevMajor = impl.major stat.RdevMinor = impl.minor - case *directory, *namedPipe: + case *socketFile, *directory, *namedPipe: // Nothing to do. default: panic(fmt.Sprintf("unknown inode type: %T", i.impl)) @@ -385,28 +392,41 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu return syserror.EINVAL } } + now := i.clock.Now().Nanoseconds() if mask&linux.STATX_ATIME != 0 { - atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped()) + if stat.Atime.Nsec == linux.UTIME_NOW { + atomic.StoreInt64(&i.atime, now) + } else { + atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped()) + } needsCtimeBump = true } if mask&linux.STATX_MTIME != 0 { - atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped()) + if stat.Mtime.Nsec == linux.UTIME_NOW { + atomic.StoreInt64(&i.mtime, now) + } else { + atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped()) + } needsCtimeBump = true // Ignore the mtime bump, since we just set it ourselves. needsMtimeBump = false } if mask&linux.STATX_CTIME != 0 { - atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped()) + if stat.Ctime.Nsec == linux.UTIME_NOW { + atomic.StoreInt64(&i.ctime, now) + } else { + atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped()) + } // Ignore the ctime bump, since we just set it ourselves. needsCtimeBump = false } - now := i.clock.Now().Nanoseconds() if needsMtimeBump { atomic.StoreInt64(&i.mtime, now) } if needsCtimeBump { atomic.StoreInt64(&i.ctime, now) } + i.mu.Unlock() return nil } @@ -466,6 +486,8 @@ func (i *inode) direntType() uint8 { return linux.DT_DIR case *symlink: return linux.DT_LNK + case *socketFile: + return linux.DT_SOCK case *deviceFile: switch impl.kind { case vfs.BlockDevice: @@ -484,6 +506,92 @@ func (i *inode) isDir() bool { return linux.FileMode(i.mode).FileType() == linux.S_IFDIR } +func (i *inode) touchAtime(mnt *vfs.Mount) { + if err := mnt.CheckBeginWrite(); err != nil { + return + } + now := i.clock.Now().Nanoseconds() + i.mu.Lock() + atomic.StoreInt64(&i.atime, now) + i.mu.Unlock() + mnt.EndWrite() +} + +// Preconditions: The caller has called vfs.Mount.CheckBeginWrite(). +func (i *inode) touchCtime() { + now := i.clock.Now().Nanoseconds() + i.mu.Lock() + atomic.StoreInt64(&i.ctime, now) + i.mu.Unlock() +} + +// Preconditions: The caller has called vfs.Mount.CheckBeginWrite(). +func (i *inode) touchCMtime() { + now := i.clock.Now().Nanoseconds() + i.mu.Lock() + atomic.StoreInt64(&i.mtime, now) + atomic.StoreInt64(&i.ctime, now) + i.mu.Unlock() +} + +// Preconditions: The caller has called vfs.Mount.CheckBeginWrite() and holds +// inode.mu. +func (i *inode) touchCMtimeLocked() { + now := i.clock.Now().Nanoseconds() + atomic.StoreInt64(&i.mtime, now) + atomic.StoreInt64(&i.ctime, now) +} + +func (i *inode) listxattr(size uint64) ([]string, error) { + return i.xattrs.Listxattr(size) +} + +func (i *inode) getxattr(creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) { + if err := i.checkPermissions(creds, vfs.MayRead); err != nil { + return "", err + } + if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { + return "", syserror.EOPNOTSUPP + } + if !i.userXattrSupported() { + return "", syserror.ENODATA + } + return i.xattrs.Getxattr(opts) +} + +func (i *inode) setxattr(creds *auth.Credentials, opts *vfs.SetxattrOptions) error { + if err := i.checkPermissions(creds, vfs.MayWrite); err != nil { + return err + } + if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + if !i.userXattrSupported() { + return syserror.EPERM + } + return i.xattrs.Setxattr(opts) +} + +func (i *inode) removexattr(creds *auth.Credentials, name string) error { + if err := i.checkPermissions(creds, vfs.MayWrite); err != nil { + return err + } + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + if !i.userXattrSupported() { + return syserror.EPERM + } + return i.xattrs.Removexattr(name) +} + +// Extended attributes in the user.* namespace are only supported for regular +// files and directories. +func (i *inode) userXattrSupported() bool { + filetype := linux.S_IFMT & atomic.LoadUint32(&i.mode) + return filetype == linux.S_IFREG || filetype == linux.S_IFDIR +} + // fileDescription is embedded by tmpfs implementations of // vfs.FileDescriptionImpl. type fileDescription struct { @@ -511,3 +619,23 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) creds := auth.CredentialsFromContext(ctx) return fd.inode().setStat(ctx, creds, &opts.Stat) } + +// Listxattr implements vfs.FileDescriptionImpl.Listxattr. +func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { + return fd.inode().listxattr(size) +} + +// Getxattr implements vfs.FileDescriptionImpl.Getxattr. +func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) { + return fd.inode().getxattr(auth.CredentialsFromContext(ctx), &opts) +} + +// Setxattr implements vfs.FileDescriptionImpl.Setxattr. +func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { + return fd.inode().setxattr(auth.CredentialsFromContext(ctx), &opts) +} + +// Removexattr implements vfs.FileDescriptionImpl.Removexattr. +func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { + return fd.inode().removexattr(auth.CredentialsFromContext(ctx), name) +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index bb7e3cbc3..e0ff58d8c 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -169,6 +169,7 @@ go_library( "//pkg/sentry/fs/lock", "//pkg/sentry/fs/timerfd", "//pkg/sentry/fsbridge", + "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/hostcpu", "//pkg/sentry/inet", diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index d09d97825..ed40b5303 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -307,6 +307,61 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return fds, nil } +// NewFDsVFS2 allocates new FDs guaranteed to be the lowest number available +// greater than or equal to the fd parameter. All files will share the set +// flags. Success is guaranteed to be all or none. +func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) { + if fd < 0 { + // Don't accept negative FDs. + return nil, syscall.EINVAL + } + + // Default limit. + end := int32(math.MaxInt32) + + // Ensure we don't get past the provided limit. + if limitSet := limits.FromContext(ctx); limitSet != nil { + lim := limitSet.Get(limits.NumberOfFiles) + if lim.Cur != limits.Infinity { + end = int32(lim.Cur) + } + if fd >= end { + return nil, syscall.EMFILE + } + } + + f.mu.Lock() + defer f.mu.Unlock() + + // From f.next to find available fd. + if fd < f.next { + fd = f.next + } + + // Install all entries. + for i := fd; i < end && len(fds) < len(files); i++ { + if d, _, _ := f.getVFS2(i); d == nil { + f.setVFS2(i, files[len(fds)], flags) // Set the descriptor. + fds = append(fds, i) // Record the file descriptor. + } + } + + // Failure? Unwind existing FDs. + if len(fds) < len(files) { + for _, i := range fds { + f.setVFS2(i, nil, FDFlags{}) // Zap entry. + } + return nil, syscall.EMFILE + } + + if fd == f.next { + // Update next search start position. + f.next = fds[len(fds)-1] + 1 + } + + return fds, nil +} + // NewFDVFS2 allocates a file descriptor greater than or equal to minfd for // the given file description. If it succeeds, it takes a reference on file. func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 0a448b57c..de8a95854 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -564,15 +564,25 @@ func (ts *TaskSet) unregisterEpollWaiters() { ts.mu.RLock() defer ts.mu.RUnlock() + + // Tasks that belong to the same process could potentially point to the + // same FDTable. So we retain a map of processed ones to avoid + // processing the same FDTable multiple times. + processed := make(map[*FDTable]struct{}) for t := range ts.Root.tids { // We can skip locking Task.mu here since the kernel is paused. - if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { - if e, ok := file.FileOperations.(*epoll.EventPoll); ok { - e.UnregisterEpollWaiters() - } - }) + if t.fdTable == nil { + continue + } + if _, ok := processed[t.fdTable]; ok { + continue } + t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + if e, ok := file.FileOperations.(*epoll.EventPoll); ok { + e.UnregisterEpollWaiters() + } + }) + processed[t.fdTable] = struct{}{} } } @@ -1034,14 +1044,17 @@ func (k *Kernel) pauseTimeLocked() { // This means we'll iterate FDTables shared by multiple tasks repeatedly, // but ktime.Timer.Pause is idempotent so this is harmless. if t.fdTable != nil { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if !VFS2Enabled { - t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) { + if VFS2Enabled { + if tfd, ok := fd.Impl().(*vfs.TimerFileDescription); ok { + tfd.PauseTimer() + } + } else { if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok { tfd.PauseTimer() } - }) - } + } + }) } } k.timekeeper.PauseUpdates() @@ -1066,15 +1079,18 @@ func (k *Kernel) resumeTimeLocked() { it.ResumeTimer() } } - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if !VFS2Enabled { - if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + if t.fdTable != nil { + t.fdTable.forEach(func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) { + if VFS2Enabled { + if tfd, ok := fd.Impl().(*vfs.TimerFileDescription); ok { + tfd.ResumeTimer() + } + } else { if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok { tfd.ResumeTimer() } - }) - } + } + }) } } } @@ -1435,9 +1451,10 @@ func (k *Kernel) SupervisorContext() context.Context { // +stateify savable type SocketEntry struct { socketEntry - k *Kernel - Sock *refs.WeakRef - ID uint64 // Socket table entry number. + k *Kernel + Sock *refs.WeakRef + SockVFS2 *vfs.FileDescription + ID uint64 // Socket table entry number. } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. @@ -1460,7 +1477,30 @@ func (k *Kernel) RecordSocket(sock *fs.File) { k.extMu.Unlock() } +// RecordSocketVFS2 adds a VFS2 socket to the system-wide socket table for +// tracking. +// +// Precondition: Caller must hold a reference to sock. +// +// Note that the socket table will not hold a reference on the +// vfs.FileDescription, because we do not support weak refs on VFS2 files. +func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) { + k.extMu.Lock() + id := k.nextSocketEntry + k.nextSocketEntry++ + s := &SocketEntry{ + k: k, + ID: id, + SockVFS2: sock, + } + k.sockets.PushBack(s) + k.extMu.Unlock() +} + // ListSockets returns a snapshot of all sockets. +// +// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef() +// to get a reference on a socket in the table. func (k *Kernel) ListSockets() []*SocketEntry { k.extMu.Lock() var socks []*SocketEntry diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 725e9db7d..62c8691f1 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -255,7 +255,8 @@ func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be // atomic, but requires no atomicity for writes larger than this. wanted := ops.left() - if avail := p.max - p.view.Size(); wanted > avail { + avail := p.max - p.view.Size() + if wanted > avail { if wanted <= p.atomicIOBytes { return 0, syserror.ErrWouldBlock } @@ -268,8 +269,14 @@ func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { return done, err } - if wanted > done { - // Partial write due to full pipe. + if done < avail { + // Non-failure, but short write. + return done, nil + } + if done < wanted { + // Partial write due to full pipe. Note that this could also be + // the short write case above, we would expect a second call + // and the write to return zero bytes in this case. return done, syserror.ErrWouldBlock } diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 35ad97d5d..e23e796ef 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -184,7 +184,6 @@ func (t *Task) CanTrace(target *Task, attach bool) bool { if targetCreds.PermittedCaps&^callerCreds.PermittedCaps != 0 { return false } - // TODO: Yama LSM return true } diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 208569057..f66cfcc7f 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -461,7 +461,7 @@ func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.A func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) { s.mu.Lock() defer s.mu.Unlock() - // TODO(b/38173783): RemoveMapping may be called during task exit, when ctx + // RemoveMapping may be called during task exit, when ctx // is context.Background. Gracefully handle missing clocks. Failing to // update the detach time in these cases is ok, since no one can observe the // omission. diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 93c4fe969..2e3565747 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -209,65 +209,61 @@ type Stracer interface { // SyscallEnter is called on syscall entry. // // The returned private data is passed to SyscallExit. - // - // TODO(gvisor.dev/issue/155): remove kernel imports from the strace - // package so that the type can be used directly. SyscallEnter(t *Task, sysno uintptr, args arch.SyscallArguments, flags uint32) interface{} // SyscallExit is called on syscall exit. SyscallExit(context interface{}, t *Task, sysno, rval uintptr, err error) } -// SyscallTable is a lookup table of system calls. Critically, a SyscallTable -// is *immutable*. In order to make supporting suspend and resume sane, they -// must be uniquely registered and may not change during operation. +// SyscallTable is a lookup table of system calls. // -// +stateify savable +// Note that a SyscallTable is not savable directly. Instead, they are saved as +// an OS/Arch pair and lookup happens again on restore. type SyscallTable struct { // OS is the operating system that this syscall table implements. - OS abi.OS `state:"wait"` + OS abi.OS // Arch is the architecture that this syscall table targets. - Arch arch.Arch `state:"wait"` + Arch arch.Arch // The OS version that this syscall table implements. - Version Version `state:"manual"` + Version Version // AuditNumber is a numeric constant that represents the syscall table. If // non-zero, auditNumber must be one of the AUDIT_ARCH_* values defined by // linux/audit.h. - AuditNumber uint32 `state:"manual"` + AuditNumber uint32 // Table is the collection of functions. - Table map[uintptr]Syscall `state:"manual"` + Table map[uintptr]Syscall // lookup is a fixed-size array that holds the syscalls (indexed by // their numbers). It is used for fast look ups. - lookup []SyscallFn `state:"manual"` + lookup []SyscallFn // Emulate is a collection of instruction addresses to emulate. The // keys are addresses, and the values are system call numbers. - Emulate map[usermem.Addr]uintptr `state:"manual"` + Emulate map[usermem.Addr]uintptr // The function to call in case of a missing system call. - Missing MissingFn `state:"manual"` + Missing MissingFn // Stracer traces this syscall table. - Stracer Stracer `state:"manual"` + Stracer Stracer // External is used to handle an external callback. - External func(*Kernel) `state:"manual"` + External func(*Kernel) // ExternalFilterBefore is called before External is called before the syscall is executed. // External is not called if it returns false. - ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"` + ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool // ExternalFilterAfter is called before External is called after the syscall is executed. // External is not called if it returns false. - ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"` + ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool // FeatureEnable stores the strace and one-shot enable bits. - FeatureEnable SyscallFlagsTable `state:"manual"` + FeatureEnable SyscallFlagsTable } // allSyscallTables contains all known tables. diff --git a/pkg/sentry/kernel/syscalls_state.go b/pkg/sentry/kernel/syscalls_state.go index 00358326b..90f890495 100644 --- a/pkg/sentry/kernel/syscalls_state.go +++ b/pkg/sentry/kernel/syscalls_state.go @@ -14,16 +14,34 @@ package kernel -import "fmt" +import ( + "fmt" -// afterLoad is invoked by stateify. -func (s *SyscallTable) afterLoad() { - otherTable, ok := LookupSyscallTable(s.OS, s.Arch) - if !ok { - // Couldn't find a reference? - panic(fmt.Sprintf("syscall table not found for OS %v Arch %v", s.OS, s.Arch)) + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/sentry/arch" +) + +// syscallTableInfo is used to reload the SyscallTable. +// +// +stateify savable +type syscallTableInfo struct { + OS abi.OS + Arch arch.Arch +} + +// saveSt saves the SyscallTable. +func (tc *TaskContext) saveSt() syscallTableInfo { + return syscallTableInfo{ + OS: tc.st.OS, + Arch: tc.st.Arch, } +} - // Copy the table. - *s = *otherTable +// loadSt loads the SyscallTable. +func (tc *TaskContext) loadSt(sti syscallTableInfo) { + st, ok := LookupSyscallTable(sti.OS, sti.Arch) + if !ok { + panic(fmt.Sprintf("syscall table not found for OS %v, Arch %v", sti.OS, sti.Arch)) + } + tc.st = st // Save the table reference. } diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index d6546735e..e5d133d6c 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -777,6 +777,15 @@ func (t *Task) NewFDs(fd int32, files []*fs.File, flags FDFlags) ([]int32, error return t.fdTable.NewFDs(t, fd, files, flags) } +// NewFDsVFS2 is a convenience wrapper for t.FDTable().NewFDsVFS2. +// +// This automatically passes the task as the context. +// +// Precondition: same as FDTable. +func (t *Task) NewFDsVFS2(fd int32, files []*vfs.FileDescription, flags FDFlags) ([]int32, error) { + return t.fdTable.NewFDsVFS2(t, fd, files, flags) +} + // NewFDFrom is a convenience wrapper for t.FDTable().NewFDs with a single file. // // This automatically passes the task as the context. diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 0158b1788..9fa528384 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -49,7 +49,7 @@ type TaskContext struct { fu *futex.Manager // st is the task's syscall table. - st *SyscallTable + st *SyscallTable `state:".(syscallTableInfo)"` } // release releases all resources held by the TaskContext. release is called by @@ -58,7 +58,6 @@ func (tc *TaskContext) release() { // Nil out pointers so that if the task is saved after release, it doesn't // follow the pointers to possibly now-invalid objects. if tc.MemoryManager != nil { - // TODO(b/38173783) tc.MemoryManager.DecUsers(context.Background()) tc.MemoryManager = nil } diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go index ce3e6ef28..0325967e4 100644 --- a/pkg/sentry/kernel/task_identity.go +++ b/pkg/sentry/kernel/task_identity.go @@ -455,7 +455,7 @@ func (t *Task) SetKeepCaps(k bool) { t.creds.Store(creds) } -// updateCredsForExec updates t.creds to reflect an execve(). +// updateCredsForExecLocked updates t.creds to reflect an execve(). // // NOTE(b/30815691): We currently do not implement privileged executables // (set-user/group-ID bits and file capabilities). This allows us to make a lot diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 8802db142..6aa798346 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -513,8 +513,6 @@ func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool { if t.stop != nil { return false } - // - TODO(b/38173783): No special case for when t is also the sending task, - // because the identity of the sender is unknown. // - Do not choose tasks that have already been interrupted, as they may be // busy handling another signal. if len(t.interruptChan) != 0 { diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index 706de83ef..e959700f2 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -245,7 +245,7 @@ type Clock interface { type WallRateClock struct{} // WallTimeUntil implements Clock.WallTimeUntil. -func (WallRateClock) WallTimeUntil(t, now Time) time.Duration { +func (*WallRateClock) WallTimeUntil(t, now Time) time.Duration { return t.Sub(now) } @@ -254,16 +254,16 @@ func (WallRateClock) WallTimeUntil(t, now Time) time.Duration { type NoClockEvents struct{} // Readiness implements waiter.Waitable.Readiness. -func (NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask { +func (*NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask { return 0 } // EventRegister implements waiter.Waitable.EventRegister. -func (NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (*NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) { } // EventUnregister implements waiter.Waitable.EventUnregister. -func (NoClockEvents) EventUnregister(e *waiter.Entry) { +func (*NoClockEvents) EventUnregister(e *waiter.Entry) { } // ClockEventsQueue implements waiter.Waitable by wrapping waiter.Queue and @@ -273,7 +273,7 @@ type ClockEventsQueue struct { } // Readiness implements waiter.Waitable.Readiness. -func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask { +func (*ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask { return 0 } diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go index 0332fc71c..5c667117c 100644 --- a/pkg/sentry/mm/address_space.go +++ b/pkg/sentry/mm/address_space.go @@ -201,8 +201,10 @@ func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, pre if pma.needCOW { perms.Write = false } - if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil { - return err + if perms.Any() { // MapFile precondition + if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil { + return err + } } pseg = pseg.NextSegment() } diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index cb29d94b0..379148903 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -59,25 +59,27 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool { } a.contexts[id] = &AIOContext{ - done: make(chan struct{}, 1), + requestReady: make(chan struct{}, 1), maxOutstanding: events, } return true } -// destroyAIOContext destroys an asynchronous I/O context. +// destroyAIOContext destroys an asynchronous I/O context. It doesn't wait for +// for pending requests to complete. Returns the destroyed AIOContext so it can +// be drained. // -// False is returned if the context does not exist. -func (a *aioManager) destroyAIOContext(id uint64) bool { +// Nil is returned if the context does not exist. +func (a *aioManager) destroyAIOContext(id uint64) *AIOContext { a.mu.Lock() defer a.mu.Unlock() ctx, ok := a.contexts[id] if !ok { - return false + return nil } delete(a.contexts, id) ctx.destroy() - return true + return ctx } // lookupAIOContext looks up the given context. @@ -102,8 +104,8 @@ type ioResult struct { // // +stateify savable type AIOContext struct { - // done is the notification channel used for all requests. - done chan struct{} `state:"nosave"` + // requestReady is the notification channel used for all requests. + requestReady chan struct{} `state:"nosave"` // mu protects below. mu sync.Mutex `state:"nosave"` @@ -129,8 +131,14 @@ func (ctx *AIOContext) destroy() { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.dead = true - if ctx.outstanding == 0 { - close(ctx.done) + ctx.checkForDone() +} + +// Preconditions: ctx.mu must be held by caller. +func (ctx *AIOContext) checkForDone() { + if ctx.dead && ctx.outstanding == 0 { + close(ctx.requestReady) + ctx.requestReady = nil } } @@ -154,11 +162,12 @@ func (ctx *AIOContext) PopRequest() (interface{}, bool) { // Is there anything ready? if e := ctx.results.Front(); e != nil { - ctx.results.Remove(e) - ctx.outstanding-- - if ctx.outstanding == 0 && ctx.dead { - close(ctx.done) + if ctx.outstanding == 0 { + panic("AIOContext outstanding is going negative") } + ctx.outstanding-- + ctx.results.Remove(e) + ctx.checkForDone() return e.data, true } return nil, false @@ -172,26 +181,58 @@ func (ctx *AIOContext) FinishRequest(data interface{}) { // Push to the list and notify opportunistically. The channel notify // here is guaranteed to be safe because outstanding must be non-zero. - // The done channel is only closed when outstanding reaches zero. + // The requestReady channel is only closed when outstanding reaches zero. ctx.results.PushBack(&ioResult{data: data}) select { - case ctx.done <- struct{}{}: + case ctx.requestReady <- struct{}{}: default: } } // WaitChannel returns a channel that is notified when an AIO request is -// completed. -// -// The boolean return value indicates whether or not the context is active. -func (ctx *AIOContext) WaitChannel() (chan struct{}, bool) { +// completed. Returns nil if the context is destroyed and there are no more +// outstanding requests. +func (ctx *AIOContext) WaitChannel() chan struct{} { ctx.mu.Lock() defer ctx.mu.Unlock() - if ctx.outstanding == 0 && ctx.dead { - return nil, false + return ctx.requestReady +} + +// Dead returns true if the context has been destroyed. +func (ctx *AIOContext) Dead() bool { + ctx.mu.Lock() + defer ctx.mu.Unlock() + return ctx.dead +} + +// CancelPendingRequest forgets about a request that hasn't yet completed. +func (ctx *AIOContext) CancelPendingRequest() { + ctx.mu.Lock() + defer ctx.mu.Unlock() + + if ctx.outstanding == 0 { + panic("AIOContext outstanding is going negative") } - return ctx.done, true + ctx.outstanding-- + ctx.checkForDone() +} + +// Drain drops all completed requests. Pending requests remain untouched. +func (ctx *AIOContext) Drain() { + ctx.mu.Lock() + defer ctx.mu.Unlock() + + if ctx.outstanding == 0 { + return + } + size := uint32(ctx.results.Len()) + if ctx.outstanding < size { + panic("AIOContext outstanding is going negative") + } + ctx.outstanding -= size + ctx.results.Reset() + ctx.checkForDone() } // aioMappable implements memmap.MappingIdentity and memmap.Mappable for AIO @@ -332,9 +373,9 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint Length: aioRingBufferSize, MappingIdentity: m, Mappable: m, - // TODO(fvoznika): Linux does "do_mmap_pgoff(..., PROT_READ | - // PROT_WRITE, ...)" in fs/aio.c:aio_setup_ring(); why do we make this - // mapping read-only? + // Linux uses "do_mmap_pgoff(..., PROT_READ | PROT_WRITE, ...)" in + // fs/aio.c:aio_setup_ring(). Since we don't implement AIO_RING_MAGIC, + // user mode should not write to this page. Perms: usermem.Read, MaxPerms: usermem.Read, }) @@ -349,11 +390,11 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint return id, nil } -// DestroyAIOContext destroys an asynchronous I/O context. It returns false if -// the context does not exist. -func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) bool { +// DestroyAIOContext destroys an asynchronous I/O context. It returns the +// destroyed context. nil if the context does not exist. +func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext { if _, ok := mm.LookupAIOContext(ctx, id); !ok { - return false + return nil } // Only unmaps after it assured that the address is a valid aio context to diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go index c37fc9f7b..3dabac1af 100644 --- a/pkg/sentry/mm/aio_context_state.go +++ b/pkg/sentry/mm/aio_context_state.go @@ -16,5 +16,5 @@ package mm // afterLoad is invoked by stateify. func (a *AIOContext) afterLoad() { - a.done = make(chan struct{}, 1) + a.requestReady = make(chan struct{}, 1) } diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index cdcad0fdc..b69520030 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -71,8 +71,8 @@ go_library( "lib_amd64.go", "lib_amd64.s", "lib_arm64.go", - "lib_arm64_unsafe.go", "lib_arm64.s", + "lib_arm64_unsafe.go", "ring0.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index f14c336b9..7ac38764d 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -300,7 +300,7 @@ type SocketOperations struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil { + if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { return nil, syserr.TranslateNetstackError(err) } } @@ -535,7 +535,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } if resCh != nil { - t := ctx.(*kernel.Task) + t := kernel.TaskFromContext(ctx) if err := t.Block(resCh); err != nil { return 0, syserr.FromError(err).ToError() } @@ -608,7 +608,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader } if resCh != nil { - t := ctx.(*kernel.Task) + t := kernel.TaskFromContext(ctx) if err := t.Block(resCh); err != nil { return 0, syserr.FromError(err).ToError() } @@ -663,7 +663,7 @@ func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error // This is a hack to work around the fact that both IPv4 and IPv6 ANY are // represented by the empty string. // -// TODO(gvisor.dev/issues/1556): remove this function. +// TODO(gvisor.dev/issue/1556): remove this function. func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress { if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET { addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" @@ -940,7 +940,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -965,8 +965,15 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, return nil, syserr.ErrProtocolNotAvailable } +func boolToInt32(v bool) int32 { + if v { + return 1 + } + return 0 +} + // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -998,12 +1005,11 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var v tcpip.PasscredOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.PasscredOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1042,24 +1048,22 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var v tcpip.ReuseAddressOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.ReusePortOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReusePortOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1089,24 +1093,22 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var v tcpip.BroadcastOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.BroadcastOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.KeepaliveEnabledOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { @@ -1156,47 +1158,41 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptInt(tcpip.DelayOption) + v, err := ep.GetSockOptBool(tcpip.DelayOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - - if v == 0 { - return int32(1), nil - } - return int32(0), nil + return boolToInt32(!v), nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.CorkOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.CorkOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.QuickAckOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.QuickAckOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.MaxSegOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MaxSegOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1328,11 +1324,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1342,8 +1334,8 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if outLen == 0 { return make([]byte, 0), nil } - var v tcpip.IPv6TrafficClassOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1365,12 +1357,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil default: emitUnimplementedEventIPv6(t, name) @@ -1386,8 +1373,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.TTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.TTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1403,8 +1390,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastTTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1429,23 +1416,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastLoopOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - if v { - return int32(1), nil - } - return int32(0), nil + return boolToInt32(v), nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { return []byte(nil), nil } - var v tcpip.IPv4TOSOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { @@ -1462,11 +1445,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1477,11 +1456,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil default: emitUnimplementedEventIP(t, name) @@ -1541,7 +1516,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa // SetSockOpt can be used to implement the linux syscall setsockopt(2) for // sockets backed by a commonEndpoint. -func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error { +func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error { switch level { case linux.SOL_SOCKET: return setSockOptSocket(t, s, ep, name, optVal) @@ -1568,7 +1543,7 @@ func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, n } // setSockOptSocket implements SetSockOpt when level is SOL_SOCKET. -func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { switch name { case linux.SO_SNDBUF: if len(optVal) < sizeOfInt32 { @@ -1592,7 +1567,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1600,7 +1575,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) @@ -1628,7 +1603,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0)) case linux.SO_PASSCRED: if len(optVal) < sizeOfInt32 { @@ -1636,7 +1611,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { @@ -1644,7 +1619,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0)) case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { @@ -1716,11 +1691,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - var o int - if v == 0 { - o = 1 - } - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o)) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -1728,7 +1699,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -1736,7 +1707,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -1744,7 +1715,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v))) case linux.TCP_KEEPIDLE: if len(optVal) < sizeOfInt32 { @@ -1855,7 +1826,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) if v == -1 { v = 0 } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v))) case linux.IPV6_RECVTCLASS: v, err := parseIntOrChar(optVal) @@ -1940,7 +1911,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if v < 0 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v))) case linux.IP_ADD_MEMBERSHIP: req, err := copyInMulticastRequest(optVal, false /* allowAddr */) @@ -1987,9 +1958,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt( - tcpip.MulticastLoopOption(v != 0), - )) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2008,7 +1977,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } else if v < 1 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v))) case linux.IP_TOS: if len(optVal) == 0 { @@ -2018,7 +1987,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v))) case linux.IP_RECVTOS: v, err := parseIntOrChar(optVal) diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index eb090e79b..c3f04b613 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -62,10 +62,6 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in } case linux.SOCK_RAW: - // TODO(b/142504697): "In order to create a raw socket, a - // process must have the CAP_NET_RAW capability in the user - // namespace that governs its network namespace." - raw(7) - // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { @@ -141,10 +137,6 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* } func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { - // TODO(b/142504697): "In order to create a packet socket, a process - // must have the CAP_NET_RAW capability in the user namespace that - // governs its network namespace." - packet(7) - // Packet sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(t) if !creds.HasCapability(linux.CAP_NET_RAW) { diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index b5ba4a56b..6580bd6e9 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -269,7 +269,7 @@ func NewVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*v return nil, err } if s != nil { - // TODO: Add vfs2 sockets to global socket table. + t.Kernel().RecordSocketVFS2(s) return s, nil } } @@ -291,7 +291,9 @@ func PairVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (* return nil, nil, err } if s1 != nil && s2 != nil { - // TODO: Add vfs2 sockets to global socket table. + k := t.Kernel() + k.RecordSocketVFS2(s1) + k.RecordSocketVFS2(s2) return s1, s2, nil } } diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 08743deba..de2cc4bdf 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -8,23 +8,27 @@ go_library( "device.go", "io.go", "unix.go", + "unix_vfs2.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/fspath", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", "//pkg/sentry/socket", "//pkg/sentry/socket/control", "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", + "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 74bcd6300..c708b6030 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/ilist", + "//pkg/log", "//pkg/refs", "//pkg/sync", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2ef654235..1f3880cc5 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" @@ -838,24 +839,45 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. Currently not supported. func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { - case tcpip.PasscredOption: - e.setPasscred(v != 0) - return nil - } return nil } func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.BroadcastOption: + case tcpip.PasscredOption: + e.setPasscred(v) + case tcpip.ReuseAddressOption: + default: + log.Warningf("Unsupported socket option: %d", opt) + return tcpip.ErrUnknownProtocolOption + } return nil } func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.SendBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: + default: + log.Warningf("Unsupported socket option: %d", opt) + return tcpip.ErrUnknownProtocolOption + } return nil } func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.PasscredOption: + return e.Passcred(), nil + + default: + log.Warningf("Unsupported socket option: %d", opt) + return false, tcpip.ErrUnknownProtocolOption + } } func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { @@ -914,29 +936,19 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return int(v), nil default: + log.Warningf("Unsupported socket option: %d", opt) return -1, tcpip.ErrUnknownProtocolOption } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) - } - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: + log.Warningf("Unsupported socket option: %T", opt) return tcpip.ErrUnknownProtocolOption } } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 4d30aa714..7c64f30fa 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -33,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -52,11 +54,8 @@ type SocketOperations struct { fsutil.FileNoSplice `state:"nosave"` fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` - refs.AtomicRefCount - socket.SendReceiveTimeout - ep transport.Endpoint - stype linux.SockType + socketOpsCommon } // New creates a new unix socket. @@ -75,16 +74,29 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty } s := SocketOperations{ - ep: ep, - stype: stype, + socketOpsCommon: socketOpsCommon{ + ep: ep, + stype: stype, + }, } s.EnableLeakCheck("unix.SocketOperations") return fs.NewFile(ctx, d, flags, &s) } +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { + refs.AtomicRefCount + socket.SendReceiveTimeout + + ep transport.Endpoint + stype linux.SockType +} + // DecRef implements RefCounter.DecRef. -func (s *SocketOperations) DecRef() { +func (s *socketOpsCommon) DecRef() { s.DecRefWithDestructor(func() { s.ep.Close() }) @@ -97,7 +109,7 @@ func (s *SocketOperations) Release() { s.DecRef() } -func (s *SocketOperations) isPacket() bool { +func (s *socketOpsCommon) isPacket() bool { switch s.stype { case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: return true @@ -110,7 +122,7 @@ func (s *SocketOperations) isPacket() bool { } // Endpoint extracts the transport.Endpoint. -func (s *SocketOperations) Endpoint() transport.Endpoint { +func (s *socketOpsCommon) Endpoint() transport.Endpoint { return s.ep } @@ -143,7 +155,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) { // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -155,7 +167,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, // GetSockName implements the linux syscall getsockname(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -178,7 +190,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // Listen implements the linux syscall listen(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { return s.ep.Listen(backlog) } @@ -310,6 +322,8 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Create the socket. + // + // TODO(gvisor.dev/issue/2324): Correctly set file permissions. childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}}) if err != nil { return syserr.ErrPortInUse @@ -345,6 +359,31 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, return ep, nil } + if kernel.VFS2Enabled { + p := fspath.Parse(path) + root := t.FSContext().RootDirectoryVFS2() + start := root + relPath := !p.Absolute + if relPath { + start = t.FSContext().WorkingDirectoryVFS2() + } + pop := vfs.PathOperation{ + Root: root, + Start: start, + Path: p, + FollowFinalSymlink: true, + } + ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop) + root.DecRef() + if relPath { + start.DecRef() + } + if e != nil { + return nil, syserr.FromError(e) + } + return ep, nil + } + // Find the node in the filesystem. root := t.FSContext().RootDirectory() cwd := t.FSContext().WorkingDirectory() @@ -363,12 +402,11 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, // No socket! return nil, syserr.ErrConnectionRefused } - return ep, nil } // Connect implements the linux syscall connect(2) for unix sockets. -func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { ep, err := extractEndpoint(t, sockaddr) if err != nil { return err @@ -379,7 +417,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo return s.ep.Connect(t, ep) } -// Writev implements fs.FileOperations.Write. +// Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { t := kernel.TaskFromContext(ctx) ctrl := control.New(t, s.ep, nil) @@ -399,7 +437,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO // SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by // a transport.Endpoint. -func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { w := EndpointWriter{ Ctx: t, Endpoint: s.ep, @@ -453,27 +491,27 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] } // Passcred implements transport.Credentialer.Passcred. -func (s *SocketOperations) Passcred() bool { +func (s *socketOpsCommon) Passcred() bool { return s.ep.Passcred() } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. -func (s *SocketOperations) ConnectedPasscred() bool { +func (s *socketOpsCommon) ConnectedPasscred() bool { return s.ep.ConnectedPasscred() } // Readiness implements waiter.Waitable.Readiness. -func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { return s.ep.Readiness(mask) } // EventRegister implements waiter.Waitable.EventRegister. -func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.ep.EventRegister(e, mask) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (s *SocketOperations) EventUnregister(e *waiter.Entry) { +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } @@ -485,7 +523,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa // Shutdown implements the linux syscall shutdown(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { f, err := netstack.ConvertShutdown(how) if err != nil { return err @@ -511,7 +549,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -648,12 +686,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } // State implements socket.Socket.State. -func (s *SocketOperations) State() uint32 { +func (s *socketOpsCommon) State() uint32 { return s.ep.State() } // Type implements socket.Socket.Type. -func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { // Unix domain sockets always have a protocol of 0. return linux.AF_UNIX, s.stype, 0 } @@ -706,4 +744,5 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F func init() { socket.RegisterProvider(linux.AF_UNIX, &provider{}) + socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{}) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go new file mode 100644 index 000000000..3e54d49c4 --- /dev/null +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -0,0 +1,348 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unix + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/control" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// SocketVFS2 implements socket.SocketVFS2 (and by extension, +// vfs.FileDescriptionImpl) for Unix sockets. +type SocketVFS2 struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + + socketOpsCommon +} + +// NewVFS2File creates and returns a new vfs.FileDescription for a unix socket. +func NewVFS2File(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { + sock := NewFDImpl(ep, stype) + vfsfd := &sock.vfsfd + if err := sockfs.InitSocket(sock, vfsfd, t.Kernel().SocketMount(), t.Credentials()); err != nil { + return nil, syserr.FromError(err) + } + return vfsfd, nil +} + +// NewFDImpl creates and returns a new SocketVFS2. +func NewFDImpl(ep transport.Endpoint, stype linux.SockType) *SocketVFS2 { + // You can create AF_UNIX, SOCK_RAW sockets. They're the same as + // SOCK_DGRAM and don't require CAP_NET_RAW. + if stype == linux.SOCK_RAW { + stype = linux.SOCK_DGRAM + } + + return &SocketVFS2{ + socketOpsCommon: socketOpsCommon{ + ep: ep, + stype: stype, + }, + } +} + +// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) +} + +// blockingAccept implements a blocking version of accept(2), that is, if no +// connections are ready to be accept, it will block until one becomes ready. +func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { + // Register for notifications. + e, ch := waiter.NewChannelEntry(nil) + s.socketOpsCommon.EventRegister(&e, waiter.EventIn) + defer s.socketOpsCommon.EventUnregister(&e) + + // Try to accept the connection; if it fails, then wait until we get a + // notification. + for { + if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + return ep, err + } + + if err := t.Block(ch); err != nil { + return nil, syserr.FromError(err) + } + } +} + +// Accept implements the linux syscall accept(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { + // Issue the accept request to get the new endpoint. + ep, err := s.ep.Accept() + if err != nil { + if err != syserr.ErrWouldBlock || !blocking { + return 0, nil, 0, err + } + + var err *syserr.Error + ep, err = s.blockingAccept(t) + if err != nil { + return 0, nil, 0, err + } + } + + // We expect this to be a FileDescription here. + ns, err := NewVFS2File(t, ep, s.stype) + if err != nil { + return 0, nil, 0, err + } + defer ns.DecRef() + + if flags&linux.SOCK_NONBLOCK != 0 { + ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK) + } + + var addr linux.SockAddr + var addrLen uint32 + if peerRequested { + // Get address of the peer. + var err *syserr.Error + addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) + if err != nil { + return 0, nil, 0, err + } + } + + fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ + CloseOnExec: flags&linux.SOCK_CLOEXEC != 0, + }) + if e != nil { + return 0, nil, 0, syserr.FromError(e) + } + + t.Kernel().RecordSocketVFS2(ns) + return fd, addr, addrLen, nil +} + +// Bind implements the linux syscall bind(2) for unix sockets. +func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { + p, e := extractPath(sockaddr) + if e != nil { + return e + } + + bep, ok := s.ep.(transport.BoundEndpoint) + if !ok { + // This socket can't be bound. + return syserr.ErrInvalidArgument + } + + return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error { + // Is it abstract? + if p[0] == 0 { + if t.IsNetworkNamespaced() { + return syserr.ErrInvalidEndpointState + } + if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil { + // syserr.ErrPortInUse corresponds to EADDRINUSE. + return syserr.ErrPortInUse + } + } else { + path := fspath.Parse(p) + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef() + start := root + relPath := !path.Absolute + if relPath { + start = t.FSContext().WorkingDirectoryVFS2() + defer start.DecRef() + } + pop := vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + } + err := t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{ + // TODO(gvisor.dev/issue/2324): The file permissions should be taken + // from s and t.FSContext().Umask() (see net/unix/af_unix.c:unix_bind), + // but VFS1 just always uses 0400. Resolve this inconsistency. + Mode: linux.S_IFSOCK | 0400, + Endpoint: bep, + }) + if err == syserror.EEXIST { + return syserr.ErrAddressInUse + } + return syserr.FromError(err) + } + + return nil + }) +} + +// Ioctl implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return netstack.Ioctl(ctx, s.ep, uio, args) +} + +// PRead implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + if dst.NumBytes() == 0 { + return 0, nil + } + return dst.CopyOutFrom(ctx, &EndpointReader{ + Ctx: ctx, + Endpoint: s.ep, + NumRights: 0, + Peek: false, + From: nil, + }) +} + +// PWrite implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + t := kernel.TaskFromContext(ctx) + ctrl := control.New(t, s.ep, nil) + + if src.NumBytes() == 0 { + nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil) + return int64(nInt), err.ToError() + } + + return src.CopyInTo(ctx, &EndpointWriter{ + Ctx: ctx, + Endpoint: s.ep, + Control: ctrl, + To: nil, + }) +} + +// Release implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Release() { + // Release only decrements a reference on s because s may be referenced in + // the abstract socket namespace. + s.DecRef() +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { + return s.socketOpsCommon.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.socketOpsCommon.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SocketVFS2) EventUnregister(e *waiter.Entry) { + s.socketOpsCommon.EventUnregister(e) +} + +// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { + return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) +} + +// providerVFS2 is a unix domain socket provider for VFS2. +type providerVFS2 struct{} + +func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Check arguments. + if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { + return nil, syserr.ErrProtocolNotSupported + } + + // Create the endpoint and socket. + var ep transport.Endpoint + switch stype { + case linux.SOCK_DGRAM, linux.SOCK_RAW: + ep = transport.NewConnectionless(t) + case linux.SOCK_SEQPACKET, linux.SOCK_STREAM: + ep = transport.NewConnectioned(t, stype, t.Kernel()) + default: + return nil, syserr.ErrInvalidArgument + } + + f, err := NewVFS2File(t, ep, stype) + if err != nil { + ep.Close() + return nil, err + } + return f, nil +} + +// Pair creates a new pair of AF_UNIX connected sockets. +func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + // Check arguments. + if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { + return nil, nil, syserr.ErrProtocolNotSupported + } + + switch stype { + case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW: + // Ok + default: + return nil, nil, syserr.ErrInvalidArgument + } + + // Create the endpoints and sockets. + ep1, ep2 := transport.NewPair(t, stype, t.Kernel()) + s1, err := NewVFS2File(t, ep1, stype) + if err != nil { + ep1.Close() + ep2.Close() + return nil, nil, err + } + s2, err := NewVFS2File(t, ep2, stype) + if err != nil { + s1.DecRef() + ep2.Close() + return nil, nil, err + } + + return s1, s2, nil +} diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 77655558e..b94c4fbf5 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -778,9 +778,6 @@ func (s SyscallMap) Name(sysno uintptr) string { // // N.B. This is not in an init function because we can't be sure all syscall // tables are registered with the kernel when init runs. -// -// TODO(gvisor.dev/issue/155): remove kernel package dependencies from this -// package and have the kernel package self-initialize all syscall tables. func Initialize() { for _, table := range kernel.SyscallTables() { // Is this known? diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index b401978db..38cbeba5a 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -114,14 +114,28 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca func IoDestroy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { id := args[0].Uint64() - // Destroy the given context. - if !t.MemoryManager().DestroyAIOContext(t, id) { + ctx := t.MemoryManager().DestroyAIOContext(t, id) + if ctx == nil { // Does not exist. return 0, nil, syserror.EINVAL } - // FIXME(fvoznika): Linux blocks until all AIO to the destroyed context is - // done. - return 0, nil, nil + + // Drain completed requests amd wait for pending requests until there are no + // more. + for { + ctx.Drain() + + ch := ctx.WaitChannel() + if ch == nil { + // No more requests, we're done. + return 0, nil, nil + } + // The task cannot be interrupted during the wait. Equivalent to + // TASK_UNINTERRUPTIBLE in Linux. + t.UninterruptibleSleepStart(true /* deactivate */) + <-ch + t.UninterruptibleSleepFinish(true /* activate */) + } } // IoGetevents implements linux syscall io_getevents(2). @@ -200,13 +214,13 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadline ktime.Time) (interface{}, error) { for { if v, ok := ctx.PopRequest(); ok { - // Request was readly available. Just return it. + // Request was readily available. Just return it. return v, nil } // Need to wait for request completion. - done, active := ctx.WaitChannel() - if !active { + done := ctx.WaitChannel() + if done == nil { // Context has been destroyed. return nil, syserror.EINVAL } @@ -248,6 +262,10 @@ func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) { } func performCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *ioCallback, ioseq usermem.IOSequence, ctx *mm.AIOContext, eventFile *fs.File) { + if ctx.Dead() { + ctx.CancelPendingRequest() + return + } ev := &ioEvent{ Data: cb.Data, Obj: uint64(cbAddr), diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go index 9c6728530..f92bf8096 100644 --- a/pkg/sentry/syscalls/linux/sys_prctl.go +++ b/pkg/sentry/syscalls/linux/sys_prctl.go @@ -161,8 +161,8 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if args[1].Int() != 1 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 { return 0, nil, syserror.EINVAL } - // no_new_privs is assumed to always be set. See - // kernel.Task.updateCredsForExec. + // PR_SET_NO_NEW_PRIVS is assumed to always be set. + // See kernel.Task.updateCredsForExecLocked. return 0, nil, nil case linux.PR_GET_NO_NEW_PRIVS: diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 2919228d0..61b2576ac 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -31,6 +31,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// LINT.IfChange + // minListenBacklog is the minimum reasonable backlog for listening sockets. const minListenBacklog = 8 @@ -244,7 +246,10 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // Copy the file descriptors out. if _, err := t.CopyOut(socks, fds); err != nil { - // Note that we don't close files here; see pipe(2) also. + for _, fd := range fds { + _, file := t.FDTable().Remove(fd) + file.DecRef() + } return 0, nil, err } @@ -1128,3 +1133,5 @@ func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen) return n, nil, err } + +// LINT.ThenChange(./vfs2/socket.go) diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index fd642834b..fbc6cf15f 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -29,6 +29,10 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB return 0, syserror.EINVAL } + if opts.Length > int64(kernel.MAX_RW_COUNT) { + opts.Length = int64(kernel.MAX_RW_COUNT) + } + var ( total int64 n int64 diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 2eb210014..b32abfe59 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -21,16 +21,19 @@ go_library( "poll.go", "read_write.go", "setstat.go", + "socket.go", "stat.go", "stat_amd64.go", "stat_arm64.go", "sync.go", + "sys_timerfd.go", "xattr.go", ], marshal = True, visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/binary", "//pkg/bits", "//pkg/fspath", "//pkg/gohacks", @@ -42,10 +45,14 @@ go_library( "//pkg/sentry/limits", "//pkg/sentry/loader", "//pkg/sentry/memmap", + "//pkg/sentry/socket", + "//pkg/sentry/socket/control", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/syscalls", "//pkg/sentry/syscalls/linux", "//pkg/sentry/vfs", "//pkg/sync", + "//pkg/syserr", "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go index a61cc5059..62e98817d 100644 --- a/pkg/sentry/syscalls/linux/vfs2/getdents.go +++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go @@ -97,6 +97,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { // char d_name[]; /* Filename (null-terminated) */ // }; size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name) + size = (size + 7) &^ 7 // round up to multiple of 8 if size > cb.remaining { return syserror.EINVAL } @@ -106,7 +107,12 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { usermem.ByteOrder.PutUint16(buf[16:18], uint16(size)) buf[18] = dirent.Type copy(buf[19:], dirent.Name) - buf[size-1] = 0 // NUL terminator + // Zero out all remaining bytes in buf, including the NUL terminator + // after dirent.Name. + bufTail := buf[19+len(dirent.Name):] + for i := range bufTail { + bufTail[i] = 0 + } } else { // struct linux_dirent { // unsigned long d_ino; /* Inode number */ @@ -125,6 +131,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width())) } size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name) + size = (size + 7) &^ 7 // round up to multiple of sizeof(long) if size > cb.remaining { return syserror.EINVAL } @@ -133,9 +140,14 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff)) usermem.ByteOrder.PutUint16(buf[16:18], uint16(size)) copy(buf[18:], dirent.Name) - buf[size-3] = 0 // NUL terminator - buf[size-2] = 0 // zero padding byte - buf[size-1] = dirent.Type + // Zero out all remaining bytes in buf, including the NUL terminator + // after dirent.Name and the zero padding byte between the name and + // dirent type. + bufTail := buf[18+len(dirent.Name):] + for i := range bufTail { + bufTail[i] = 0 + } + bufTail[2] = dirent.Type } n, err := cb.t.CopyOutBytes(cb.addr, buf) if err != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go index 7d220bc20..645e0bcb8 100644 --- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go +++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go @@ -44,21 +44,22 @@ func Override(table map[uintptr]kernel.Syscall) { table[32] = syscalls.Supported("dup", Dup) table[33] = syscalls.Supported("dup2", Dup2) delete(table, 40) // sendfile - delete(table, 41) // socket - delete(table, 42) // connect - delete(table, 43) // accept - delete(table, 44) // sendto - delete(table, 45) // recvfrom - delete(table, 46) // sendmsg - delete(table, 47) // recvmsg - delete(table, 48) // shutdown - delete(table, 49) // bind - delete(table, 50) // listen - delete(table, 51) // getsockname - delete(table, 52) // getpeername - delete(table, 53) // socketpair - delete(table, 54) // setsockopt - delete(table, 55) // getsockopt + // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. + table[41] = syscalls.PartiallySupported("socket", Socket, "In process of porting socket syscalls to VFS2.", nil) + table[42] = syscalls.PartiallySupported("connect", Connect, "In process of porting socket syscalls to VFS2.", nil) + table[43] = syscalls.PartiallySupported("accept", Accept, "In process of porting socket syscalls to VFS2.", nil) + table[44] = syscalls.PartiallySupported("sendto", SendTo, "In process of porting socket syscalls to VFS2.", nil) + table[45] = syscalls.PartiallySupported("recvfrom", RecvFrom, "In process of porting socket syscalls to VFS2.", nil) + table[46] = syscalls.PartiallySupported("sendmsg", SendMsg, "In process of porting socket syscalls to VFS2.", nil) + table[47] = syscalls.PartiallySupported("recvmsg", RecvMsg, "In process of porting socket syscalls to VFS2.", nil) + table[48] = syscalls.PartiallySupported("shutdown", Shutdown, "In process of porting socket syscalls to VFS2.", nil) + table[49] = syscalls.PartiallySupported("bind", Bind, "In process of porting socket syscalls to VFS2.", nil) + table[50] = syscalls.PartiallySupported("listen", Listen, "In process of porting socket syscalls to VFS2.", nil) + table[51] = syscalls.PartiallySupported("getsockname", GetSockName, "In process of porting socket syscalls to VFS2.", nil) + table[52] = syscalls.PartiallySupported("getpeername", GetPeerName, "In process of porting socket syscalls to VFS2.", nil) + table[53] = syscalls.PartiallySupported("socketpair", SocketPair, "In process of porting socket syscalls to VFS2.", nil) + table[54] = syscalls.PartiallySupported("getsockopt", GetSockOpt, "In process of porting socket syscalls to VFS2.", nil) + table[55] = syscalls.PartiallySupported("setsockopt", SetSockOpt, "In process of porting socket syscalls to VFS2.", nil) table[59] = syscalls.Supported("execve", Execve) table[72] = syscalls.Supported("fcntl", Fcntl) delete(table, 73) // flock @@ -139,12 +140,13 @@ func Override(table map[uintptr]kernel.Syscall) { table[280] = syscalls.Supported("utimensat", Utimensat) table[281] = syscalls.Supported("epoll_pwait", EpollPwait) delete(table, 282) // signalfd - delete(table, 283) // timerfd_create + table[283] = syscalls.Supported("timerfd_create", TimerfdCreate) delete(table, 284) // eventfd delete(table, 285) // fallocate - delete(table, 286) // timerfd_settime - delete(table, 287) // timerfd_gettime - delete(table, 288) // accept4 + table[286] = syscalls.Supported("timerfd_settime", TimerfdSettime) + table[287] = syscalls.Supported("timerfd_gettime", TimerfdGettime) + // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. + table[288] = syscalls.PartiallySupported("accept4", Accept4, "In process of porting socket syscalls to VFS2.", nil) delete(table, 289) // signalfd4 delete(table, 290) // eventfd2 table[291] = syscalls.Supported("epoll_create1", EpollCreate1) @@ -153,9 +155,11 @@ func Override(table map[uintptr]kernel.Syscall) { delete(table, 294) // inotify_init1 table[295] = syscalls.Supported("preadv", Preadv) table[296] = syscalls.Supported("pwritev", Pwritev) - delete(table, 299) // recvmmsg + // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. + table[299] = syscalls.PartiallySupported("recvmmsg", RecvMMsg, "In process of porting socket syscalls to VFS2.", nil) table[306] = syscalls.Supported("syncfs", Syncfs) - delete(table, 307) // sendmmsg + // TODO(gvisor.dev/issue/1485): Port all socket variants to VFS2. + table[307] = syscalls.PartiallySupported("sendmmsg", SendMMsg, "In process of porting socket syscalls to VFS2.", nil) table[316] = syscalls.Supported("renameat2", Renameat2) delete(table, 319) // memfd_create table[322] = syscalls.Supported("execveat", Execveat) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go new file mode 100644 index 000000000..79a4a7ada --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -0,0 +1,1138 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "time" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/control" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// minListenBacklog is the minimum reasonable backlog for listening sockets. +const minListenBacklog = 8 + +// maxListenBacklog is the maximum allowed backlog for listening sockets. +const maxListenBacklog = 1024 + +// maxAddrLen is the maximum socket address length we're willing to accept. +const maxAddrLen = 200 + +// maxOptLen is the maximum sockopt parameter length we're willing to accept. +const maxOptLen = 1024 * 8 + +// maxControlLen is the maximum length of the msghdr.msg_control buffer we're +// willing to accept. Note that this limit is smaller than Linux, which allows +// buffers upto INT_MAX. +const maxControlLen = 10 * 1024 * 1024 + +// nameLenOffset is the offset from the start of the MessageHeader64 struct to +// the NameLen field. +const nameLenOffset = 8 + +// controlLenOffset is the offset form the start of the MessageHeader64 struct +// to the ControlLen field. +const controlLenOffset = 40 + +// flagsOffset is the offset form the start of the MessageHeader64 struct +// to the Flags field. +const flagsOffset = 48 + +const sizeOfInt32 = 4 + +// messageHeader64Len is the length of a MessageHeader64 struct. +var messageHeader64Len = uint64(binary.Size(MessageHeader64{})) + +// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct. +var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{})) + +// baseRecvFlags are the flags that are accepted across recvmsg(2), +// recvmmsg(2), and recvfrom(2). +const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | linux.MSG_NOSIGNAL | linux.MSG_WAITALL | linux.MSG_TRUNC | linux.MSG_CTRUNC + +// MessageHeader64 is the 64-bit representation of the msghdr struct used in +// the recvmsg and sendmsg syscalls. +type MessageHeader64 struct { + // Name is the optional pointer to a network address buffer. + Name uint64 + + // NameLen is the length of the buffer pointed to by Name. + NameLen uint32 + _ uint32 + + // Iov is a pointer to an array of io vectors that describe the memory + // locations involved in the io operation. + Iov uint64 + + // IovLen is the length of the array pointed to by Iov. + IovLen uint64 + + // Control is the optional pointer to ancillary control data. + Control uint64 + + // ControlLen is the length of the data pointed to by Control. + ControlLen uint64 + + // Flags on the sent/received message. + Flags int32 + _ int32 +} + +// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in +// the recvmmsg and sendmmsg syscalls. +type multipleMessageHeader64 struct { + msgHdr MessageHeader64 + msgLen uint32 + _ int32 +} + +// CopyInMessageHeader64 copies a message header from user to kernel memory. +func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error { + b := t.CopyScratchBuffer(52) + if _, err := t.CopyInBytes(addr, b); err != nil { + return err + } + + msg.Name = usermem.ByteOrder.Uint64(b[0:]) + msg.NameLen = usermem.ByteOrder.Uint32(b[8:]) + msg.Iov = usermem.ByteOrder.Uint64(b[16:]) + msg.IovLen = usermem.ByteOrder.Uint64(b[24:]) + msg.Control = usermem.ByteOrder.Uint64(b[32:]) + msg.ControlLen = usermem.ByteOrder.Uint64(b[40:]) + msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:])) + + return nil +} + +// CaptureAddress allocates memory for and copies a socket address structure +// from the untrusted address space range. +func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) { + if addrlen > maxAddrLen { + return nil, syserror.EINVAL + } + + addrBuf := make([]byte, addrlen) + if _, err := t.CopyInBytes(addr, addrBuf); err != nil { + return nil, err + } + + return addrBuf, nil +} + +// writeAddress writes a sockaddr structure and its length to an output buffer +// in the unstrusted address space range. If the address is bigger than the +// buffer, it is truncated. +func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { + // Get the buffer length. + var bufLen uint32 + if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil { + return err + } + + if int32(bufLen) < 0 { + return syserror.EINVAL + } + + // Write the length unconditionally. + if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil { + return err + } + + if addr == nil { + return nil + } + + if bufLen > addrLen { + bufLen = addrLen + } + + // Copy as much of the address as will fit in the buffer. + encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr) + if bufLen > uint32(len(encodedAddr)) { + bufLen = uint32(len(encodedAddr)) + } + _, err := t.CopyOutBytes(addrPtr, encodedAddr[:int(bufLen)]) + return err +} + +// Socket implements the linux syscall socket(2). +func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + domain := int(args[0].Int()) + stype := args[1].Int() + protocol := int(args[2].Int()) + + // Check and initialize the flags. + if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { + return 0, nil, syserror.EINVAL + } + + // Create the new socket. + s, e := socket.NewVFS2(t, domain, linux.SockType(stype&0xf), protocol) + if e != nil { + return 0, nil, e.ToError() + } + defer s.DecRef() + + if err := s.SetStatusFlags(t, t.Credentials(), uint32(stype&linux.SOCK_NONBLOCK)); err != nil { + return 0, nil, err + } + + fd, err := t.NewFDFromVFS2(0, s, kernel.FDFlags{ + CloseOnExec: stype&linux.SOCK_CLOEXEC != 0, + }) + if err != nil { + return 0, nil, err + } + + return uintptr(fd), nil, nil +} + +// SocketPair implements the linux syscall socketpair(2). +func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + domain := int(args[0].Int()) + stype := args[1].Int() + protocol := int(args[2].Int()) + addr := args[3].Pointer() + + // Check and initialize the flags. + if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { + return 0, nil, syserror.EINVAL + } + + // Create the socket pair. + s1, s2, e := socket.PairVFS2(t, domain, linux.SockType(stype&0xf), protocol) + if e != nil { + return 0, nil, e.ToError() + } + // Adding to the FD table will cause an extra reference to be acquired. + defer s1.DecRef() + defer s2.DecRef() + + nonblocking := uint32(stype & linux.SOCK_NONBLOCK) + if err := s1.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil { + return 0, nil, err + } + if err := s2.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil { + return 0, nil, err + } + + // Create the FDs for the sockets. + flags := kernel.FDFlags{ + CloseOnExec: stype&linux.SOCK_CLOEXEC != 0, + } + fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{s1, s2}, flags) + if err != nil { + return 0, nil, err + } + + if _, err := t.CopyOut(addr, fds); err != nil { + for _, fd := range fds { + _, file := t.FDTable().Remove(fd) + file.DecRef() + } + return 0, nil, err + } + + return 0, nil, nil +} + +// Connect implements the linux syscall connect(2). +func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Uint() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Capture address and call syscall implementation. + a, err := CaptureAddress(t, addr, addrlen) + if err != nil { + return 0, nil, err + } + + blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0 + return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS) +} + +// accept is the implementation of the accept syscall. It is called by accept +// and accept4 syscall handlers. +func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) { + // Check that no unsupported flags are passed in. + if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { + return 0, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, syserror.ENOTSOCK + } + + // Call the syscall implementation for this socket, then copy the + // output address if one is specified. + blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0 + + peerRequested := addrLen != 0 + nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking) + if e != nil { + return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + } + if peerRequested { + // NOTE(magi): Linux does not give you an error if it can't + // write the data back out so neither do we. + if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL { + return 0, err + } + } + return uintptr(nfd), nil +} + +// Accept4 implements the linux syscall accept4(2). +func Accept4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Pointer() + flags := int(args[3].Int()) + + n, err := accept(t, fd, addr, addrlen, flags) + return n, nil, err +} + +// Accept implements the linux syscall accept(2). +func Accept(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Pointer() + + n, err := accept(t, fd, addr, addrlen, 0) + return n, nil, err +} + +// Bind implements the linux syscall bind(2). +func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Uint() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Capture address and call syscall implementation. + a, err := CaptureAddress(t, addr, addrlen) + if err != nil { + return 0, nil, err + } + + return 0, nil, s.Bind(t, a).ToError() +} + +// Listen implements the linux syscall listen(2). +func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + backlog := args[1].Int() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Per Linux, the backlog is silently capped to reasonable values. + if backlog <= 0 { + backlog = minListenBacklog + } + if backlog > maxListenBacklog { + backlog = maxListenBacklog + } + + return 0, nil, s.Listen(t, int(backlog)).ToError() +} + +// Shutdown implements the linux syscall shutdown(2). +func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + how := args[1].Int() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Validate how, then call syscall implementation. + switch how { + case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR: + default: + return 0, nil, syserror.EINVAL + } + + return 0, nil, s.Shutdown(t, int(how)).ToError() +} + +// GetSockOpt implements the linux syscall getsockopt(2). +func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + level := args[1].Int() + name := args[2].Int() + optValAddr := args[3].Pointer() + optLenAddr := args[4].Pointer() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Read the length. Reject negative values. + optLen := int32(0) + if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + return 0, nil, err + } + if optLen < 0 { + return 0, nil, syserror.EINVAL + } + + // Call syscall implementation then copy both value and value len out. + v, e := getSockOpt(t, s, int(level), int(name), optValAddr, int(optLen)) + if e != nil { + return 0, nil, e.ToError() + } + + vLen := int32(binary.Size(v)) + if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + return 0, nil, err + } + + if v != nil { + if _, err := t.CopyOut(optValAddr, v); err != nil { + return 0, nil, err + } + } + + return 0, nil, nil +} + +// getSockOpt tries to handle common socket options, or dispatches to a specific +// socket implementation. +func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { + if level == linux.SOL_SOCKET { + switch name { + case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: + if len < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + } + + switch name { + case linux.SO_TYPE: + _, skType, _ := s.Type() + return int32(skType), nil + case linux.SO_DOMAIN: + family, _, _ := s.Type() + return int32(family), nil + case linux.SO_PROTOCOL: + _, _, protocol := s.Type() + return int32(protocol), nil + } + } + + return s.GetSockOpt(t, level, name, optValAddr, len) +} + +// SetSockOpt implements the linux syscall setsockopt(2). +// +// Note that unlike Linux, enabling SO_PASSCRED does not autobind the socket. +func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + level := args[1].Int() + name := args[2].Int() + optValAddr := args[3].Pointer() + optLen := args[4].Int() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + if optLen < 0 { + return 0, nil, syserror.EINVAL + } + if optLen > maxOptLen { + return 0, nil, syserror.EINVAL + } + buf := t.CopyScratchBuffer(int(optLen)) + if _, err := t.CopyIn(optValAddr, &buf); err != nil { + return 0, nil, err + } + + // Call syscall implementation. + if err := s.SetSockOpt(t, int(level), int(name), buf); err != nil { + return 0, nil, err.ToError() + } + + return 0, nil, nil +} + +// GetSockName implements the linux syscall getsockname(2). +func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Pointer() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Get the socket name and copy it to the caller. + v, vl, err := s.GetSockName(t) + if err != nil { + return 0, nil, err.ToError() + } + + return 0, nil, writeAddress(t, v, vl, addr, addrlen) +} + +// GetPeerName implements the linux syscall getpeername(2). +func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + addrlen := args[2].Pointer() + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Get the socket peer name and copy it to the caller. + v, vl, err := s.GetPeerName(t) + if err != nil { + return 0, nil, err.ToError() + } + + return 0, nil, writeAddress(t, v, vl, addr, addrlen) +} + +// RecvMsg implements the linux syscall recvmsg(2). +func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + msgPtr := args[1].Pointer() + flags := args[2].Int() + + if t.Arch().Width() != 8 { + // We only handle 64-bit for now. + return 0, nil, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Reject flags that we don't handle yet. + if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 { + return 0, nil, syserror.EINVAL + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + var haveDeadline bool + var deadline ktime.Time + if dl := s.RecvTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + + n, err := recvSingleMsg(t, s, msgPtr, flags, haveDeadline, deadline) + return n, nil, err +} + +// RecvMMsg implements the linux syscall recvmmsg(2). +func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + msgPtr := args[1].Pointer() + vlen := args[2].Uint() + flags := args[3].Int() + toPtr := args[4].Pointer() + + if t.Arch().Width() != 8 { + // We only handle 64-bit for now. + return 0, nil, syserror.EINVAL + } + + // Reject flags that we don't handle yet. + if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 { + return 0, nil, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + var haveDeadline bool + var deadline ktime.Time + if toPtr != 0 { + var ts linux.Timespec + if _, err := ts.CopyIn(t, toPtr); err != nil { + return 0, nil, err + } + if !ts.Valid() { + return 0, nil, syserror.EINVAL + } + deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration()) + haveDeadline = true + } + + if !haveDeadline { + if dl := s.RecvTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + } + + var count uint32 + var err error + for i := uint64(0); i < uint64(vlen); i++ { + mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len) + if !ok { + return 0, nil, syserror.EFAULT + } + var n uintptr + if n, err = recvSingleMsg(t, s, mp, flags, haveDeadline, deadline); err != nil { + break + } + + // Copy the received length to the caller. + lp, ok := mp.AddLength(messageHeader64Len) + if !ok { + return 0, nil, syserror.EFAULT + } + if _, err = t.CopyOut(lp, uint32(n)); err != nil { + break + } + count++ + } + + if count == 0 { + return 0, nil, err + } + return uintptr(count), nil, nil +} + +func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { + // Capture the message header and io vectors. + var msg MessageHeader64 + if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + return 0, err + } + + if msg.IovLen > linux.UIO_MAXIOV { + return 0, syserror.EMSGSIZE + } + dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, err + } + + // FIXME(b/63594852): Pretend we have an empty error queue. + if flags&linux.MSG_ERRQUEUE != 0 { + return 0, syserror.EAGAIN + } + + // Fast path when no control message nor name buffers are provided. + if msg.ControlLen == 0 && msg.NameLen == 0 { + n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) + if err != nil { + return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS) + } + if !cms.Unix.Empty() { + mflags |= linux.MSG_CTRUNC + cms.Release() + } + + if int(msg.Flags) != mflags { + // Copy out the flags to the caller. + if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + return 0, err + } + } + + return uintptr(n), nil + } + + if msg.ControlLen > maxControlLen { + return 0, syserror.ENOBUFS + } + n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen) + if e != nil { + return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + } + defer cms.Release() + + controlData := make([]byte, 0, msg.ControlLen) + controlData = control.PackControlMessages(t, cms, controlData) + + if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { + creds, _ := cms.Unix.Credentials.(control.SCMCredentials) + controlData, mflags = control.PackCredentials(t, creds, controlData, mflags) + } + + if cms.Unix.Rights != nil { + controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags) + } + + // Copy the address to the caller. + if msg.NameLen != 0 { + if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil { + return 0, err + } + } + + // Copy the control data to the caller. + if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { + return 0, err + } + if len(controlData) > 0 { + if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil { + return 0, err + } + } + + // Copy out the flags to the caller. + if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + return 0, err + } + + return uintptr(n), nil +} + +// recvFrom is the implementation of the recvfrom syscall. It is called by +// recvfrom and recv syscall handlers. +func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) { + if int(bufLen) < 0 { + return 0, syserror.EINVAL + } + + // Reject flags that we don't handle yet. + if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 { + return 0, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, syserror.ENOTSOCK + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + dst, err := t.SingleIOSequence(bufPtr, int(bufLen), usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, err + } + + var haveDeadline bool + var deadline ktime.Time + if dl := s.RecvTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + + n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) + cm.Release() + if e != nil { + return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + } + + // Copy the address to the caller. + if nameLenPtr != 0 { + if err := writeAddress(t, sender, senderLen, namePtr, nameLenPtr); err != nil { + return 0, err + } + } + + return uintptr(n), nil +} + +// RecvFrom implements the linux syscall recvfrom(2). +func RecvFrom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + bufPtr := args[1].Pointer() + bufLen := args[2].Uint64() + flags := args[3].Int() + namePtr := args[4].Pointer() + nameLenPtr := args[5].Pointer() + + n, err := recvFrom(t, fd, bufPtr, bufLen, flags, namePtr, nameLenPtr) + return n, nil, err +} + +// SendMsg implements the linux syscall sendmsg(2). +func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + msgPtr := args[1].Pointer() + flags := args[2].Int() + + if t.Arch().Width() != 8 { + // We only handle 64-bit for now. + return 0, nil, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Reject flags that we don't handle yet. + if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 { + return 0, nil, syserror.EINVAL + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + n, err := sendSingleMsg(t, s, file, msgPtr, flags) + return n, nil, err +} + +// SendMMsg implements the linux syscall sendmmsg(2). +func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + msgPtr := args[1].Pointer() + vlen := args[2].Uint() + flags := args[3].Int() + + if t.Arch().Width() != 8 { + // We only handle 64-bit for now. + return 0, nil, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, nil, syserror.ENOTSOCK + } + + // Reject flags that we don't handle yet. + if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 { + return 0, nil, syserror.EINVAL + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + var count uint32 + var err error + for i := uint64(0); i < uint64(vlen); i++ { + mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len) + if !ok { + return 0, nil, syserror.EFAULT + } + var n uintptr + if n, err = sendSingleMsg(t, s, file, mp, flags); err != nil { + break + } + + // Copy the received length to the caller. + lp, ok := mp.AddLength(messageHeader64Len) + if !ok { + return 0, nil, syserror.EFAULT + } + if _, err = t.CopyOut(lp, uint32(n)); err != nil { + break + } + count++ + } + + if count == 0 { + return 0, nil, err + } + return uintptr(count), nil, nil +} + +func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) { + // Capture the message header. + var msg MessageHeader64 + if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + return 0, err + } + + var controlData []byte + if msg.ControlLen > 0 { + // Put an upper bound to prevent large allocations. + if msg.ControlLen > maxControlLen { + return 0, syserror.ENOBUFS + } + controlData = make([]byte, msg.ControlLen) + if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil { + return 0, err + } + } + + // Read the destination address if one is specified. + var to []byte + if msg.NameLen != 0 { + var err error + to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen) + if err != nil { + return 0, err + } + } + + // Read data then call the sendmsg implementation. + if msg.IovLen > linux.UIO_MAXIOV { + return 0, syserror.EMSGSIZE + } + src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, err + } + + controlMessages, err := control.Parse(t, s, controlData) + if err != nil { + return 0, err + } + + var haveDeadline bool + var deadline ktime.Time + if dl := s.SendTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + + // Call the syscall implementation. + n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) + err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) + if err != nil { + controlMessages.Release() + } + return uintptr(n), err +} + +// sendTo is the implementation of the sendto syscall. It is called by sendto +// and send syscall handlers. +func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) { + bl := int(bufLen) + if bl < 0 { + return 0, syserror.EINVAL + } + + // Get socket from the file descriptor. + file := t.GetFileVFS2(fd) + if file == nil { + return 0, syserror.EBADF + } + defer file.DecRef() + + // Extract the socket. + s, ok := file.Impl().(socket.SocketVFS2) + if !ok { + return 0, syserror.ENOTSOCK + } + + if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { + flags |= linux.MSG_DONTWAIT + } + + // Read the destination address if one is specified. + var to []byte + var err error + if namePtr != 0 { + to, err = CaptureAddress(t, namePtr, nameLen) + if err != nil { + return 0, err + } + } + + src, err := t.SingleIOSequence(bufPtr, bl, usermem.IOOpts{ + AddressSpaceActive: true, + }) + if err != nil { + return 0, err + } + + var haveDeadline bool + var deadline ktime.Time + if dl := s.SendTimeout(); dl > 0 { + deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond) + haveDeadline = true + } else if dl < 0 { + flags |= linux.MSG_DONTWAIT + } + + // Call the syscall implementation. + n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)}) + return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file) +} + +// SendTo implements the linux syscall sendto(2). +func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + bufPtr := args[1].Pointer() + bufLen := args[2].Uint64() + flags := args[3].Int() + namePtr := args[4].Pointer() + nameLen := args[5].Uint() + + n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen) + return n, nil, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/sys_timerfd.go b/pkg/sentry/syscalls/linux/vfs2/sys_timerfd.go new file mode 100644 index 000000000..7938a5249 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/sys_timerfd.go @@ -0,0 +1,123 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// TimerfdCreate implements Linux syscall timerfd_create(2). +func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + clockID := args[0].Int() + flags := args[1].Int() + + if flags&^(linux.TFD_CLOEXEC|linux.TFD_NONBLOCK) != 0 { + return 0, nil, syserror.EINVAL + } + + var fileFlags uint32 + if flags&linux.TFD_NONBLOCK != 0 { + fileFlags = linux.O_NONBLOCK + } + + var clock ktime.Clock + switch clockID { + case linux.CLOCK_REALTIME: + clock = t.Kernel().RealtimeClock() + case linux.CLOCK_MONOTONIC, linux.CLOCK_BOOTTIME: + clock = t.Kernel().MonotonicClock() + default: + return 0, nil, syserror.EINVAL + } + file, err := t.Kernel().VFS().NewTimerFD(clock, fileFlags) + if err != nil { + return 0, nil, err + } + defer file.DecRef() + fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ + CloseOnExec: flags&linux.TFD_CLOEXEC != 0, + }) + if err != nil { + return 0, nil, err + } + return uintptr(fd), nil, nil +} + +// TimerfdSettime implements Linux syscall timerfd_settime(2). +func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + flags := args[1].Int() + newValAddr := args[2].Pointer() + oldValAddr := args[3].Pointer() + + if flags&^(linux.TFD_TIMER_ABSTIME) != 0 { + return 0, nil, syserror.EINVAL + } + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + tfd, ok := file.Impl().(*vfs.TimerFileDescription) + if !ok { + return 0, nil, syserror.EINVAL + } + + var newVal linux.Itimerspec + if _, err := t.CopyIn(newValAddr, &newVal); err != nil { + return 0, nil, err + } + newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tfd.Clock()) + if err != nil { + return 0, nil, err + } + tm, oldS := tfd.SetTime(newS) + if oldValAddr != 0 { + oldVal := ktime.ItimerspecFromSetting(tm, oldS) + if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil { + return 0, nil, err + } + } + return 0, nil, nil +} + +// TimerfdGettime implements Linux syscall timerfd_gettime(2). +func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + curValAddr := args[1].Pointer() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + tfd, ok := file.Impl().(*vfs.TimerFileDescription) + if !ok { + return 0, nil, syserror.EINVAL + } + + tm, s := tfd.GetTime() + curVal := ktime.ItimerspecFromSetting(tm, s) + _, err := t.CopyOut(curValAddr, &curVal) + return 0, nil, err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go index 89e9ff4d7..af455d5c1 100644 --- a/pkg/sentry/syscalls/linux/vfs2/xattr.go +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go @@ -51,7 +51,7 @@ func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSyml } defer tpop.Release() - names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop) + names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size)) if err != nil { return 0, nil, err } @@ -74,7 +74,7 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } defer file.DecRef() - names, err := file.Listxattr(t) + names, err := file.Listxattr(t, uint64(size)) if err != nil { return 0, nil, err } @@ -116,7 +116,10 @@ func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli return 0, nil, err } - value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, name) + value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetxattrOptions{ + Name: name, + Size: uint64(size), + }) if err != nil { return 0, nil, err } @@ -145,7 +148,7 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } - value, err := file.Getxattr(t, name) + value, err := file.Getxattr(t, &vfs.GetxattrOptions{Name: name, Size: uint64(size)}) if err != nil { return 0, nil, err } @@ -230,7 +233,7 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } - return 0, nil, file.Setxattr(t, vfs.SetxattrOptions{ + return 0, nil, file.Setxattr(t, &vfs.SetxattrOptions{ Name: name, Value: value, Flags: uint32(flags), diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index bf4d27c7d..9aeb83fb0 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -36,6 +36,7 @@ go_library( "pathname.go", "permissions.go", "resolving_path.go", + "timerfd.go", "vfs.go", ], visibility = ["//pkg/sentry:internal"], @@ -51,6 +52,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/time", "//pkg/sentry/limits", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index d1f6dfb45..a64d86122 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -245,7 +245,7 @@ func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath } // ListxattrAt implements FilesystemImpl.ListxattrAt. -func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) { +func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) { if !rp.Done() { return nil, syserror.ENOTDIR } @@ -253,7 +253,7 @@ func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([ } // GetxattrAt implements FilesystemImpl.GetxattrAt. -func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) { +func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) { if !rp.Done() { return "", syserror.ENOTDIR } diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 3da45d744..8e0b40841 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -99,6 +99,8 @@ func (vfs *VirtualFilesystem) NewEpollInstanceFD() (*FileDescription, error) { interest: make(map[epollInterestKey]*epollInterest), } if err := ep.vfsfd.Init(ep, linux.O_RDWR, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, UseDentryMetadata: true, }); err != nil { return nil, err diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 8ee549dc2..4fb9aea87 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -84,6 +84,17 @@ type FileDescriptionOptions struct { // usually only the case if O_DIRECT would actually have an effect. AllowDirectIO bool + // If DenyPRead is true, calls to FileDescription.PRead() return ESPIPE. + DenyPRead bool + + // If DenyPWrite is true, calls to FileDescription.PWrite() return + // ESPIPE. + DenyPWrite bool + + // if InvalidWrite is true, calls to FileDescription.Write() return + // EINVAL. + InvalidWrite bool + // If UseDentryMetadata is true, calls to FileDescription methods that // interact with file and filesystem metadata (Stat, SetStat, StatFS, // Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling @@ -175,6 +186,12 @@ func (fd *FileDescription) DecRef() { } } +// Refs returns the current number of references. The returned count +// is inherently racy and is unsafe to use without external synchronization. +func (fd *FileDescription) Refs() int64 { + return atomic.LoadInt64(&fd.refs) +} + // Mount returns the mount on which fd was opened. It does not take a reference // on the returned Mount. func (fd *FileDescription) Mount() *Mount { @@ -306,6 +323,7 @@ type FileDescriptionImpl interface { // - If opts.Flags specifies unsupported options, PRead returns EOPNOTSUPP. // // Preconditions: The FileDescription was opened for reading. + // FileDescriptionOptions.DenyPRead == false. PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) // Read is similar to PRead, but does not specify an offset. @@ -337,6 +355,7 @@ type FileDescriptionImpl interface { // EOPNOTSUPP. // // Preconditions: The FileDescription was opened for writing. + // FileDescriptionOptions.DenyPWrite == false. PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) // Write is similar to PWrite, but does not specify an offset, which is @@ -382,11 +401,11 @@ type FileDescriptionImpl interface { Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) // Listxattr returns all extended attribute names for the file. - Listxattr(ctx context.Context) ([]string, error) + Listxattr(ctx context.Context, size uint64) ([]string, error) // Getxattr returns the value associated with the given extended attribute // for the file. - Getxattr(ctx context.Context, name string) (string, error) + Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) // Setxattr changes the value associated with the given extended attribute // for the file. @@ -515,6 +534,9 @@ func (fd *FileDescription) EventUnregister(e *waiter.Entry) { // offset, and returns the number of bytes read. PRead is permitted to return // partial reads with a nil error. func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { + if fd.opts.DenyPRead { + return 0, syserror.ESPIPE + } if !fd.readable { return 0, syserror.EBADF } @@ -533,6 +555,9 @@ func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opt // offset, and returns the number of bytes written. PWrite is permitted to // return partial writes with a nil error. func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { + if fd.opts.DenyPWrite { + return 0, syserror.ESPIPE + } if !fd.writable { return 0, syserror.EBADF } @@ -541,6 +566,9 @@ func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, o // Write is similar to PWrite, but does not specify an offset. func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { + if fd.opts.InvalidWrite { + return 0, syserror.EINVAL + } if !fd.writable { return 0, syserror.EBADF } @@ -577,18 +605,23 @@ func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. // Listxattr returns all extended attribute names for the file represented by // fd. -func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) { +// +// If the size of the list (including a NUL terminating byte after every entry) +// would exceed size, ERANGE may be returned. Note that implementations +// are free to ignore size entirely and return without error). In all cases, +// if size is 0, the list should be returned without error, regardless of size. +func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp) + names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size) vfsObj.putResolvingPath(rp) return names, err } - names, err := fd.impl.Listxattr(ctx) + names, err := fd.impl.Listxattr(ctx, size) if err == syserror.ENOTSUP { // Linux doesn't actually return ENOTSUP in this case; instead, // fs/xattr.c:vfs_listxattr() falls back to allowing the security @@ -601,34 +634,39 @@ func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) { // Getxattr returns the value associated with the given extended attribute for // the file represented by fd. -func (fd *FileDescription) Getxattr(ctx context.Context, name string) (string, error) { +// +// If the size of the return value exceeds opts.Size, ERANGE may be returned +// (note that implementations are free to ignore opts.Size entirely and return +// without error). In all cases, if opts.Size is 0, the value should be +// returned without error, regardless of size. +func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions) (string, error) { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, name) + val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts) vfsObj.putResolvingPath(rp) return val, err } - return fd.impl.Getxattr(ctx, name) + return fd.impl.Getxattr(ctx, *opts) } // Setxattr changes the value associated with the given extended attribute for // the file represented by fd. -func (fd *FileDescription) Setxattr(ctx context.Context, opts SetxattrOptions) error { +func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions) error { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, opts) + err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts) vfsObj.putResolvingPath(rp) return err } - return fd.impl.Setxattr(ctx, opts) + return fd.impl.Setxattr(ctx, *opts) } // Removexattr removes the given extended attribute from the file represented diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index d45e602ce..f4c111926 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -130,14 +130,14 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg // Listxattr implements FileDescriptionImpl.Listxattr analogously to // inode_operations::listxattr == NULL in Linux. -func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context) ([]string, error) { +func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context, size uint64) ([]string, error) { // This isn't exactly accurate; see FileDescription.Listxattr. return nil, syserror.ENOTSUP } // Getxattr implements FileDescriptionImpl.Getxattr analogously to // inode::i_opflags & IOP_XATTR == 0 in Linux. -func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, name string) (string, error) { +func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) { return "", syserror.ENOTSUP } diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index cd34782ff..a537a29d1 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -442,7 +442,13 @@ type FilesystemImpl interface { // - If extended attributes are not supported by the filesystem, // ListxattrAt returns nil. (See FileDescription.Listxattr for an // explanation.) - ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) + // + // - If the size of the list (including a NUL terminating byte after every + // entry) would exceed size, ERANGE may be returned. Note that + // implementations are free to ignore size entirely and return without + // error). In all cases, if size is 0, the list should be returned without + // error, regardless of size. + ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) // GetxattrAt returns the value associated with the given extended // attribute for the file at rp. @@ -451,7 +457,15 @@ type FilesystemImpl interface { // // - If extended attributes are not supported by the filesystem, GetxattrAt // returns ENOTSUP. - GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) + // + // - If an extended attribute named opts.Name does not exist, ENODATA is + // returned. + // + // - If the size of the return value exceeds opts.Size, ERANGE may be + // returned (note that implementations are free to ignore opts.Size entirely + // and return without error). In all cases, if opts.Size is 0, the value + // should be returned without error, regardless of size. + GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) // SetxattrAt changes the value associated with the given extended // attribute for the file at rp. @@ -460,6 +474,10 @@ type FilesystemImpl interface { // // - If extended attributes are not supported by the filesystem, SetxattrAt // returns ENOTSUP. + // + // - If XATTR_CREATE is set in opts.Flag and opts.Name already exists, + // EEXIST is returned. If XATTR_REPLACE is set and opts.Name does not exist, + // ENODATA is returned. SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error // RemovexattrAt removes the given extended attribute from the file at rp. @@ -468,6 +486,8 @@ type FilesystemImpl interface { // // - If extended attributes are not supported by the filesystem, // RemovexattrAt returns ENOTSUP. + // + // - If name does not exist, ENODATA is returned. RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error // BoundEndpointAt returns the Unix socket endpoint bound at the path rp. @@ -497,7 +517,7 @@ type FilesystemImpl interface { // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl. PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error - // TODO: inotify_add_watch() + // TODO(gvisor.dev/issue/1479): inotify_add_watch() } // PrependPathAtVFSRootError is returned by implementations of diff --git a/pkg/sentry/vfs/memxattr/BUILD b/pkg/sentry/vfs/memxattr/BUILD new file mode 100644 index 000000000..d8c4d27b9 --- /dev/null +++ b/pkg/sentry/vfs/memxattr/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "memxattr", + srcs = ["xattr.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + ], +) diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go new file mode 100644 index 000000000..cc1e7d764 --- /dev/null +++ b/pkg/sentry/vfs/memxattr/xattr.go @@ -0,0 +1,102 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memxattr provides a default, in-memory extended attribute +// implementation. +package memxattr + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" +) + +// SimpleExtendedAttributes implements extended attributes using a map of +// names to values. +// +// +stateify savable +type SimpleExtendedAttributes struct { + // mu protects the below fields. + mu sync.RWMutex `state:"nosave"` + xattrs map[string]string +} + +// Getxattr returns the value at 'name'. +func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, error) { + x.mu.RLock() + value, ok := x.xattrs[opts.Name] + x.mu.RUnlock() + if !ok { + return "", syserror.ENODATA + } + // Check that the size of the buffer provided in getxattr(2) is large enough + // to contain the value. + if opts.Size != 0 && uint64(len(value)) > opts.Size { + return "", syserror.ERANGE + } + return value, nil +} + +// Setxattr sets 'value' at 'name'. +func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error { + x.mu.Lock() + defer x.mu.Unlock() + if x.xattrs == nil { + if opts.Flags&linux.XATTR_REPLACE != 0 { + return syserror.ENODATA + } + x.xattrs = make(map[string]string) + } + + _, ok := x.xattrs[opts.Name] + if ok && opts.Flags&linux.XATTR_CREATE != 0 { + return syserror.EEXIST + } + if !ok && opts.Flags&linux.XATTR_REPLACE != 0 { + return syserror.ENODATA + } + + x.xattrs[opts.Name] = opts.Value + return nil +} + +// Listxattr returns all names in xattrs. +func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) { + // Keep track of the size of the buffer needed in listxattr(2) for the list. + listSize := 0 + x.mu.RLock() + names := make([]string, 0, len(x.xattrs)) + for n := range x.xattrs { + names = append(names, n) + // Add one byte per null terminator. + listSize += len(n) + 1 + } + x.mu.RUnlock() + if size != 0 && uint64(listSize) > size { + return nil, syserror.ERANGE + } + return names, nil +} + +// Removexattr removes the xattr at 'name'. +func (x *SimpleExtendedAttributes) Removexattr(name string) error { + x.mu.Lock() + defer x.mu.Unlock() + if _, ok := x.xattrs[name]; !ok { + return syserror.ENODATA + } + delete(x.xattrs, name) + return nil +} diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 7792eb1a0..f06946103 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -233,9 +233,9 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia } vd.dentry.mu.Lock() } - // TODO: Linux requires that either both the mount point and the mount root - // are directories, or neither are, and returns ENOTDIR if this is not the - // case. + // TODO(gvisor.dev/issue/1035): Linux requires that either both the mount + // point and the mount root are directories, or neither are, and returns + // ENOTDIR if this is not the case. mntns := vd.mount.ns mnt := newMount(vfs, fs, root, mntns, opts) vfs.mounts.seq.BeginWrite() @@ -274,9 +274,9 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti } } - // TODO(jamieliu): Linux special-cases umount of the caller's root, which - // we don't implement yet (we'll just fail it since the caller holds a - // reference on it). + // TODO(gvisor.dev/issue/1035): Linux special-cases umount of the caller's + // root, which we don't implement yet (we'll just fail it since the caller + // holds a reference on it). vfs.mounts.seq.BeginWrite() if opts.Flags&linux.MNT_DETACH == 0 { @@ -835,7 +835,8 @@ func superBlockOpts(mountPath string, mnt *Mount) string { // NOTE(b/147673608): If the mount is a cgroup, we also need to include // the cgroup name in the options. For now we just read that from the // path. - // TODO(gvisor.dev/issues/190): Once gVisor has full cgroup support, we + // + // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we // should get this value from the cgroup itself, and not rely on the // path. if mnt.fs.FilesystemType().Name() == "cgroup" { diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index 3b933468d..3335e4057 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -55,7 +55,7 @@ func TestMountTableInsertLookup(t *testing.T) { } } -// TODO: concurrent lookup/insertion/removal +// TODO(gvisor.dev/issue/1035): concurrent lookup/insertion/removal. // must be powers of 2 var benchNumMounts = []int{1 << 2, 1 << 5, 1 << 8} diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 3e90dc4ed..534528ce6 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -16,6 +16,7 @@ package vfs import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" ) // GetDentryOptions contains options to VirtualFilesystem.GetDentryAt() and @@ -44,6 +45,10 @@ type MknodOptions struct { // DevMinor are the major and minor device numbers for the created device. DevMajor uint32 DevMinor uint32 + + // Endpoint is the endpoint to bind to the created file, if a socket file is + // being created for bind(2) on a Unix domain socket. + Endpoint transport.BoundEndpoint } // MountFlags contains flags as specified for mount(2), e.g. MS_NOEXEC. @@ -127,6 +132,20 @@ type SetStatOptions struct { Stat linux.Statx } +// GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(), +// FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and +// FileDescriptionImpl.Getxattr(). +type GetxattrOptions struct { + // Name is the name of the extended attribute to retrieve. + Name string + + // Size is the maximum value size that the caller will tolerate. If the value + // is larger than size, getxattr methods may return ERANGE, but they are also + // free to ignore the hint entirely (i.e. the value returned may be larger + // than size). All size checking is done independently at the syscall layer. + Size uint64 +} + // SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(), // FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and // FileDescriptionImpl.Setxattr(). diff --git a/pkg/sentry/vfs/timerfd.go b/pkg/sentry/vfs/timerfd.go new file mode 100644 index 000000000..42b880656 --- /dev/null +++ b/pkg/sentry/vfs/timerfd.go @@ -0,0 +1,142 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/context" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// TimerFileDescription implements FileDescriptionImpl for timer fds. It also +// implements ktime.TimerListener. +type TimerFileDescription struct { + vfsfd FileDescription + FileDescriptionDefaultImpl + DentryMetadataFileDescriptionImpl + + events waiter.Queue + timer *ktime.Timer + + // val is the number of timer expirations since the last successful + // call to PRead, or SetTime. val must be accessed using atomic memory + // operations. + val uint64 +} + +var _ FileDescriptionImpl = (*TimerFileDescription)(nil) +var _ ktime.TimerListener = (*TimerFileDescription)(nil) + +// NewTimerFD returns a new timer fd. +func (vfs *VirtualFilesystem) NewTimerFD(clock ktime.Clock, flags uint32) (*FileDescription, error) { + vd := vfs.NewAnonVirtualDentry("[timerfd]") + defer vd.DecRef() + tfd := &TimerFileDescription{} + tfd.timer = ktime.NewTimer(clock, tfd) + if err := tfd.vfsfd.Init(tfd, flags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{ + UseDentryMetadata: true, + DenyPRead: true, + DenyPWrite: true, + InvalidWrite: true, + }); err != nil { + return nil, err + } + return &tfd.vfsfd, nil +} + +// Read implements FileDescriptionImpl.Read. +func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { + const sizeofUint64 = 8 + if dst.NumBytes() < sizeofUint64 { + return 0, syserror.EINVAL + } + if val := atomic.SwapUint64(&tfd.val, 0); val != 0 { + var buf [sizeofUint64]byte + usermem.ByteOrder.PutUint64(buf[:], val) + if _, err := dst.CopyOut(ctx, buf[:]); err != nil { + // Linux does not undo consuming the number of + // expirations even if writing to userspace fails. + return 0, err + } + return sizeofUint64, nil + } + return 0, syserror.ErrWouldBlock +} + +// Clock returns the timer fd's Clock. +func (tfd *TimerFileDescription) Clock() ktime.Clock { + return tfd.timer.Clock() +} + +// GetTime returns the associated Timer's setting and the time at which it was +// observed. +func (tfd *TimerFileDescription) GetTime() (ktime.Time, ktime.Setting) { + return tfd.timer.Get() +} + +// SetTime atomically changes the associated Timer's setting, resets the number +// of expirations to 0, and returns the previous setting and the time at which +// it was observed. +func (tfd *TimerFileDescription) SetTime(s ktime.Setting) (ktime.Time, ktime.Setting) { + return tfd.timer.SwapAnd(s, func() { atomic.StoreUint64(&tfd.val, 0) }) +} + +// Readiness implements waiter.Waitable.Readiness. +func (tfd *TimerFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { + var ready waiter.EventMask + if atomic.LoadUint64(&tfd.val) != 0 { + ready |= waiter.EventIn + } + return ready +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (tfd *TimerFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + tfd.events.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (tfd *TimerFileDescription) EventUnregister(e *waiter.Entry) { + tfd.events.EventUnregister(e) +} + +// PauseTimer pauses the associated Timer. +func (tfd *TimerFileDescription) PauseTimer() { + tfd.timer.Pause() +} + +// ResumeTimer resumes the associated Timer. +func (tfd *TimerFileDescription) ResumeTimer() { + tfd.timer.Resume() +} + +// Release implements FileDescriptionImpl.Release() +func (tfd *TimerFileDescription) Release() { + tfd.timer.Destroy() +} + +// Notify implements ktime.TimerListener.Notify. +func (tfd *TimerFileDescription) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { + atomic.AddUint64(&tfd.val, exp) + tfd.events.Notify(waiter.EventIn) + return ktime.Setting{}, false +} + +// Destroy implements ktime.TimerListener.Destroy. +func (tfd *TimerFileDescription) Destroy() {} diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 720b90d8f..f592913d5 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -680,10 +680,10 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti // ListxattrAt returns all extended attribute names for the file at the given // path. -func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) ([]string, error) { +func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) { rp := vfs.getResolvingPath(creds, pop) for { - names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp) + names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size) if err == nil { vfs.putResolvingPath(rp) return names, nil @@ -705,10 +705,10 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede // GetxattrAt returns the value associated with the given extended attribute // for the file at the given path. -func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) (string, error) { +func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetxattrOptions) (string, error) { rp := vfs.getResolvingPath(creds, pop) for { - val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, name) + val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts) if err == nil { vfs.putResolvingPath(rp) return val, nil diff --git a/pkg/state/state.go b/pkg/state/state.go index dbe507ab4..03ae2dbb0 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -241,10 +241,7 @@ func Register(name string, instance interface{}, fns Fns) { // // This function is used by the stateify tool. func IsZeroValue(val interface{}) bool { - if val == nil { - return true - } - return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface()) + return val == nil || reflect.ValueOf(val).Elem().IsZero() } // step captures one encoding / decoding step. On each step, there is up to one diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 5340cf0d6..0e35d7d17 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -31,13 +31,13 @@ go_library( name = "sync", srcs = [ "aliases.go", - "downgradable_rwmutex_unsafe.go", "memmove_unsafe.go", + "mutex_unsafe.go", "norace_unsafe.go", "race_unsafe.go", + "rwmutex_unsafe.go", "seqcount.go", - "syncutil.go", - "tmutex_unsafe.go", + "sync.go", ], ) @@ -45,9 +45,9 @@ go_test( name = "sync_test", size = "small", srcs = [ - "downgradable_rwmutex_test.go", + "mutex_test.go", + "rwmutex_test.go", "seqcount_test.go", - "tmutex_test.go", ], library = ":sync", ) diff --git a/pkg/sync/tmutex_test.go b/pkg/sync/mutex_test.go index 0838248b4..0838248b4 100644 --- a/pkg/sync/tmutex_test.go +++ b/pkg/sync/mutex_test.go diff --git a/pkg/sync/tmutex_unsafe.go b/pkg/sync/mutex_unsafe.go index 3dd15578b..3dd15578b 100644 --- a/pkg/sync/tmutex_unsafe.go +++ b/pkg/sync/mutex_unsafe.go diff --git a/pkg/sync/downgradable_rwmutex_test.go b/pkg/sync/rwmutex_test.go index ce667e825..ce667e825 100644 --- a/pkg/sync/downgradable_rwmutex_test.go +++ b/pkg/sync/rwmutex_test.go diff --git a/pkg/sync/downgradable_rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go index ea6cdc447..ea6cdc447 100644 --- a/pkg/sync/downgradable_rwmutex_unsafe.go +++ b/pkg/sync/rwmutex_unsafe.go diff --git a/pkg/sync/syncutil.go b/pkg/sync/sync.go index b16cf5333..b16cf5333 100644 --- a/pkg/sync/syncutil.go +++ b/pkg/sync/sync.go diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8d42cd066..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -17,6 +17,7 @@ package buffer import ( "bytes" + "io" ) // View is a slice of a buffer, with convenience methods. @@ -89,6 +90,47 @@ func (vv *VectorisedView) TrimFront(count int) { } } +// Read implements io.Reader. +func (vv *VectorisedView) Read(v View) (copied int, err error) { + count := len(v) + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + copy(v[copied:], vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return copied, nil + } + count -= len(vv.views[0]) + copy(v[copied:], vv.views[0]) + copied += len(vv.views[0]) + vv.RemoveFirst() + } + if copied == 0 { + return 0, io.EOF + } + return copied, nil +} + +// ReadToVV reads up to n bytes from vv to dstVV and removes them from vv. It +// returns the number of bytes copied. +func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int) { + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + dstVV.AppendView(vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return + } + count -= len(vv.views[0]) + dstVV.AppendView(vv.views[0]) + copied += len(vv.views[0]) + vv.RemoveFirst() + } + return copied +} + // CapLength irreversibly reduces the length of the vectorised view. func (vv *VectorisedView) CapLength(length int) { if length < 0 { @@ -116,12 +158,12 @@ func (vv *VectorisedView) CapLength(length int) { // Clone returns a clone of this VectorisedView. // If the buffer argument is large enough to contain all the Views of this VectorisedView, // the method will avoid allocations and use the buffer to store the Views of the clone. -func (vv VectorisedView) Clone(buffer []View) VectorisedView { +func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } // First returns the first view of the vectorised view. -func (vv VectorisedView) First() View { +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { return nil } @@ -134,11 +176,12 @@ func (vv *VectorisedView) RemoveFirst() { return } vv.size -= len(vv.views[0]) + vv.views[0] = nil vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. -func (vv VectorisedView) Size() int { +func (vv *VectorisedView) Size() int { return vv.size } @@ -146,7 +189,7 @@ func (vv VectorisedView) Size() int { // // If the vectorised view contains a single view, that view will be returned // directly. -func (vv VectorisedView) ToView() View { +func (vv *VectorisedView) ToView() View { if len(vv.views) == 1 { return vv.views[0] } @@ -158,7 +201,7 @@ func (vv VectorisedView) ToView() View { } // Views returns the slice containing the all views. -func (vv VectorisedView) Views() []View { +func (vv *VectorisedView) Views() []View { return vv.views } diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index ebc3a17b7..106e1994c 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -233,3 +233,140 @@ func TestToClone(t *testing.T) { }) } } + +func TestVVReadToVV(t *testing.T) { + testCases := []struct { + comment string + vv VectorisedView + bytesToRead int + wantBytes string + leftVV VectorisedView + }{ + { + comment: "large VV, short read", + vv: vv(30, "012345678901234567890123456789"), + bytesToRead: 10, + wantBytes: "0123456789", + leftVV: vv(20, "01234567890123456789"), + }, + { + comment: "largeVV, multiple views, short read", + vv: vv(13, "123", "345", "567", "8910"), + bytesToRead: 6, + wantBytes: "123345", + leftVV: vv(7, "567", "8910"), + }, + { + comment: "smallVV (multiple views), large read", + vv: vv(3, "1", "2", "3"), + bytesToRead: 10, + wantBytes: "123", + leftVV: vv(0, ""), + }, + { + comment: "smallVV (single view), large read", + vv: vv(1, "1"), + bytesToRead: 10, + wantBytes: "1", + leftVV: vv(0, ""), + }, + { + comment: "emptyVV, large read", + vv: vv(0, ""), + bytesToRead: 10, + wantBytes: "", + leftVV: vv(0, ""), + }, + } + + for _, tc := range testCases { + t.Run(tc.comment, func(t *testing.T) { + var readTo VectorisedView + inSize := tc.vv.Size() + copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead) + if got, want := copied, len(tc.wantBytes); got != want { + t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc) + } + if got, want := string(readTo.ToView()), tc.wantBytes; got != want { + t.Errorf("unexpected content in readTo got: %s, want: %s", got, want) + } + if got, want := tc.vv.Size(), inSize-copied; got != want { + t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want { + t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want) + } + }) + } +} + +func TestVVRead(t *testing.T) { + testCases := []struct { + comment string + vv VectorisedView + bytesToRead int + readBytes string + leftBytes string + wantError bool + }{ + { + comment: "large VV, short read", + vv: vv(30, "012345678901234567890123456789"), + bytesToRead: 10, + readBytes: "0123456789", + leftBytes: "01234567890123456789", + }, + { + comment: "largeVV, multiple buffers, short read", + vv: vv(13, "123", "345", "567", "8910"), + bytesToRead: 6, + readBytes: "123345", + leftBytes: "5678910", + }, + { + comment: "smallVV, large read", + vv: vv(3, "1", "2", "3"), + bytesToRead: 10, + readBytes: "123", + leftBytes: "", + }, + { + comment: "smallVV, large read", + vv: vv(1, "1"), + bytesToRead: 10, + readBytes: "1", + leftBytes: "", + }, + { + comment: "emptyVV, large read", + vv: vv(0, ""), + bytesToRead: 10, + readBytes: "", + wantError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.comment, func(t *testing.T) { + readTo := NewView(tc.bytesToRead) + inSize := tc.vv.Size() + copied, err := tc.vv.Read(readTo) + if !tc.wantError && err != nil { + t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err) + } + readTo = readTo[:copied] + if got, want := copied, len(tc.readBytes); got != want { + t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(readTo), tc.readBytes; got != want { + t.Errorf("unexpected data in readTo got: %s, want: %s", got, want) + } + if got, want := tc.vv.Size(), inSize-copied; got != want { + t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want { + t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want) + } + }) + } +} diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 8dc0f7c0e..307f1b666 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -728,7 +728,7 @@ func ICMPv6Code(want byte) TransportChecker { // message for type of ty, with potentially additional checks specified by // checkers. // -// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid // NDP message as far as the size of the message (minSize) is concerned. The // values within the message are up to checkers to validate. func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { @@ -760,9 +760,9 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N // Neighbor Solicitation message (as per the raw wire format), with potentially // additional checks specified by checkers. // -// checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPNS message as far as the size of the messages concerned. The values within -// the message are up to checkers to validate. +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPNS message as far as the size of the message is concerned. The values +// within the message are up to checkers to validate. func NDPNS(checkers ...TransportChecker) NetworkChecker { return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) } @@ -780,7 +780,54 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { ns := header.NDPNeighborSolicit(icmp.NDPPayload()) if got := ns.TargetAddress(); got != want { - t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want) + t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) + } + } +} + +// NDPNA creates a checker that checks that the packet contains a valid NDP +// Neighbor Advertisement message (as per the raw wire format), with potentially +// additional checks specified by checkers. +// +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPNA message as far as the size of the message is concerned. The values +// within the message are up to checkers to validate. +func NDPNA(checkers ...TransportChecker) NetworkChecker { + return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...) +} + +// NDPNATargetAddress creates a checker that checks the Target Address field of +// a header.NDPNeighborAdvert. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNA message as far as the size is concerned. +func NDPNATargetAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + + if got := na.TargetAddress(); got != want { + t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) + } + } +} + +// NDPNASolicitedFlag creates a checker that checks the Solicited field of +// a header.NDPNeighborAdvert. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNA message as far as the size is concerned. +func NDPNASolicitedFlag(want bool) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + + if got := na.SolicitedFlag(); got != want { + t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) } } } @@ -819,6 +866,13 @@ func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) } + case header.NDPTargetLinkLayerAddressOption: + gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { + t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) + } default: t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) } @@ -831,6 +885,21 @@ func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption } } +// NDPNAOptions creates a checker that checks that the packet contains the +// provided NDP options within an NDP Neighbor Solicitation message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNA message as far as the size is concerned. +func NDPNAOptions(opts []header.NDPOption) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + ndpOptions(t, na.Options(), opts) + } +} + // NDPNSOptions creates a checker that checks that the packet contains the // provided NDP options within an NDP Neighbor Solicitation message. // @@ -849,7 +918,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker { // NDPRS creates a checker that checks that the packet contains a valid NDP // Router Solicitation message (as per the raw wire format). // -// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid // NDPRS as far as the size of the message is concerned. The values within the // message are up to checkers to validate. func NDPRS(checkers ...TransportChecker) NetworkChecker { diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index 7a0014ad9..14413f2ce 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -88,7 +88,7 @@ func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr { - t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", got, test.expectedLinkAddr) + t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", test.addr, got, test.expectedLinkAddr) } }) } diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 1b6c3f328..2c4591409 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -62,6 +62,10 @@ const ( // within an IPv6RoutingExtHdr. ipv6RoutingExtHdrSegmentsLeftIdx = 1 + // IPv6FragmentExtHdrLength is the length of an IPv6 extension header, in + // bytes. + IPv6FragmentExtHdrLength = 8 + // ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the // Fragment Offset field within an IPv6FragmentExtHdr. ipv6FragmentExtHdrFragmentOffsetOffset = 0 @@ -391,17 +395,24 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa } // AsRawHeader returns the remaining payload of i as a raw header and -// completes the iterator. +// optionally consumes the iterator. // -// Calls to Next after calling AsRawHeader on i will indicate that the -// iterator is done. -func (i *IPv6PayloadIterator) AsRawHeader() IPv6RawPayloadHeader { - buf := i.payload +// If consume is true, calls to Next after calling AsRawHeader on i will +// indicate that the iterator is done. +func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { identifier := i.nextHdrIdentifier - // Mark i as done. - *i = IPv6PayloadIterator{ - nextHdrIdentifier: IPv6NoNextHeaderIdentifier, + var buf buffer.VectorisedView + if consume { + // Since we consume the iterator, we return the payload as is. + buf = i.payload + + // Mark i as done. + *i = IPv6PayloadIterator{ + nextHdrIdentifier: IPv6NoNextHeaderIdentifier, + } + } else { + buf = i.payload.Clone(nil) } return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf} @@ -420,7 +431,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { // a fragment extension header as the data following the fragment extension // header may not be complete. if i.forceRaw { - return i.AsRawHeader(), false, nil + return i.AsRawHeader(true /* consume */), false, nil } // Is the header we are parsing a known extension header? @@ -452,10 +463,12 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { fragmentExtHdr := IPv6FragmentExtHdr(data) - // If the packet is a fragmented packet, do not attempt to parse - // anything after the fragment extension header as the data following - // the extension header may not be complete. - if fragmentExtHdr.More() || fragmentExtHdr.FragmentOffset() != 0 { + // If the packet is not the first fragment, do not attempt to parse anything + // after the fragment extension header as the payload following the fragment + // extension header should not contain any headers; the first fragment must + // hold all the headers up to and including any upper layer headers, as per + // RFC 8200 section 4.5. + if fragmentExtHdr.FragmentOffset() != 0 { i.forceRaw = true } @@ -476,7 +489,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { default: // The header we are parsing is not a known extension header. Return the // raw payload. - return i.AsRawHeader(), false, nil + return i.AsRawHeader(true /* consume */), false, nil } } diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go index 133ccc8b6..ab20c5f37 100644 --- a/pkg/tcpip/header/ipv6_extension_headers_test.go +++ b/pkg/tcpip/header/ipv6_extension_headers_test.go @@ -673,19 +673,26 @@ func TestIPv6ExtHdrIter(t *testing.T) { payload buffer.VectorisedView expected []IPv6PayloadHeader }{ - // With a non-atomic fragment, the payload after the fragment will not be - // parsed because the payload may not be complete. + // With a non-atomic fragment that is not the first fragment, the payload + // after the fragment will not be parsed because the payload is expected to + // only hold upper layer data. { - name: "hopbyhop - fragment - routing - upper", + name: "hopbyhop - fragment (not first) - routing - upper", firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, payload: makeVectorisedViewFromByteBuffers([]byte{ // Hop By Hop extension header. uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, // Fragment extension header. + // + // More = 1, Fragment Offset = 2117, ID = 2147746305 uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1, // Routing extension header. + // + // Even though we have a routing ext header here, it should be + // be interpretted as raw bytes as only the first fragment is expected + // to hold headers. 255, 0, 1, 2, 3, 4, 5, 6, // Upper layer data. @@ -701,6 +708,34 @@ func TestIPv6ExtHdrIter(t *testing.T) { }, }, { + name: "hopbyhop - fragment (first) - routing - upper", + firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Hop By Hop extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4, + + // Fragment extension header. + // + // More = 1, Fragment Offset = 0, ID = 2147746305 + uint8(IPv6RoutingExtHdrIdentifier), 0, 0, 1, 128, 4, 2, 1, + + // Routing extension header. + 255, 0, 1, 2, 3, 4, 5, 6, + + // Upper layer data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}}, + IPv6FragmentExtHdr([6]byte{0, 1, 128, 4, 2, 1}), + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6RawPayloadHeader{ + Identifier: 255, + Buf: upperLayerData.ToVectorisedView(), + }, + }, + }, + { name: "fragment - routing - upper (across views)", firstNextHdr: IPv6FragmentExtHdrIdentifier, payload: makeVectorisedViewFromByteBuffers([]byte{ diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 1cb9f5dc8..969f8f1e8 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -115,7 +115,7 @@ func TestNDPNeighborAdvert(t *testing.T) { // Make sure flags got updated in the backing buffer. if got := b[ndpNAFlagsOffset]; got != 64 { - t.Errorf("got flags byte = %d, want = 64") + t.Errorf("got flags byte = %d, want = 64", got) } } diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index a8d6653ce..b4a0ae53d 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -28,7 +28,7 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Pkt stack.PacketBuffer + Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO Route stack.Route @@ -257,7 +257,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne route := r.Clone() route.Release() p := PacketInfo{ - Pkt: pkt, + Pkt: &pkt, Proto: protocol, GSO: gso, Route: route, @@ -269,21 +269,15 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() route.Release() - payloadView := pkts[0].Data.ToView() n := 0 - for _, pkt := range pkts { - off := pkt.DataOffset - size := pkt.DataSize + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ - Pkt: stack.PacketBuffer{ - Header: pkt.Header, - Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), - }, + Pkt: pkt, Proto: protocol, GSO: gso, Route: route, @@ -301,7 +295,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: stack.PacketBuffer{Data: vv}, + Pkt: &stack.PacketBuffer{Data: vv}, Proto: 0, GSO: nil, } diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 3b3b6909b..b857ce9d0 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -91,7 +91,7 @@ func (p PacketDispatchMode) String() string { case PacketMMap: return "PacketMMap" default: - return fmt.Sprintf("unknown packet dispatch mode %v", p) + return fmt.Sprintf("unknown packet dispatch mode '%d'", p) } } @@ -441,118 +441,106 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - var ethHdrBuf []byte - // hdr + data - iovLen := 2 - if e.hdrSize > 0 { - // Add ethernet header if needed. - ethHdrBuf = make([]byte, header.EthernetMinimumSize) - eth := header.Ethernet(ethHdrBuf) - ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, - Type: protocol, - } - - // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - iovLen++ - } +// +// NOTE: This API uses sendmmsg to batch packets. As a result the underlying FD +// picked to write the packet out has to be the same for all packets in the +// list. In other words all packets in the batch should belong to the same +// flow. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + n := pkts.Len() - n := len(pkts) - - views := pkts[0].Data.Views() - /* - * Each boundary in views can add one more iovec. - * - * payload | | | | - * ----------------------------- - * packets | | | | | | | - * ----------------------------- - * iovecs | | | | | | | | | - */ - iovec := make([]syscall.Iovec, n*iovLen+len(views)-1) mmsgHdrs := make([]rawfile.MMsgHdr, n) + i := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + var ethHdrBuf []byte + iovLen := 0 + if e.hdrSize > 0 { + // Add ethernet header if needed. + ethHdrBuf = make([]byte, header.EthernetMinimumSize) + eth := header.Ethernet(ethHdrBuf) + ethHdr := &header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + Type: protocol, + } - iovecIdx := 0 - viewIdx := 0 - viewOff := 0 - off := 0 - nextOff := 0 - for i := range pkts { - // TODO(b/134618279): Different packets may have different data - // in the future. We should handle this. - if !viewsEqual(pkts[i].Data.Views(), views) { - panic("All packets in pkts should have the same Data.") + // Preserve the src address if it's set in the route. + if r.LocalLinkAddress != "" { + ethHdr.SrcAddr = r.LocalLinkAddress + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) + iovLen++ } - prevIovecIdx := iovecIdx - mmsgHdr := &mmsgHdrs[i] - mmsgHdr.Msg.Iov = &iovec[iovecIdx] - packetSize := pkts[i].DataSize - hdr := &pkts[i].Header - - off = pkts[i].DataOffset - if off != nextOff { - // We stop in a different point last time. - size := packetSize - viewIdx = 0 - viewOff = 0 - for size > 0 { - if size >= len(views[viewIdx]) { - viewIdx++ - viewOff = 0 - size -= len(views[viewIdx]) - } else { - viewOff = size - size = 0 + var vnetHdrBuf []byte + vnetHdr := virtioNetHdr{} + if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if gso != nil { + vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + if gso.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen + vnetHdr.csumOffset = gso.CsumOffset + } + if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { + switch gso.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", gso.Type)) + } + vnetHdr.gsoSize = gso.MSS } } + vnetHdrBuf = vnetHdrToByteSlice(&vnetHdr) + iovLen++ } - nextOff = off + packetSize + iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views())) + mmsgHdr := &mmsgHdrs[i] + mmsgHdr.Msg.Iov = &iovecs[0] + iovecIdx := 0 + if vnetHdrBuf != nil { + v := &iovecs[iovecIdx] + v.Base = &vnetHdrBuf[0] + v.Len = uint64(len(vnetHdrBuf)) + iovecIdx++ + } if ethHdrBuf != nil { - v := &iovec[iovecIdx] + v := &iovecs[iovecIdx] v.Base = ðHdrBuf[0] v.Len = uint64(len(ethHdrBuf)) iovecIdx++ } - - v := &iovec[iovecIdx] + pktSize := uint64(0) + // Encode L3 Header + v := &iovecs[iovecIdx] + hdr := &pkt.Header hdrView := hdr.View() v.Base = &hdrView[0] v.Len = uint64(len(hdrView)) + pktSize += v.Len iovecIdx++ - for packetSize > 0 { - vec := &iovec[iovecIdx] + // Now encode the Transport Payload. + pktViews := pkt.Data.Views() + for i := range pktViews { + vec := &iovecs[iovecIdx] iovecIdx++ - - v := views[viewIdx] - vec.Base = &v[viewOff] - s := len(v) - viewOff - if s <= packetSize { - viewIdx++ - viewOff = 0 - } else { - s = packetSize - viewOff += s - } - vec.Len = uint64(s) - packetSize -= s + vec.Base = &pktViews[i][0] + vec.Len = uint64(len(pktViews[i])) + pktSize += vec.Len } - - mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx) + mmsgHdr.Msg.Iovlen = uint64(iovecIdx) + i++ } packets := 0 for packets < n { - fd := e.fds[pkts[packets].Hash%uint32(len(e.fds))] + fd := e.fds[pkts.Front().Hash%uint32(len(e.fds))] sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs) if err != nil { return packets, err diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 4039753b7..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index f5973066d..a5478ce17 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,7 +87,7 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { return 0, tcpip.ErrNoRoute diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 6461d0108..0796d717e 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -214,7 +214,7 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0a6b8945c..062388f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -200,7 +200,7 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("send", protocol, pkt.Header.View(), gso) } @@ -233,20 +233,16 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumb // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - e.dumpPacket(gso, protocol, pkt) + e.dumpPacket(gso, protocol, &pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - view := pkts[0].Data.ToView() - for _, pkt := range pkts { - e.dumpPacket(gso, protocol, stack.PacketBuffer{ - Header: pkt.Header, - Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(), - }) +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.dumpPacket(gso, protocol, pkt) } return e.lower.WritePackets(r, gso, pkts, protocol) } diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 52fe397bf..2b3741276 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -112,9 +112,9 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { if !e.writeGate.Enter() { - return len(pkts), nil + return pkts.Len(), nil } n, err := e.lower.WritePackets(r, gso, pkts, protocol) diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 88224e494..54eb5322b 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -71,9 +71,9 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcp } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - e.writeCount += len(pkts) - return len(pkts), nil +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + e.writeCount += pkts.Len() + return pkts.Len(), nil } func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 255098372..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index b3e239ac7..1646d9cde 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -138,7 +138,8 @@ func TestDirectRequest(t *testing.T) { // Sleep tests are gross, but this will only potentially flake // if there's a bug. If there is no bug this will reliably // succeed. - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() if pkt, ok := c.linkEP.ReadContext(ctx); ok { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4950d69fc..4c20301c6 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,7 +172,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a7d9a8b25..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -280,28 +280,47 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } if r.Loop&stack.PacketOut == 0 { - return len(pkts), nil + return pkts.Len(), nil + } + + for pkt := pkts.Front(); pkt != nil; { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) + pkt = pkt.Next() } // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - for i := range pkts { - if ok := ipt.Check(stack.Output, pkts[i]); !ok { - // iptables is telling us to drop the packet. + dropped := ipt.CheckPackets(stack.Output, pkts) + if len(dropped) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + + // Slow Path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { continue } - ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) - pkts[i].NetworkHeader = buffer.View(ip) + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + n++ } - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + return n, nil } // WriteHeaderIncludedPacket writes a packet already containing a network diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index a93a7621a..3f71fc520 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -31,6 +31,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index e0dd5afd3..dc0369156 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -62,7 +62,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived @@ -79,32 +79,22 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Only the first view in vv is accounted for by h. To account for the // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. - payload := pkt.Data + payload := pkt.Data.Clone(nil) payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return } - // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and - // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field - // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not - // set to 0. - switch h.Type() { - case header.ICMPv6NeighborSolicit, - header.ICMPv6NeighborAdvert, - header.ICMPv6RouterSolicit, - header.ICMPv6RouterAdvert, - header.ICMPv6RedirectMsg: - if iph.HopLimit() != header.NDPHopLimit { - received.Invalid.Increment() - return - } - - if h.Code() != 0 { - received.Invalid.Increment() - return - } + isNDPValid := func() bool { + // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and + // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field + // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not + // set to 0. + // + // As per RFC 6980 section 5, nodes MUST silently drop NDP messages if the + // packet includes a fragmentation header. + return !hasFragmentHeader && iph.HopLimit() == header.NDPHopLimit && h.Code() == 0 } // TODO(b/112892170): Meaningfully handle all ICMP types. @@ -133,7 +123,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize { + if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -148,53 +138,48 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P targetAddr := ns.TargetAddress() s := r.Stack() - rxNICID := r.NICID() - if isTentative, err := s.IsAddrTentative(rxNICID, targetAddr); err != nil { - // We will only get an error if rxNICID is unrecognized, - // which should not happen. For now short-circuit this - // packet. + if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { + // We will only get an error if the NIC is unrecognized, which should not + // happen. For now, drop this packet. // // TODO(b/141002840): Handle this better? return } else if isTentative { - // If the target address is tentative and the source - // of the packet is a unicast (specified) address, then - // the source of the packet is attempting to perform - // address resolution on the target. In this case, the - // solicitation is silently ignored, as per RFC 4862 - // section 5.4.3. + // If the target address is tentative and the source of the packet is a + // unicast (specified) address, then the source of the packet is + // attempting to perform address resolution on the target. In this case, + // the solicitation is silently ignored, as per RFC 4862 section 5.4.3. // - // If the target address is tentative and the source of - // the packet is the unspecified address (::), then we - // know another node is also performing DAD for the - // same address (since targetAddr is tentative for us, - // we know we are also performing DAD on it). In this - // case we let the stack know so it can handle such a - // scenario and do nothing further with the NDP NS. - if iph.SourceAddress() == header.IPv6Any { - s.DupTentativeAddrDetected(rxNICID, targetAddr) + // If the target address is tentative and the source of the packet is the + // unspecified address (::), then we know another node is also performing + // DAD for the same address (since the target address is tentative for us, + // we know we are also performing DAD on it). In this case we let the + // stack know so it can handle such a scenario and do nothing further with + // the NS. + if r.RemoteAddress == header.IPv6Any { + s.DupTentativeAddrDetected(e.nicID, targetAddr) } - // Do not handle neighbor solicitations targeted - // to an address that is tentative on the received - // NIC any further. + // Do not handle neighbor solicitations targeted to an address that is + // tentative on the NIC any further. return } - // At this point we know that targetAddr is not tentative on - // rxNICID so the packet is processed as defined in RFC 4861, - // as per RFC 4862 section 5.4.3. + // At this point we know that the target address is not tentative on the NIC + // so the packet is processed as defined in RFC 4861, as per RFC 4862 + // section 5.4.3. + // Is the NS targetting us? if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { - // We don't have a useful answer; the best we can do is ignore the request. return } - // If the NS message has the source link layer option, update the link - // address cache with the link address for the sender of the message. + // If the NS message contains the Source Link-Layer Address option, update + // the link address cache with the value of the option. // // TODO(b/148429853): Properly process the NS message and do Neighbor // Unreachability Detection. + var sourceLinkAddr tcpip.LinkAddress for { opt, done, err := it.Next() if err != nil { @@ -207,22 +192,36 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch opt := opt.(type) { case header.NDPSourceLinkLayerAddressOption: - e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, opt.EthernetAddress()) + // No RFCs define what to do when an NS message has multiple Source + // Link-Layer Address options. Since no interface can have multiple + // link-layer addresses, we consider such messages invalid. + if len(sourceLinkAddr) != 0 { + received.Invalid.Increment() + return + } + + sourceLinkAddr = opt.EthernetAddress() } } - optsSerializer := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), + unspecifiedSource := r.RemoteAddress == header.IPv6Any + + // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST + // NOT be included when the source IP address is the unspecified address. + // Otherwise, on link layers that have addresses this option MUST be + // included in multicast solicitations and SHOULD be included in unicast + // solicitations. + if len(sourceLinkAddr) == 0 { + if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { + received.Invalid.Increment() + return + } + } else if unspecifiedSource { + received.Invalid.Increment() + return + } else { + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr) } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - packet.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(packet.NDPPayload()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(targetAddr) - opts := na.Options() - opts.Serialize(optsSerializer) // ICMPv6 Neighbor Solicit messages are always sent to // specially crafted IPv6 multicast addresses. As a result, the @@ -235,6 +234,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P r := r.Clone() defer r.Release() r.LocalAddress = targetAddr + + // As per RFC 4861 section 7.2.4, if the the source of the solicitation is + // the unspecified address, the node MUST set the Solicited flag to zero and + // multicast the advertisement to the all-nodes address. + solicited := true + if unspecifiedSource { + solicited = false + r.RemoteAddress = header.IPv6AllNodesMulticastAddress + } + + // If the NS has a source link-layer option, use the link address it + // specifies as the remote link address for the response instead of the + // source link address of the packet. + // + // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link + // address cache for the right destination link address instead of manually + // patching the route with the remote link address if one is specified in a + // Source Link-Layer Address option. + if len(sourceLinkAddr) != 0 { + r.RemoteLinkAddress = sourceLinkAddr + } + + optsSerializer := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), + } + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + packet.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(packet.NDPPayload()) + na.SetSolicitedFlag(solicited) + na.SetOverrideFlag(true) + na.SetTargetAddress(targetAddr) + opts := na.Options() + opts.Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) @@ -253,7 +286,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize { + if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } @@ -355,8 +388,20 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterSolicit: received.RouterSolicit.Increment() + if !isNDPValid() { + received.Invalid.Increment() + return + } case header.ICMPv6RouterAdvert: + received.RouterAdvert.Increment() + + p := h.NDPPayload() + if len(p) < header.NDPRAMinimumSize || !isNDPValid() { + received.Invalid.Increment() + return + } + routerAddr := iph.SourceAddress() // @@ -370,16 +415,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - p := h.NDPPayload() - - // Is the NDP payload of sufficient size to hold a Router - // Advertisement? - if len(p) < header.NDPRAMinimumSize { - // ...No, silently drop the packet. - received.Invalid.Increment() - return - } - ra := header.NDPRouterAdvert(p) opts := ra.Options() @@ -395,8 +430,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // as RFC 4861 section 6.1.2 is concerned. // - received.RouterAdvert.Increment() - // Tell the NIC to handle the RA. stack := r.Stack() rxNICID := r.NICID() @@ -404,6 +437,10 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RedirectMsg: received.RedirectMsg.Increment() + if !isNDPValid() { + received.Invalid.Increment() + return + } default: received.Invalid.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bae09ed94..bd099a7f8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -32,7 +32,8 @@ import ( const ( linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") ) var ( diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 685239017..331b0817b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -143,19 +143,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } if r.Loop&stack.PacketOut == 0 { - return len(pkts), nil + return pkts.Len(), nil } - for i := range pkts { - hdr := &pkts[i].Header - size := pkts[i].DataSize - ip := e.addIPHeader(r, hdr, size, params) - pkts[i].NetworkHeader = buffer.View(ip) + for pb := pkts.Front(); pb != nil; pb = pb.Next() { + ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) + pb.NetworkHeader = buffer.View(ip) } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) @@ -185,6 +183,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { pkt.Data.CapLength(int(h.PayloadLength())) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data) + hasFragmentHeader := false for firstHeader := true; ; firstHeader = false { extHdr, done, err := it.Next() @@ -257,6 +256,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { } case header.IPv6FragmentExtHdr: + hasFragmentHeader = true + fragmentOffset := extHdr.FragmentOffset() more := extHdr.More() if !more && fragmentOffset == 0 { @@ -269,7 +270,55 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { continue } - rawPayload := it.AsRawHeader() + // Don't consume the iterator if we have the first fragment because we + // will use it to validate that the first fragment holds the upper layer + // header. + rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */) + + if fragmentOffset == 0 { + // Check that the iterator ends with a raw payload as the first fragment + // should include all headers up to and including any upper layer + // headers, as per RFC 8200 section 4.5; only upper layer data + // (non-headers) should follow the fragment extension header. + var lastHdr header.IPv6PayloadHeader + + for { + it, done, err := it.Next() + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + if done { + break + } + + lastHdr = it + } + + // If the last header is a raw header, then the last portion of the IPv6 + // payload is not a known IPv6 extension header. Note, this does not + // mean that the last portion is an upper layer header or not an + // extension header because: + // 1) we do not yet support all extension headers + // 2) we do not validate the upper layer header before reassembling. + // + // This check makes sure that a known IPv6 extension header is not + // present after the Fragment extension header in a non-initial + // fragment. + // + // TODO(#2196): Support IPv6 Authentication and Encapsulated + // Security Payload extension headers. + // TODO(#2333): Validate that the upper layer header is valid. + switch lastHdr.(type) { + case header.IPv6RawPayloadHeader: + default: + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + } + fragmentPayloadLen := rawPayload.Buf.Size() if fragmentPayloadLen == 0 { // Drop the packet as it's marked as a fragment but has no payload. @@ -344,7 +393,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { pkt.Data = extHdr.Buf if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, pkt) + e.handleICMP(r, headerView, pkt, hasFragmentHeader) } else { r.Stats().IP.PacketsDelivered.Increment() // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 37f7e53ce..95e5dbf8e 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -1014,7 +1014,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { ), }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: nil, }, { name: "Two fragments with routing header with non-zero segments left across fragments", diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index f924ed9e1..8db51da96 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -173,6 +174,257 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { } } +func TestNeighorSolicitationResponse(t *testing.T) { + const nicID = 1 + nicAddr := lladdr0 + remoteAddr := lladdr1 + nicAddrSNMC := header.SolicitedNodeAddr(nicAddr) + nicLinkAddr := linkAddr0 + remoteLinkAddr0 := linkAddr1 + remoteLinkAddr1 := linkAddr2 + + tests := []struct { + name string + nsOpts header.NDPOptionsSerializer + nsSrcLinkAddr tcpip.LinkAddress + nsSrc tcpip.Address + nsDst tcpip.Address + nsInvalid bool + naDstLinkAddr tcpip.LinkAddress + naSolicited bool + naSrc tcpip.Address + naDst tcpip.Address + }{ + { + name: "Unspecified source to multicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: false, + naSrc: nicAddr, + naDst: header.IPv6AllNodesMulticastAddress, + }, + { + name: "Unspecified source with source ll option to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + { + name: "Unspecified source to unicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: false, + naSrc: nicAddr, + naDst: header.IPv6AllNodesMulticastAddress, + }, + { + name: "Unspecified source with source ll option to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddr, + nsInvalid: true, + }, + + { + name: "Specified source with 1 source ll to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll different from route to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr1, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source to multicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + { + name: "Specified source with 2 source ll to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + + { + name: "Specified source to unicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll different from route to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr1, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 2 source ll to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(1, 1280, nicLinkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(nicAddr) + opts := ns.Options() + opts.Serialize(test.nsOpts) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + if test.nsInvalid { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + + if p, got := e.Read(); got { + t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) + } + + // If we expected the NS to be invalid, we have nothing else to check. + return + } + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + p, got := e.Read() + if !got { + t.Fatal("expected an NDP NA response") + } + + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + } + + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(test.naSrc), + checker.DstAddr(test.naDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNA( + checker.NDPNASolicitedFlag(test.naSolicited), + checker.NDPNATargetAddress(nicAddr), + checker.NDPNAOptions([]header.NDPOption{ + header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), + }), + )) + }) + } +} + // TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a // valid NDP NA message with the Target Link Layer Address option results in a // new entry in the link address cache for the target of the message. @@ -276,9 +528,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { } } -// TestHopLimitValidation is a test that makes sure that NDP packets are only -// received if their IP header's hop limit is set to 255. -func TestHopLimitValidation(t *testing.T) { +func TestNDPValidation(t *testing.T) { setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { t.Helper() @@ -294,12 +544,19 @@ func TestHopLimitValidation(t *testing.T) { return s, ep, r } - handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) { + handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + nextHdr := uint8(header.ICMPv6ProtocolNumber) + if atomicFragment { + bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength) + bytes[0] = nextHdr + nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) + } + payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), + NextHeader: nextHdr, HopLimit: hopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, @@ -364,61 +621,93 @@ func TestHopLimitValidation(t *testing.T) { }, } + subTests := []struct { + name string + atomicFragment bool + hopLimit uint8 + code uint8 + valid bool + }{ + { + name: "Valid", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 0, + valid: true, + }, + { + name: "Fragmented", + atomicFragment: true, + hopLimit: header.NDPHopLimit, + code: 0, + valid: false, + }, + { + name: "Invalid hop limit", + atomicFragment: false, + hopLimit: header.NDPHopLimit - 1, + code: 0, + valid: false, + }, + { + name: "Invalid ICMPv6 code", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 1, + valid: false, + }, + } + for _, typ := range types { t.Run(typ.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - // Should not have received any ICMPv6 packets with - // type = typ.typ. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Receive the NDP packet with an invalid hop limit - // value. - handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r) - - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - - // Rx count of NDP packet of type typ.typ should not - // have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Receive the NDP packet with a valid hop limit value. - handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r) - - // Rx count of NDP packet of type typ.typ should have - // increased. - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) + for _, test := range subTests { + t.Run(test.name, func(t *testing.T) { + s, ep, r := setup(t) + defer r.Release() + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + extraDataLen := len(typ.extraData) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength) + extraData := buffer.View(hdr.Prepend(extraDataLen)) + copy(extraData, typ.extraData) + pkt := header.ICMPv6(hdr.Prepend(typ.size)) + pkt.SetType(typ.typ) + pkt.SetCode(test.code) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) + + // Rx count of the NDP message should initially be 0. + if got := typStat.Value(); got != 0 { + t.Errorf("got %s = %d, want = 0", typ.name, got) + } + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + + if t.Failed() { + t.FailNow() + } + + handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r) + + // Rx count of the NDP packet should have increased. + if got := typStat.Value(); got != 1 { + t.Errorf("got %s = %d, want = 1", typ.name, got) + } + + want := uint64(0) + if !test.valid { + // Invalid count should have increased. + want = 1 + } + if got := invalid.Value(); got != want { + t.Errorf("got invalid = %d, want = %d", got, want) + } + }) } }) } @@ -592,21 +881,18 @@ func TestRouterAdvertValidation(t *testing.T) { Data: hdr.View().ToVectorisedView(), }) + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } + if test.expectedSuccess { if got := invalid.Value(); got != 0 { t.Fatalf("got invalid = %d, want = 0", got) } - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } - } else { if got := invalid.Value(); got != 1 { t.Fatalf("got invalid = %d, want = 1", got) } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } } }) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 8d80e9cee..5e963a4af 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -15,6 +15,18 @@ go_template_instance( }, ) +go_template_instance( + name = "packet_buffer_list", + out = "packet_buffer_list.go", + package = "stack", + prefix = "PacketBuffer", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*PacketBuffer", + "Linker": "*PacketBuffer", + }, +) + go_library( name = "stack", srcs = [ @@ -29,7 +41,7 @@ go_library( "ndp.go", "nic.go", "packet_buffer.go", - "packet_buffer_state.go", + "packet_buffer_list.go", "rand.go", "registration.go", "route.go", diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c45c43d21..e9c652042 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -101,7 +101,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH } // WritePackets implements LinkEndpoint.WritePackets. -func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } @@ -260,10 +260,10 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 - for _, pkt := range pkts { - e.WritePacket(r, gso, protocol, pkt) + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.WritePacket(r, gso, protocol, *pkt) n++ } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 37907ae24..6c0a4b24d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -209,6 +209,23 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { return true } +// CheckPackets runs pkts through the rules for hook and returns a map of packets that +// should not go forward. +// +// NOTE: unlike the Check API the returned map contains packets that should be +// dropped. +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if ok := it.Check(hook, *pkt); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) + } + drop[pkt] = struct{}{} + } + } + return drop +} + // Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 598468bdd..acb2d4731 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -468,7 +468,7 @@ func TestDADResolve(t *testing.T) { // As per RFC 4861 section 4.3, a possible option is the Source Link // Layer option, but this option MUST NOT be included when the source // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), + checker.IPv6(t, p.Pkt.Header.View(), checker.SrcAddr(header.IPv6Any), checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), @@ -1959,7 +1959,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // addr2 is deprecated but if explicitly requested, it should be used. fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) } // Another PI w/ 0 preferred lifetime should not result in a deprecation @@ -1972,7 +1972,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { } expectPrimaryAddr(addr1) if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) } // Refresh lifetimes of addr generated from prefix2. @@ -2084,7 +2084,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // addr1 is deprecated but if explicitly requested, it should be used. fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make @@ -2097,7 +2097,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } expectPrimaryAddr(addr2) if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Refresh lifetimes for addr of prefix1. @@ -2121,7 +2121,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // addr2 should be the primary endpoint now since it is not deprecated. expectPrimaryAddr(addr2) if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Wait for addr of prefix1 to be invalidated. @@ -2564,7 +2564,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { AddressWithPrefix: addr2, } if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) } // addr2 should be more preferred now since it is at the front of the primary // list. @@ -3483,7 +3483,8 @@ func TestRouterSolicitation(t *testing.T) { e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired waitForPkt := func(timeout time.Duration) { t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() p, ok := e.ReadContext(ctx) if !ok { t.Fatal("timed out waiting for packet") @@ -3513,7 +3514,8 @@ func TestRouterSolicitation(t *testing.T) { } waitForNothing := func(timeout time.Duration) { t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet") } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9367de180..dc125f25e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -23,9 +23,11 @@ import ( // As a PacketBuffer traverses up the stack, it may be necessary to pass it to // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. -// -// +stateify savable type PacketBuffer struct { + // PacketBufferEntry is used to build an intrusive list of + // PacketBuffers. + PacketBufferEntry + // Data holds the payload of the packet. For inbound packets, it also // holds the headers, which are consumed as the packet moves up the // stack. Headers are guaranteed not to be split across views. @@ -34,14 +36,6 @@ type PacketBuffer struct { // or otherwise modified. Data buffer.VectorisedView - // DataOffset is used for GSO output. It is the offset into the Data - // field where the payload of this packet starts. - DataOffset int - - // DataSize is used for GSO output. It is the size of this packet's - // payload. - DataSize int - // Header holds the headers of outbound packets. As a packet is passed // down the stack, each layer adds to Header. Header buffer.Prependable diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ac043b722..23ca9ee03 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -246,7 +246,7 @@ type NetworkEndpoint interface { // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. @@ -393,7 +393,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may // require to change syscall filters. - WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 9fbe8a411..a0e5e0300 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -168,23 +168,26 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuff return err } -// WritePackets writes the set of packets through the given route. -func (r *Route) WritePackets(gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +// WritePackets writes a list of n packets through the given route and returns +// the number of packets written. +func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { if !r.ref.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - payloadSize := 0 - for i := 0; i < n; i++ { - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength())) - payloadSize += pkts[i].DataSize + + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Header.UsedLength() + writtenBytes += pb.Data.Size() } - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) + + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) return n, err } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b8543b71e..c7634ceb1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -153,7 +153,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } @@ -1445,19 +1445,19 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } // If the NIC doesn't exist, it won't work. if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } } @@ -1483,12 +1483,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { } nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) + t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } // Set the initial route table. @@ -1503,10 +1503,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // When an interface is given, the route for a broadcast goes through it. r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } // When an interface is not given, it consults the route table. @@ -2399,7 +2399,7 @@ func TestNICContextPreservation(t *testing.T) { t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) } if got, want := nicinfo.Context == test.want, true; got != want { - t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) + t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) } }) } @@ -2768,7 +2768,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { { subnet, err := tcpip.NewSubnet("\x00", "\x00") if err != nil { - t.Fatalf("NewSubnet failed:", err) + t.Fatalf("NewSubnet failed: %v", err) } s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } @@ -2782,11 +2782,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // permanentExpired kind. r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) if err != nil { - t.Fatal("FindRoute failed:", err) + t.Fatalf("FindRoute failed: %v", err) } defer r.Release() if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } // @@ -2798,7 +2798,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } @@ -2806,7 +2806,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // make sure the new peb was respected. // (The address should just be promoted now). if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } var primaryAddrs []tcpip.Address for _, pa := range s.NICInfo()[1].ProtocolAddresses { @@ -2839,11 +2839,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // GetMainNICAddress; else, our original address // should be returned. if err := s.RemoveAddress(1, "\x03"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } addr, err = s.GetMainNICAddress(1, fakeNetNumber) if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) + t.Fatalf("s.GetMainNICAddress failed: %v", err) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index c65b0c632..2474a7db3 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -206,7 +206,7 @@ func TestTransportDemuxerRegister(t *testing.T) { // the distribution of packets received matches expectations. func TestBindToDeviceDistribution(t *testing.T) { type endpointSockopts struct { - reuse int + reuse bool bindToDevice tcpip.NICID } for _, test := range []struct { @@ -221,11 +221,11 @@ func TestBindToDeviceDistribution(t *testing.T) { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. @@ -236,9 +236,9 @@ func TestBindToDeviceDistribution(t *testing.T) { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {reuse: 0, bindToDevice: 1}, - {reuse: 0, bindToDevice: 2}, - {reuse: 0, bindToDevice: 3}, + {reuse: false, bindToDevice: 1}, + {reuse: false, bindToDevice: 2}, + {reuse: false, bindToDevice: 3}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. @@ -253,12 +253,12 @@ func TestBindToDeviceDistribution(t *testing.T) { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {reuse: 1, bindToDevice: 1}, - {reuse: 1, bindToDevice: 1}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 0}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to @@ -309,9 +309,8 @@ func TestBindToDeviceDistribution(t *testing.T) { }(ep) defer ep.Close() - reusePortOption := tcpip.ReusePortOption(endpoint.reuse) - if err := ep.SetSockOpt(reusePortOption); err != nil { - t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", reusePortOption, i, err) + if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil { + t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(bindToDeviceOption); err != nil { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2ef3271f1..aec7126ff 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -520,34 +520,90 @@ type WriteOptions struct { type SockOptBool int const ( + // BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether + // datagram sockets are allowed to send packets to a broadcast address. + BroadcastOption SockOptBool = iota + + // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be + // held until segments are full by the TCP transport protocol. + CorkOption + + // DelayOption is used by SetSockOpt/GetSockOpt to specify if data + // should be sent out immediately by the transport protocol. For TCP, + // it determines if the Nagle algorithm is on or off. + DelayOption + + // KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether + // TCP keepalive is enabled for this socket. + KeepaliveEnabledOption + + // MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether + // multicast packets sent over a non-loopback interface will be looped back. + MulticastLoopOption + + // PasscredOption is used by SetSockOpt/GetSockOpt to specify whether + // SCM_CREDENTIALS socket control messages are enabled. + // + // Only supported on Unix sockets. + PasscredOption + + // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. + QuickAckOption + // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the // IPV6_TCLASS ancillary message is passed with incoming packets. - ReceiveTClassOption SockOptBool = iota + ReceiveTClassOption // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS // ancillary message is passed with incoming packets. ReceiveTOSOption - // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 - // socket is to be restricted to sending and receiving IPv6 packets only. - V6OnlyOption - // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify // if more inforamtion is provided with incoming packets such // as interface index and address. ReceiveIPPacketInfoOption - // TODO(b/146901447): convert existing bool socket options to be handled via - // Get/SetSockOptBool + // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() + // should allow reuse of local address. + ReuseAddressOption + + // ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets + // to be bound to an identical socket address. + ReusePortOption + + // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 + // socket is to be restricted to sending and receiving IPv6 packets only. + V6OnlyOption ) // SockOptInt represents socket options which values have the int type. type SockOptInt int const ( + // KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number + // of un-ACKed TCP keepalives that will be sent before the connection is + // closed. + KeepaliveCountOption SockOptInt = iota + + // IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS + // for all subsequent outgoing IPv4 packets from the endpoint. + IPv4TOSOption + + // IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS + // for all subsequent outgoing IPv6 packets from the endpoint. + IPv6TrafficClassOption + + // MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current + // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. + MaxSegOption + + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default + // TTL value for multicast messages. The default is 1. + MulticastTTLOption + // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the input buffer should be returned. - ReceiveQueueSizeOption SockOptInt = iota + ReceiveQueueSizeOption // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the send buffer size option. @@ -561,44 +617,21 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption - // DelayOption is used by SetSockOpt/GetSockOpt to specify if data - // should be sent out immediately by the transport protocol. For TCP, - // it determines if the Nagle algorithm is on or off. - DelayOption - - // TODO(b/137664753): convert all int socket options to be handled via - // GetSockOptInt. + // TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop + // limit value for unicast messages. The default is protocol specific. + // + // A zero value indicates the default. + TTLOption ) // ErrorOption is used in GetSockOpt to specify that the last error reported by // the endpoint should be cleared and returned. type ErrorOption struct{} -// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be -// held until segments are full by the TCP transport protocol. -type CorkOption int - -// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() -// should allow reuse of local address. -type ReuseAddressOption int - -// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets -// to be bound to an identical socket address. -type ReusePortOption int - // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. type BindToDeviceOption NICID -// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. -type QuickAckOption int - -// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether -// SCM_CREDENTIALS socket control messages are enabled. -// -// Only supported on Unix sockets. -type PasscredOption int - // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. @@ -607,10 +640,6 @@ type TCPInfoOption struct { RTTVar time.Duration } -// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether -// TCP keepalive is enabled for this socket. -type KeepaliveEnabledOption int - // KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a // connection must remain idle before the first TCP keepalive packet is sent. // Once this time is reached, KeepaliveIntervalOption is used instead. @@ -620,11 +649,6 @@ type KeepaliveIdleOption time.Duration // interval between sending TCP keepalive packets. type KeepaliveIntervalOption time.Duration -// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number -// of un-ACKed TCP keepalives that will be sent before the connection is -// closed. -type KeepaliveCountOption int - // TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user // specified timeout for a given TCP connection. // See: RFC5482 for details. @@ -638,20 +662,9 @@ type CongestionControlOption string // control algorithms. type AvailableCongestionControlOption string -// ModerateReceiveBufferOption allows the caller to enable/disable TCP receive // buffer moderation. type ModerateReceiveBufferOption bool -// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current -// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. -type MaxSegOption int - -// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop -// limit value for unicast messages. The default is protocol specific. -// -// A zero value indicates the default. -type TTLOption uint8 - // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state // before being marked closed. @@ -668,10 +681,6 @@ type TCPTimeWaitTimeoutOption time.Duration // for a handshake till the specified timeout until a segment with data arrives. type TCPDeferAcceptOption time.Duration -// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default -// TTL value for multicast messages. The default is 1. -type MulticastTTLOption uint8 - // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { @@ -679,10 +688,6 @@ type MulticastInterfaceOption struct { InterfaceAddr Address } -// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether -// multicast packets sent over a non-loopback interface will be looped back. -type MulticastLoopOption bool - // MembershipOption is used by SetSockOpt/GetSockOpt as an argument to // AddMembershipOption and RemoveMembershipOption. type MembershipOption struct { @@ -705,22 +710,10 @@ type RemoveMembershipOption MembershipOption // TCP out-of-band data is delivered along with the normal in-band data. type OutOfBandInlineOption int -// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether -// datagram sockets are allowed to send packets to a broadcast address. -type BroadcastOption int - // DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify // a default TTL. type DefaultTTLOption uint8 -// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS -// for all subsequent outgoing IPv4 packets from the endpoint. -type IPv4TOSOption uint8 - -// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS -// for all subsequent outgoing IPv6 packets from the endpoint. -type IPv6TrafficClassOption uint8 - // IPPacketInfo is the message struture for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 8c0aacffa..1c8e2bc34 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -218,7 +218,7 @@ func TestAddressWithPrefixSubnet(t *testing.T) { gotSubnet := ap.Subnet() wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) if err != nil { - t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) + t.Errorf("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) continue } if gotSubnet != wantSubnet { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b007302fb..3a133eef9 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -348,29 +348,37 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(o) - e.mu.Unlock() - } - return nil } // SetSockOptBool sets a socket option. Currently not supported. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - return nil + return tcpip.ErrUnknownProtocolOption } // SetSockOptInt sets a socket option. Currently not supported. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + + default: + return tcpip.ErrUnknownProtocolOption + } return nil } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -397,26 +405,23 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil + case tcpip.TTLOption: + e.rcvMu.Lock() + v := int(e.ttl) + e.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - case *tcpip.TTLOption: - e.rcvMu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.rcvMu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 337bc1c71..eee754a5a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -533,14 +533,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: return tcpip.ErrUnknownProtocolOption } @@ -548,7 +544,13 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -576,9 +578,9 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil + default: + return -1, tcpip.ErrUnknownProtocolOption } - - return -1, tcpip.ErrUnknownProtocolOption } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 3239a5911..2ca3fb809 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -756,8 +756,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) hdr := &pkt.Header - packetSize := pkt.DataSize - off := pkt.DataOffset + packetSize := pkt.Data.Size() // Initialize the header. tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) pkt.TransportHeader = buffer.View(tcp) @@ -782,12 +781,18 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize) + xsum = header.ChecksumVV(pkt.Data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } } func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { + // We need to shallow clone the VectorisedView here as ReadToView will + // split the VectorisedView and Trim underlying views as it splits. Not + // doing the clone here will cause the underlying views of data itself + // to be altered. + data = data.Clone(nil) + optLen := len(tf.opts) if tf.rcvWnd > 0xffff { tf.rcvWnd = 0xffff @@ -796,31 +801,25 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso mss := int(gso.MSS) n := (data.Size() + mss - 1) / mss - // Allocate one big slice for all the headers. - hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen - buf := make([]byte, n*hdrSize) - pkts := make([]stack.PacketBuffer, n) - for i := range pkts { - pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) - } - size := data.Size() - off := 0 + hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen + var pkts stack.PacketBufferList for i := 0; i < n; i++ { packetSize := mss if packetSize > size { packetSize = size } size -= packetSize - pkts[i].DataOffset = off - pkts[i].DataSize = packetSize - pkts[i].Data = data - pkts[i].Hash = tf.txHash - pkts[i].Owner = owner - buildTCPHdr(r, tf, &pkts[i], gso) - off += packetSize + var pkt stack.PacketBuffer + pkt.Header = buffer.NewPrependable(hdrSize) + pkt.Hash = tf.txHash + pkt.Owner = owner + data.ReadToVV(&pkt.Data, packetSize) + buildTCPHdr(r, tf, &pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) + pkts.PushBack(&pkt) } + if tf.ttl == 0 { tf.ttl = r.DefaultTTL() } @@ -845,12 +844,10 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac } pkt := stack.PacketBuffer{ - Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - DataOffset: 0, - DataSize: data.Size(), - Data: data, - Hash: tf.txHash, - Owner: owner, + Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + Data: data, + Hash: tf.txHash, + Owner: owner, } buildTCPHdr(r, tf, &pkt, gso) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9b123e968..a8d443f73 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -821,7 +821,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var de DelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { - e.SetSockOptInt(tcpip.DelayOption, 1) + e.SetSockOptBool(tcpip.DelayOption, true) } var tcpLT tcpip.TCPLingerTimeoutOption @@ -1409,10 +1409,60 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // SetSockOptBool sets a socket option. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - e.LockUser() - defer e.UnlockUser() - switch opt { + + case tcpip.BroadcastOption: + e.LockUser() + e.broadcast = v + e.UnlockUser() + + case tcpip.CorkOption: + e.LockUser() + if !v { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + e.UnlockUser() + + case tcpip.DelayOption: + if v { + atomic.StoreUint32(&e.delay, 1) + } else { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + e.keepalive.enabled = v + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.QuickAckOption: + o := uint32(1) + if v { + o = 0 + } + atomic.StoreUint32(&e.slowAck, o) + + case tcpip.ReuseAddressOption: + e.LockUser() + e.reuseAddr = v + e.UnlockUser() + return nil + + case tcpip.ReusePortOption: + e.LockUser() + e.reusePort = v + e.UnlockUser() + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1424,7 +1474,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + e.LockUser() e.v6only = v + e.UnlockUser() + default: + return tcpip.ErrUnknownProtocolOption } return nil @@ -1432,7 +1486,40 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + e.keepalive.count = int(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.IPv4TOSOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.LockUser() + e.userMSS = uint16(userMSS) + e.UnlockUser() + e.notifyProtocolGoroutine(notifyMSSChanged) + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1483,7 +1570,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvListMu.Unlock() e.UnlockUser() e.notifyProtocolGoroutine(mask) - return nil case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max @@ -1502,52 +1588,21 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.sndBufMu.Lock() e.sndBufSize = size e.sndBufMu.Unlock() - return nil - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) - } - return nil + case tcpip.TTLOption: + e.LockUser() + e.ttl = uint8(v) + e.UnlockUser() default: - return nil + return tcpip.ErrUnknownProtocolOption } + return nil } // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 - const inetECNMask = 3 switch v := opt.(type) { - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.LockUser() - e.reuseAddr = v != 0 - e.UnlockUser() - return nil - - case tcpip.ReusePortOption: - e.LockUser() - e.reusePort = v != 0 - e.UnlockUser() - return nil - case tcpip.BindToDeviceOption: id := tcpip.NICID(v) if id != 0 && !e.stack.HasNIC(id) { @@ -1556,72 +1611,26 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.LockUser() e.bindToDevice = id e.UnlockUser() - return nil - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.LockUser() - e.userMSS = uint16(userMSS) - e.UnlockUser() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - - case tcpip.TTLOption: - e.LockUser() - e.ttl = uint8(v) - e.UnlockUser() - return nil - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - e.keepalive.enabled = v != 0 - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil case tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil case tcpip.KeepaliveIntervalOption: e.keepalive.Lock() e.keepalive.interval = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil - case tcpip.KeepaliveCountOption: - e.keepalive.Lock() - e.keepalive.count = int(v) - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil + case tcpip.OutOfBandInlineOption: + // We don't currently support disabling this option. case tcpip.TCPUserTimeoutOption: e.LockUser() e.userTimeout = time.Duration(v) e.UnlockUser() - return nil - - case tcpip.BroadcastOption: - e.LockUser() - e.broadcast = v != 0 - e.UnlockUser() - return nil case tcpip.CongestionControlOption: // Query the available cc algorithms in the stack and @@ -1652,22 +1661,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // control algorithm is specified. return tcpip.ErrNoSuchFile - case tcpip.IPv4TOSOption: - e.LockUser() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.UnlockUser() - return nil - - case tcpip.IPv6TrafficClassOption: - e.LockUser() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.UnlockUser() - return nil - case tcpip.TCPLingerTimeoutOption: e.LockUser() if v < 0 { @@ -1688,7 +1681,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } e.tcpLingerTimeout = time.Duration(v) e.UnlockUser() - return nil case tcpip.TCPDeferAcceptOption: e.LockUser() @@ -1697,11 +1689,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } e.deferAccept = time.Duration(v) e.UnlockUser() - return nil default: return nil } + return nil } // readyReceiveSize returns the number of bytes ready to be received. @@ -1723,6 +1715,43 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { + case tcpip.BroadcastOption: + e.LockUser() + v := e.broadcast + e.UnlockUser() + return v, nil + + case tcpip.CorkOption: + return atomic.LoadUint32(&e.cork) != 0, nil + + case tcpip.DelayOption: + return atomic.LoadUint32(&e.delay) != 0, nil + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + v := e.keepalive.enabled + e.keepalive.Unlock() + + return v, nil + + case tcpip.QuickAckOption: + v := atomic.LoadUint32(&e.slowAck) == 0 + return v, nil + + case tcpip.ReuseAddressOption: + e.LockUser() + v := e.reuseAddr + e.UnlockUser() + + return v, nil + + case tcpip.ReusePortOption: + e.LockUser() + v := e.reusePort + e.UnlockUser() + + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1734,14 +1763,41 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.UnlockUser() return v, nil - } - return false, tcpip.ErrUnknownProtocolOption + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + v := e.keepalive.count + e.keepalive.Unlock() + return v, nil + + case tcpip.IPv4TOSOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.MaxSegOption: + // This is just stubbed out. Linux never returns the user_mss + // value as it either returns the defaultMSS or returns the + // actual current MSS. Netstack just returns the defaultMSS + // always for now. + v := header.TCPDefaultMSS + return v, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1757,12 +1813,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil - case tcpip.DelayOption: - var o int - if v := atomic.LoadUint32(&e.delay); v != 0 { - o = 1 - } - return o, nil + case tcpip.TTLOption: + e.LockUser() + v := int(e.ttl) + e.UnlockUser() + return v, nil default: return -1, tcpip.ErrUnknownProtocolOption @@ -1779,61 +1834,10 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.lastErrorMu.Unlock() return err - case *tcpip.MaxSegOption: - // This is just stubbed out. Linux never returns the user_mss - // value as it either returns the defaultMSS or returns the - // actual current MSS. Netstack just returns the defaultMSS - // always for now. - *o = header.TCPDefaultMSS - return nil - - case *tcpip.CorkOption: - *o = 0 - if v := atomic.LoadUint32(&e.cork); v != 0 { - *o = 1 - } - return nil - - case *tcpip.ReuseAddressOption: - e.LockUser() - v := e.reuseAddr - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.ReusePortOption: - e.LockUser() - v := e.reusePort - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.BindToDeviceOption: e.LockUser() *o = tcpip.BindToDeviceOption(e.bindToDevice) e.UnlockUser() - return nil - - case *tcpip.QuickAckOption: - *o = 1 - if v := atomic.LoadUint32(&e.slowAck); v != 0 { - *o = 0 - } - return nil - - case *tcpip.TTLOption: - e.LockUser() - *o = tcpip.TTLOption(e.ttl) - e.UnlockUser() - return nil case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} @@ -1846,92 +1850,45 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { o.RTTVar = snd.rtt.rttvar snd.rtt.Unlock() } - return nil - - case *tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - v := e.keepalive.enabled - e.keepalive.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() *o = tcpip.KeepaliveIdleOption(e.keepalive.idle) e.keepalive.Unlock() - return nil case *tcpip.KeepaliveIntervalOption: e.keepalive.Lock() *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval) e.keepalive.Unlock() - return nil - - case *tcpip.KeepaliveCountOption: - e.keepalive.Lock() - *o = tcpip.KeepaliveCountOption(e.keepalive.count) - e.keepalive.Unlock() - return nil case *tcpip.TCPUserTimeoutOption: e.LockUser() *o = tcpip.TCPUserTimeoutOption(e.userTimeout) e.UnlockUser() - return nil case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 - return nil - - case *tcpip.BroadcastOption: - e.LockUser() - v := e.broadcast - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.CongestionControlOption: e.LockUser() *o = e.cc e.UnlockUser() - return nil - - case *tcpip.IPv4TOSOption: - e.LockUser() - *o = tcpip.IPv4TOSOption(e.sendTOS) - e.UnlockUser() - return nil - - case *tcpip.IPv6TrafficClassOption: - e.LockUser() - *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.UnlockUser() - return nil case *tcpip.TCPLingerTimeoutOption: e.LockUser() *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) e.UnlockUser() - return nil case *tcpip.TCPDeferAcceptOption: e.LockUser() *o = tcpip.TCPDeferAcceptOption(e.deferAccept) e.UnlockUser() - return nil default: return tcpip.ErrUnknownProtocolOption } + return nil } // checkV4MappedLocked determines the effective network protocol and converts diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index e6fe7985d..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -77,9 +77,11 @@ func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.V id: id, route: r.Clone(), } - s.views[0] = v - s.data = buffer.NewVectorisedView(len(v), s.views[:1]) s.rcvdTime = time.Now() + if len(v) != 0 { + s.views[0] = v + s.data = buffer.NewVectorisedView(len(v), s.views[:1]) + } return s } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index ce3df7478..29301a45c 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -284,7 +284,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // are released instantly on Close. tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { - t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %v) = %v", tcp.ProtocolNumber, tcpTW, err) } c.EP.Close() @@ -728,7 +728,7 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) { const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize tests := []struct { name string - setMSS uint16 + setMSS int expMSS uint16 }{ { @@ -756,15 +756,14 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) { c.Create(-1) // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) } // Get expected window size. rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) @@ -818,15 +817,14 @@ func TestUserSuppliedMSSOnConnectV6(t *testing.T) { c.CreateV6Endpoint(true) // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) } // Get expected window size. rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) @@ -1077,17 +1075,17 @@ func TestTOSV4(t *testing.T) { c.EP = ep const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) } - var v tcpip.IPv4TOSOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Errorf("GetSockopt failed: %s", err) + v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) } - if want := tcpip.IPv4TOSOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tos { + t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) } testV4Connect(t, c, checker.TOS(tos, 0)) @@ -1125,17 +1123,17 @@ func TestTrafficClassV6(t *testing.T) { c.CreateV6Endpoint(false) const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err) + if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { + t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) } - var v tcpip.IPv6TrafficClassOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockopt failed: %s", err) + v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) } - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tos { + t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) } // Test the connection request. @@ -1711,7 +1709,7 @@ func TestNoWindowShrinking(t *testing.T) { c.CreateConnected(789, 30000, 10) if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %v", err) } we, ch := waiter.NewChannelEntry(nil) @@ -1984,7 +1982,7 @@ func TestScaledWindowAccept(t *testing.T) { // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2057,7 +2055,7 @@ func TestNonScaledWindowAccept(t *testing.T) { // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2221,10 +2219,10 @@ func TestSegmentMerging(t *testing.T) { { "cork", func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(1)) + ep.SetSockOptBool(tcpip.CorkOption, true) }, func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(0)) + ep.SetSockOptBool(tcpip.CorkOption, false) }, }, } @@ -2316,7 +2314,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptInt(tcpip.DelayOption, 1) + c.EP.SetSockOptBool(tcpip.DelayOption, true) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -2364,7 +2362,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptInt(tcpip.DelayOption, 1) + c.EP.SetSockOptBool(tcpip.DelayOption, true) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -2397,7 +2395,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOptInt(tcpip.DelayOption, 0) + c.EP.SetSockOptBool(tcpip.DelayOption, false) // Check that data is received. second := c.GetPacket() @@ -2434,8 +2432,8 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }}, + {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }}, } for _, test := range tests { @@ -2576,12 +2574,12 @@ func TestSetTTL(t *testing.T) { t.Fatalf("NewEndpoint failed: %v", err) } - if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) + if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + t.Fatalf("Unexpected return value from Connect: %s", err) } // Receive SYN packet. @@ -2621,7 +2619,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2765,7 +2763,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { const rcvBufferSize = 0x20000 const wndScale = 2 if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } // Start connection attempt. @@ -3882,26 +3880,26 @@ func TestMinMaxBufferSizes(t *testing.T) { // Set values below the min. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) } checkRecvBufferSize(t, ep, 200) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) } checkSendBufferSize(t, ep, 300) // Set values above the max. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) } checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) @@ -4147,11 +4145,11 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { case "ipv4": case "ipv6": if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err) + t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err) } case "dual": if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err) + t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err) } default: t.Fatalf("unknown network: '%s'", network) @@ -4477,8 +4475,8 @@ func TestKeepalive(t *testing.T) { const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -5609,7 +5607,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { return } if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { - t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd) + t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w) } }, )) @@ -5770,14 +5768,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) { func TestDelayEnabled(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - checkDelayOption(t, c, false, 0) // Delay is disabled by default. + checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { delayEnabled tcp.DelayEnabled - wantDelayOption int + wantDelayOption bool }{ - {delayEnabled: false, wantDelayOption: 0}, - {delayEnabled: true, wantDelayOption: 1}, + {delayEnabled: false, wantDelayOption: false}, + {delayEnabled: true, wantDelayOption: true}, } { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5788,7 +5786,7 @@ func TestDelayEnabled(t *testing.T) { } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { t.Helper() var gotDelayEnabled tcp.DelayEnabled @@ -5803,12 +5801,12 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del if err != nil { t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) } - gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption) + gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption) if err != nil { - t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err) + t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err) } if gotDelayOption != wantDelayOption { - t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) + t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) } } @@ -6620,8 +6618,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) // Set userTimeout to be the duration for 3 keepalive probes. userTimeout := 30 * time.Millisecond diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index d4f6bc635..431ab4e6b 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -217,7 +217,8 @@ func (c *Context) Stack() *stack.Stack { func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), wait) + ctx, cancel := context.WithTimeout(context.Background(), wait) + defer cancel() if _, ok := c.linkEP.ReadContext(ctx); ok { c.t.Fatal(errMsg) } @@ -235,7 +236,8 @@ func (c *Context) CheckNoPacket(errMsg string) { func (c *Context) GetPacket() []byte { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { c.t.Fatalf("Packet wasn't written out") @@ -486,7 +488,8 @@ func (c *Context) CreateV6Endpoint(v6only bool) { func (c *Context) GetV6Packet() []byte { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { c.t.Fatalf("Packet wasn't written out") diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 120d3baa3..492cc1fcb 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -501,11 +501,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { + case tcpip.BroadcastOption: + e.mu.Lock() + e.broadcast = v + e.mu.Unlock() + + case tcpip.MulticastLoopOption: + e.mu.Lock() + e.multicastLoop = v + e.mu.Unlock() + case tcpip.ReceiveTOSOption: e.mu.Lock() e.receiveTOS = v e.mu.Unlock() - return nil case tcpip.ReceiveTClassOption: // We only support this option on v6 endpoints. @@ -516,7 +525,18 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.mu.Lock() e.receiveTClass = v e.mu.Unlock() - return nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.Lock() + e.receiveIPPacketInfo = v + e.mu.Unlock() + + case tcpip.ReuseAddressOption: + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v + e.mu.Unlock() case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. @@ -533,13 +553,8 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { } e.v6only = v - return nil - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.Lock() - e.receiveIPPacketInfo = v - e.mu.Unlock() - return nil + default: + return tcpip.ErrUnknownProtocolOption } return nil @@ -547,22 +562,40 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return nil -} + switch opt { + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) e.mu.Unlock() - case tcpip.MulticastTTLOption: + case tcpip.IPv4TOSOption: e.mu.Lock() - e.multicastTTL = uint8(v) + e.sendTOS = uint8(v) + e.mu.Unlock() + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) e.mu.Unlock() + case tcpip.ReceiveBufferSizeOption: + case tcpip.SendBufferSizeOption: + + default: + return tcpip.ErrUnknownProtocolOption + } + + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { case tcpip.MulticastInterfaceOption: e.mu.Lock() defer e.mu.Unlock() @@ -686,16 +719,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] - case tcpip.MulticastLoopOption: - e.mu.Lock() - e.multicastLoop = bool(v) - e.mu.Unlock() - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - case tcpip.BindToDeviceOption: id := tcpip.NICID(v) if id != 0 && !e.stack.HasNIC(id) { @@ -704,26 +727,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.bindToDevice = id e.mu.Unlock() - return nil - - case tcpip.BroadcastOption: - e.mu.Lock() - e.broadcast = v != 0 - e.mu.Unlock() - - return nil - - case tcpip.IPv4TOSOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - return nil - - case tcpip.IPv6TrafficClassOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - return nil } return nil } @@ -731,6 +734,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { + case tcpip.BroadcastOption: + e.mu.RLock() + v := e.broadcast + e.mu.RUnlock() + return v, nil + + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.MulticastLoopOption: + e.mu.RLock() + v := e.multicastLoop + e.mu.RUnlock() + return v, nil + case tcpip.ReceiveTOSOption: e.mu.RLock() v := e.receiveTOS @@ -748,6 +766,22 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.ReceiveIPPacketInfoOption: + e.mu.RLock() + v := e.receiveIPPacketInfo + e.mu.RUnlock() + return v, nil + + case tcpip.ReuseAddressOption: + return false, nil + + case tcpip.ReusePortOption: + e.mu.RLock() + v := e.reusePort + e.mu.RUnlock() + + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -760,19 +794,32 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil - case tcpip.ReceiveIPPacketInfoOption: - e.mu.RLock() - v := e.receiveIPPacketInfo - e.mu.RUnlock() - return v, nil + default: + return false, tcpip.ErrUnknownProtocolOption } - - return false, tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { + case tcpip.IPv4TOSOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.MulticastTTLOption: + e.mu.Lock() + v := int(e.multicastTTL) + e.mu.Unlock() + return v, nil + case tcpip.ReceiveQueueSizeOption: v := 0 e.rcvMu.Lock() @@ -794,29 +841,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { v := e.rcvBufSizeMax e.rcvMu.Unlock() return v, nil - } - return -1, tcpip.ErrUnknownProtocolOption + case tcpip.TTLOption: + e.mu.Lock() + v := int(e.ttl) + e.mu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - return nil - - case *tcpip.TTLOption: - e.mu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.mu.Unlock() - return nil - - case *tcpip.MulticastTTLOption: - e.mu.Lock() - *o = tcpip.MulticastTTLOption(e.multicastTTL) - e.mu.Unlock() - return nil - case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ @@ -824,67 +864,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.multicastAddr, } e.mu.Unlock() - return nil - - case *tcpip.MulticastLoopOption: - e.mu.RLock() - v := e.multicastLoop - e.mu.RUnlock() - - *o = tcpip.MulticastLoopOption(v) - return nil - - case *tcpip.ReuseAddressOption: - *o = 0 - return nil - - case *tcpip.ReusePortOption: - e.mu.RLock() - v := e.reusePort - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.BindToDeviceOption: e.mu.RLock() *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - case *tcpip.BroadcastOption: - e.mu.RLock() - v := e.broadcast - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.IPv4TOSOption: - e.mu.RLock() - *o = tcpip.IPv4TOSOption(e.sendTOS) - e.mu.RUnlock() - return nil - - case *tcpip.IPv6TrafficClassOption: - e.mu.RLock() - *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.mu.RUnlock() - return nil default: return tcpip.ErrUnknownProtocolOption } + return nil } // sendUDP sends a UDP segment via the provided network endpoint and under the diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0905726c1..8acaa607a 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -343,11 +343,11 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + c.t.Fatalf("SetSockOptBool failed: %s", err) } } else if flow.isBroadcast() { - if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { - c.t.Fatal("SetSockOpt failed:", err) + if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + c.t.Fatalf("SetSockOptBool failed: %s", err) } } } @@ -358,7 +358,8 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { c.t.Fatalf("Packet wasn't written out") @@ -607,7 +608,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) + c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr) } // Check the payload. @@ -1271,8 +1272,8 @@ func TestTTL(t *testing.T) { c.createEndpointForFlow(flow) const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil { + c.t.Fatalf("SetSockOptInt failed: %s", err) } var wantTTL uint8 @@ -1311,8 +1312,8 @@ func TestSetTTL(t *testing.T) { c.createEndpointForFlow(flow) - if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } var p stack.NetworkProtocol @@ -1346,25 +1347,26 @@ func TestSetTOS(t *testing.T) { c.createEndpointForFlow(flow) const tos = testTOS - var v tcpip.IPv4TOSOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) + c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err) + if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err) } - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) } - if want := tcpip.IPv4TOSOption(tos); v != want { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) + if v != tos { + c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos) } testWrite(c, flow, checker.TOS(tos, 0)) @@ -1381,25 +1383,26 @@ func TestSetTClass(t *testing.T) { c.createEndpointForFlow(flow) const tClass = testTOS - var v tcpip.IPv6TrafficClassOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil { - c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err) + if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil { + c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err) } - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) } - if want := tcpip.IPv6TrafficClassOption(tClass); v != want { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) + if v != tClass { + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass) } // The header getter for TClass is called TOS, so use that checker. @@ -1430,7 +1433,7 @@ func TestReceiveTosTClass(t *testing.T) { // Verify that setting and reading the option works. v, err := c.ep.GetSockOptBool(option) if err != nil { - c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) } // Test for expected default value. if v != false { @@ -1444,7 +1447,7 @@ func TestReceiveTosTClass(t *testing.T) { got, err := c.ep.GetSockOptBool(option) if err != nil { - c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) } if got != want { @@ -1563,7 +1566,8 @@ func TestV4UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) } @@ -1571,7 +1575,8 @@ func TestV4UnknownDestination(t *testing.T) { } // ICMP required. - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { t.Fatalf("packet wasn't written out") @@ -1639,7 +1644,8 @@ func TestV6UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) } @@ -1647,7 +1653,8 @@ func TestV6UnknownDestination(t *testing.T) { } // ICMP required. - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { t.Fatalf("packet wasn't written out") diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index d2f4403b0..cd6a0ea6b 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -29,9 +29,6 @@ import ( ) // IO provides access to the contents of a virtual memory space. -// -// FIXME(b/38173783): Implementations of IO cannot expect ctx to contain any -// meaningful data. type IO interface { // CopyOut copies len(src) bytes from src to the memory mapped at addr. It // returns the number of bytes copied. If the number of bytes copied is < |