diff options
Diffstat (limited to 'pkg')
131 files changed, 5529 insertions, 2452 deletions
diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go index 40f0459a0..7c9a02f20 100644 --- a/pkg/abi/linux/elf.go +++ b/pkg/abi/linux/elf.go @@ -102,4 +102,7 @@ const ( // NT_X86_XSTATE is for x86 extended state using xsave. NT_X86_XSTATE = 0x202 + + // NT_ARM_TLS is for ARM TLS register. + NT_ARM_TLS = 0x401 ) diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index e229ac21c..dbe58acbe 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -266,6 +266,9 @@ type Statx struct { DevMinor uint32 } +// SizeOfStatx is the size of a Statx struct. +var SizeOfStatx = binary.Size(Statx{}) + // FileMode represents a mode_t. type FileMode uint16 diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index bd2e13ba1..80dc09aa9 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -158,10 +158,32 @@ type IPTIP struct { // Flags define matching behavior for the IP header. Flags uint8 - // InverseFlags invert the meaning of fields in struct IPTIP. + // InverseFlags invert the meaning of fields in struct IPTIP. See the + // IPT_INV_* flags. InverseFlags uint8 } +// Flags in IPTIP.InverseFlags. Corresponding constants are in +// include/uapi/linux/netfilter_ipv4/ip_tables.h. +const ( + // Invert the meaning of InputInterface. + IPT_INV_VIA_IN = 0x01 + // Invert the meaning of OutputInterface. + IPT_INV_VIA_OUT = 0x02 + // Unclear what this is, as no references to it exist in the kernel. + IPT_INV_TOS = 0x04 + // Invert the meaning of Src. + IPT_INV_SRCIP = 0x08 + // Invert the meaning of Dst. + IPT_INV_DSTIP = 0x10 + // Invert the meaning of the IPT_F_FRAG flag. + IPT_INV_FRAG = 0x20 + // Invert the meaning of the Protocol field. + IPT_INV_PROTO = 0x40 + // Enable all flags. + IPT_INV_MASK = 0x7F +) + // SizeOfIPTIP is the size of an IPTIP. const SizeOfIPTIP = 84 @@ -253,6 +275,50 @@ type XTErrorTarget struct { // SizeOfXTErrorTarget is the size of an XTErrorTarget. const SizeOfXTErrorTarget = 64 +// Flag values for NfNATIPV4Range. The values indicate whether to map +// protocol specific part(ports) or IPs. It corresponds to values in +// include/uapi/linux/netfilter/nf_nat.h. +const ( + NF_NAT_RANGE_MAP_IPS = 1 << 0 + NF_NAT_RANGE_PROTO_SPECIFIED = 1 << 1 + NF_NAT_RANGE_PROTO_RANDOM = 1 << 2 + NF_NAT_RANGE_PERSISTENT = 1 << 3 + NF_NAT_RANGE_PROTO_RANDOM_FULLY = 1 << 4 + NF_NAT_RANGE_PROTO_RANDOM_ALL = (NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PROTO_RANDOM_FULLY) + NF_NAT_RANGE_MASK = (NF_NAT_RANGE_MAP_IPS | + NF_NAT_RANGE_PROTO_SPECIFIED | NF_NAT_RANGE_PROTO_RANDOM | + NF_NAT_RANGE_PERSISTENT | NF_NAT_RANGE_PROTO_RANDOM_FULLY) +) + +// NfNATIPV4Range corresponds to struct nf_nat_ipv4_range +// in include/uapi/linux/netfilter/nf_nat.h. The fields are in +// network byte order. +type NfNATIPV4Range struct { + Flags uint32 + MinIP [4]byte + MaxIP [4]byte + MinPort uint16 + MaxPort uint16 +} + +// NfNATIPV4MultiRangeCompat corresponds to struct +// nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h. +type NfNATIPV4MultiRangeCompat struct { + RangeSize uint32 + RangeIPV4 NfNATIPV4Range +} + +// XTRedirectTarget triggers a redirect when reached. +// Adding 4 bytes of padding to make the struct 8 byte aligned. +type XTRedirectTarget struct { + Target XTEntryTarget + NfRange NfNATIPV4MultiRangeCompat + _ [4]byte +} + +// SizeOfXTRedirectTarget is the size of an XTRedirectTarget. +const SizeOfXTRedirectTarget = 56 + // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. type IPTGetinfo struct { diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD index a77a3beea..dcd086298 100644 --- a/pkg/buffer/BUILD +++ b/pkg/buffer/BUILD @@ -10,8 +10,8 @@ go_template_instance( prefix = "buffer", template = "//pkg/ilist:generic_list", types = { - "Element": "*Buffer", - "Linker": "*Buffer", + "Element": "*buffer", + "Linker": "*buffer", }, ) @@ -34,6 +34,10 @@ go_library( go_test( name = "buffer_test", size = "small", - srcs = ["view_test.go"], + srcs = [ + "safemem_test.go", + "view_test.go", + ], library = ":buffer", + deps = ["//pkg/safemem"], ) diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index d5f64609b..c6d089fd9 100644 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go @@ -13,6 +13,10 @@ // limitations under the License. // Package buffer provides the implementation of a buffer view. +// +// A view is an flexible buffer, backed by a pool, supporting the safecopy +// operations natively as well as the ability to grow via either prepend or +// append, as well as shrink. package buffer import ( @@ -21,7 +25,7 @@ import ( const bufferSize = 8144 // See below. -// Buffer encapsulates a queueable byte buffer. +// buffer encapsulates a queueable byte buffer. // // Note that the total size is slightly less than two pages. This is done // intentionally to ensure that the buffer object aligns with runtime @@ -30,38 +34,61 @@ const bufferSize = 8144 // See below. // large enough chunk to limit excessive segmentation. // // +stateify savable -type Buffer struct { +type buffer struct { data [bufferSize]byte read int write int bufferEntry } -// Reset resets internal data. +// reset resets internal data. // -// This must be called before use. -func (b *Buffer) Reset() { +// This must be called before returning the buffer to the pool. +func (b *buffer) Reset() { b.read = 0 b.write = 0 } -// Empty indicates the buffer is empty. -// -// This indicates there is no data left to read. -func (b *Buffer) Empty() bool { - return b.read == b.write -} - // Full indicates the buffer is full. // // This indicates there is no capacity left to write. -func (b *Buffer) Full() bool { +func (b *buffer) Full() bool { return b.write == len(b.data) } +// ReadSize returns the number of bytes available for reading. +func (b *buffer) ReadSize() int { + return b.write - b.read +} + +// ReadMove advances the read index by the given amount. +func (b *buffer) ReadMove(n int) { + b.read += n +} + +// ReadSlice returns the read slice for this buffer. +func (b *buffer) ReadSlice() []byte { + return b.data[b.read:b.write] +} + +// WriteSize returns the number of bytes available for writing. +func (b *buffer) WriteSize() int { + return len(b.data) - b.write +} + +// WriteMove advances the write index by the given amount. +func (b *buffer) WriteMove(n int) { + b.write += n +} + +// WriteSlice returns the write slice for this buffer. +func (b *buffer) WriteSlice() []byte { + return b.data[b.write:] +} + // bufferPool is a pool for buffers. var bufferPool = sync.Pool{ New: func() interface{} { - return new(Buffer) + return new(buffer) }, } diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go index 071aaa488..0e5b86344 100644 --- a/pkg/buffer/safemem.go +++ b/pkg/buffer/safemem.go @@ -15,19 +15,17 @@ package buffer import ( - "io" - "gvisor.dev/gvisor/pkg/safemem" ) // WriteBlock returns this buffer as a write Block. -func (b *Buffer) WriteBlock() safemem.Block { - return safemem.BlockFromSafeSlice(b.data[b.write:]) +func (b *buffer) WriteBlock() safemem.Block { + return safemem.BlockFromSafeSlice(b.WriteSlice()) } // ReadBlock returns this buffer as a read Block. -func (b *Buffer) ReadBlock() safemem.Block { - return safemem.BlockFromSafeSlice(b.data[b.read:b.write]) +func (b *buffer) ReadBlock() safemem.Block { + return safemem.BlockFromSafeSlice(b.ReadSlice()) } // WriteFromBlocks implements safemem.Writer.WriteFromBlocks. @@ -47,21 +45,21 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { // Need at least one buffer. firstBuf := v.data.Back() if firstBuf == nil { - firstBuf = bufferPool.Get().(*Buffer) + firstBuf = bufferPool.Get().(*buffer) v.data.PushBack(firstBuf) } // Does the last block have sufficient capacity alone? - if l := len(firstBuf.data) - firstBuf.write; l >= need { + if l := firstBuf.WriteSize(); l >= need { dst = safemem.BlockSeqOf(firstBuf.WriteBlock()) } else { // Append blocks until sufficient. need -= l blocks = append(blocks, firstBuf.WriteBlock()) for need > 0 { - emptyBuf := bufferPool.Get().(*Buffer) + emptyBuf := bufferPool.Get().(*buffer) v.data.PushBack(emptyBuf) - need -= len(emptyBuf.data) // Full block. + need -= emptyBuf.WriteSize() blocks = append(blocks, emptyBuf.WriteBlock()) } dst = safemem.BlockSeqFromSlice(blocks) @@ -73,11 +71,11 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { // Update all indices. for left := int(n); left > 0; firstBuf = firstBuf.Next() { - if l := len(firstBuf.data) - firstBuf.write; left >= l { - firstBuf.write += l // Whole block. + if l := firstBuf.WriteSize(); left >= l { + firstBuf.WriteMove(l) // Whole block. left -= l } else { - firstBuf.write += left // Partial block. + firstBuf.WriteMove(left) // Partial block. left = 0 } } @@ -103,18 +101,18 @@ func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { firstBuf := v.data.Front() if firstBuf == nil { - return 0, io.EOF + return 0, nil // No EOF. } // Is all the data in a single block? - if l := firstBuf.write - firstBuf.read; l >= need { + if l := firstBuf.ReadSize(); l >= need { src = safemem.BlockSeqOf(firstBuf.ReadBlock()) } else { // Build a list of all the buffers. need -= l blocks = append(blocks, firstBuf.ReadBlock()) for buf := firstBuf.Next(); buf != nil && need > 0; buf = buf.Next() { - need -= buf.write - buf.read + need -= buf.ReadSize() blocks = append(blocks, buf.ReadBlock()) } src = safemem.BlockSeqFromSlice(blocks) diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go new file mode 100644 index 000000000..47f357e0c --- /dev/null +++ b/pkg/buffer/safemem_test.go @@ -0,0 +1,170 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "bytes" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/safemem" +) + +func TestSafemem(t *testing.T) { + testCases := []struct { + name string + input string + output string + readLen int + op func(*View) + }{ + // Basic coverage. + { + name: "short", + input: "010", + output: "010", + }, + { + name: "long", + input: "0" + strings.Repeat("1", bufferSize) + "0", + output: "0" + strings.Repeat("1", bufferSize) + "0", + }, + { + name: "short-read", + input: "0", + readLen: 100, // > size. + output: "0", + }, + { + name: "zero-read", + input: "0", + output: "", + }, + { + name: "read-empty", + input: "", + readLen: 1, // > size. + output: "", + }, + + // Ensure offsets work. + { + name: "offsets-short", + input: "012", + output: "2", + op: func(v *View) { + v.TrimFront(2) + }, + }, + { + name: "offsets-long0", + input: "0" + strings.Repeat("1", bufferSize) + "0", + output: strings.Repeat("1", bufferSize) + "0", + op: func(v *View) { + v.TrimFront(1) + }, + }, + { + name: "offsets-long1", + input: "0" + strings.Repeat("1", bufferSize) + "0", + output: strings.Repeat("1", bufferSize-1) + "0", + op: func(v *View) { + v.TrimFront(2) + }, + }, + { + name: "offsets-long2", + input: "0" + strings.Repeat("1", bufferSize) + "0", + output: "10", + op: func(v *View) { + v.TrimFront(bufferSize) + }, + }, + + // Ensure truncation works. + { + name: "truncate-short", + input: "012", + output: "01", + op: func(v *View) { + v.Truncate(2) + }, + }, + { + name: "truncate-long0", + input: "0" + strings.Repeat("1", bufferSize) + "0", + output: "0" + strings.Repeat("1", bufferSize), + op: func(v *View) { + v.Truncate(bufferSize + 1) + }, + }, + { + name: "truncate-long1", + input: "0" + strings.Repeat("1", bufferSize) + "0", + output: "0" + strings.Repeat("1", bufferSize-1), + op: func(v *View) { + v.Truncate(bufferSize) + }, + }, + { + name: "truncate-long2", + input: "0" + strings.Repeat("1", bufferSize) + "0", + output: "01", + op: func(v *View) { + v.Truncate(2) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Construct the new view. + var view View + bs := safemem.BlockSeqOf(safemem.BlockFromSafeSlice([]byte(tc.input))) + n, err := view.WriteFromBlocks(bs) + if err != nil { + t.Errorf("expected err nil, got %v", err) + } + if n != uint64(len(tc.input)) { + t.Errorf("expected %d bytes, got %d", len(tc.input), n) + } + + // Run the operation. + if tc.op != nil { + tc.op(&view) + } + + // Read and validate. + readLen := tc.readLen + if readLen == 0 { + readLen = len(tc.output) // Default. + } + out := make([]byte, readLen) + bs = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(out)) + n, err = view.ReadToBlocks(bs) + if err != nil { + t.Errorf("expected nil, got %v", err) + } + if n != uint64(len(tc.output)) { + t.Errorf("expected %d bytes, got %d", len(tc.output), n) + } + + // Ensure the contents are correct. + if !bytes.Equal(out[:n], []byte(tc.output[:n])) { + t.Errorf("contents are wrong: expected %q, got %q", tc.output, string(out)) + } + }) + } +} diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go index 00fc11e9c..e6901eadb 100644 --- a/pkg/buffer/view.go +++ b/pkg/buffer/view.go @@ -38,14 +38,6 @@ func (v *View) TrimFront(count int64) { } } -// Read implements io.Reader.Read. -// -// Note that reading does not advance the read index. This must be done -// manually using TrimFront or other methods. -func (v *View) Read(p []byte) (int, error) { - return v.ReadAt(p, 0) -} - // ReadAt implements io.ReaderAt.ReadAt. func (v *View) ReadAt(p []byte, offset int64) (int, error) { var ( @@ -54,54 +46,46 @@ func (v *View) ReadAt(p []byte, offset int64) (int, error) { ) for buf := v.data.Front(); buf != nil && done < int64(len(p)); buf = buf.Next() { needToSkip := int(offset - skipped) - if l := buf.write - buf.read; l <= needToSkip { - skipped += int64(l) + if sz := buf.ReadSize(); sz <= needToSkip { + skipped += int64(sz) continue } // Actually read data. - n := copy(p[done:], buf.data[buf.read+needToSkip:buf.write]) + n := copy(p[done:], buf.ReadSlice()[needToSkip:]) skipped += int64(needToSkip) done += int64(n) } - if int(done) < len(p) { + if int(done) < len(p) || offset+done == v.size { return int(done), io.EOF } return int(done), nil } -// Write implements io.Writer.Write. -func (v *View) Write(p []byte) (int, error) { - v.Append(p) // Does not fail. - return len(p), nil -} - // advanceRead advances the view's read index. // // Precondition: there must be sufficient bytes in the buffer. func (v *View) advanceRead(count int64) { for buf := v.data.Front(); buf != nil && count > 0; { - l := int64(buf.write - buf.read) - if l > count { + sz := int64(buf.ReadSize()) + if sz > count { // There is still data for reading. - buf.read += int(count) + buf.ReadMove(int(count)) v.size -= count count = 0 break } - // Read from this buffer. - buf.read += int(l) - count -= l - v.size -= l - - // When all data has been read from a buffer, we push - // it into the empty buffer pool for reuse. + // Consume the whole buffer. oldBuf := buf buf = buf.Next() // Iterate. v.data.Remove(oldBuf) oldBuf.Reset() bufferPool.Put(oldBuf) + + // Update counts. + count -= sz + v.size -= sz } if count > 0 { panic(fmt.Sprintf("advanceRead still has %d bytes remaining", count)) @@ -109,37 +93,39 @@ func (v *View) advanceRead(count int64) { } // Truncate truncates the view to the given bytes. +// +// This will not grow the view, only shrink it. If a length is passed that is +// greater than the current size of the view, then nothing will happen. +// +// Precondition: length must be >= 0. func (v *View) Truncate(length int64) { - if length < 0 || length >= v.size { + if length < 0 { + panic("negative length provided") + } + if length >= v.size { return // Nothing to do. } for buf := v.data.Back(); buf != nil && v.size > length; buf = v.data.Back() { - l := int64(buf.write - buf.read) // Local bytes. - switch { - case v.size-l >= length: - // Drop the buffer completely; see above. - v.data.Remove(buf) - v.size -= l - buf.Reset() - bufferPool.Put(buf) - - case v.size > length && v.size-l < length: - // Just truncate the buffer locally. - delta := (length - (v.size - l)) - buf.write = buf.read + int(delta) + sz := int64(buf.ReadSize()) + if after := v.size - sz; after < length { + // Truncate the buffer locally. + left := (length - after) + buf.write = buf.read + int(left) v.size = length - - default: - // Should never happen. - panic("invalid buffer during truncation") + break } + + // Drop the buffer completely; see above. + v.data.Remove(buf) + buf.Reset() + bufferPool.Put(buf) + v.size -= sz } - v.size = length // Save the new size. } -// Grow grows the given view to the number of bytes. If zero -// is true, all these bytes will be zero. If zero is false, -// then this is the caller's responsibility. +// Grow grows the given view to the number of bytes, which will be appended. If +// zero is true, all these bytes will be zero. If zero is false, then this is +// the caller's responsibility. // // Precondition: length must be >= 0. func (v *View) Grow(length int64, zero bool) { @@ -149,29 +135,29 @@ func (v *View) Grow(length int64, zero bool) { for v.size < length { buf := v.data.Back() - // Is there at least one buffer? + // Is there some space in the last buffer? if buf == nil || buf.Full() { - buf = bufferPool.Get().(*Buffer) + buf = bufferPool.Get().(*buffer) v.data.PushBack(buf) } // Write up to length bytes. - l := len(buf.data) - buf.write - if int64(l) > length-v.size { - l = int(length - v.size) + sz := buf.WriteSize() + if int64(sz) > length-v.size { + sz = int(length - v.size) } // Zero the written section; note that this pattern is // specifically recognized and optimized by the compiler. if zero { - for i := buf.write; i < buf.write+l; i++ { + for i := buf.write; i < buf.write+sz; i++ { buf.data[i] = 0 } } // Advance the index. - buf.write += l - v.size += int64(l) + buf.WriteMove(sz) + v.size += int64(sz) } } @@ -181,31 +167,40 @@ func (v *View) Prepend(data []byte) { if buf := v.data.Front(); buf != nil && buf.read > 0 { // Fill up before the first write. avail := buf.read - copy(buf.data[0:], data[len(data)-avail:]) - data = data[:len(data)-avail] - v.size += int64(avail) + bStart := 0 + dStart := len(data) - avail + if avail > len(data) { + bStart = avail - len(data) + dStart = 0 + } + n := copy(buf.data[bStart:], data[dStart:]) + data = data[:dStart] + v.size += int64(n) + buf.read -= n } for len(data) > 0 { // Do we need an empty buffer? - buf := bufferPool.Get().(*Buffer) + buf := bufferPool.Get().(*buffer) v.data.PushFront(buf) // The buffer is empty; copy last chunk. - start := len(data) - len(buf.data) - if start < 0 { - start = 0 // Everything. + avail := len(buf.data) + bStart := 0 + dStart := len(data) - avail + if avail > len(data) { + bStart = avail - len(data) + dStart = 0 } // We have to put the data at the end of the current // buffer in order to ensure that the next prepend will // correctly fill up the beginning of this buffer. - bStart := len(buf.data) - len(data[start:]) - n := copy(buf.data[bStart:], data[start:]) - buf.read = bStart - buf.write = len(buf.data) - data = data[:start] + n := copy(buf.data[bStart:], data[dStart:]) + data = data[:dStart] v.size += int64(n) + buf.read = len(buf.data) - n + buf.write = len(buf.data) } } @@ -214,16 +209,16 @@ func (v *View) Append(data []byte) { for done := 0; done < len(data); { buf := v.data.Back() - // Find the first empty buffer. + // Ensure there's a buffer with space. if buf == nil || buf.Full() { - buf = bufferPool.Get().(*Buffer) + buf = bufferPool.Get().(*buffer) v.data.PushBack(buf) } // Copy in to the given buffer. - n := copy(buf.data[buf.write:], data[done:]) + n := copy(buf.WriteSlice(), data[done:]) done += n - buf.write += n + buf.WriteMove(n) v.size += int64(n) } } @@ -232,52 +227,52 @@ func (v *View) Append(data []byte) { // // This method should not be used in any performance-sensitive paths. It may // allocate a fresh byte slice sufficiently large to contain all the data in -// the buffer. +// the buffer. This is principally for debugging. // // N.B. Tee data still belongs to this view, as if there is a single buffer // present, then it will be returned directly. This should be used for // temporary use only, and a reference to the given slice should not be held. func (v *View) Flatten() []byte { - if buf := v.data.Front(); buf.Next() == nil { - return buf.data[buf.read:buf.write] // Only one buffer. + if buf := v.data.Front(); buf == nil { + return nil // No data at all. + } else if buf.Next() == nil { + return buf.ReadSlice() // Only one buffer. } data := make([]byte, 0, v.size) // Need to flatten. for buf := v.data.Front(); buf != nil; buf = buf.Next() { // Copy to the allocated slice. - data = append(data, buf.data[buf.read:buf.write]...) + data = append(data, buf.ReadSlice()...) } return data } // Size indicates the total amount of data available in this view. -func (v *View) Size() (sz int64) { - sz = v.size // Pre-calculated. - return sz +func (v *View) Size() int64 { + return v.size } // Copy makes a strict copy of this view. func (v *View) Copy() (other View) { for buf := v.data.Front(); buf != nil; buf = buf.Next() { - other.Append(buf.data[buf.read:buf.write]) + other.Append(buf.ReadSlice()) } - return other + return } // Apply applies the given function across all valid data. func (v *View) Apply(fn func([]byte)) { for buf := v.data.Front(); buf != nil; buf = buf.Next() { - if l := int64(buf.write - buf.read); l > 0 { - fn(buf.data[buf.read:buf.write]) - } + fn(buf.ReadSlice()) } } // Merge merges the provided View with this one. // -// The other view will be empty after this operation. +// The other view will be appended to v, and other will be empty after this +// operation completes. func (v *View) Merge(other *View) { // Copy over all buffers. - for buf := other.data.Front(); buf != nil && !buf.Empty(); buf = other.data.Front() { + for buf := other.data.Front(); buf != nil; buf = other.data.Front() { other.data.Remove(buf) v.data.PushBack(buf) } @@ -288,6 +283,9 @@ func (v *View) Merge(other *View) { } // WriteFromReader writes to the buffer from an io.Reader. +// +// A minimum read size equal to unsafe.Sizeof(unintptr) is enforced, +// provided that count is greater than or equal to unsafe.Sizeof(uintptr). func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) { var ( done int64 @@ -297,17 +295,17 @@ func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) { for done < count { buf := v.data.Back() - // Find the first empty buffer. + // Ensure we have an empty buffer. if buf == nil || buf.Full() { - buf = bufferPool.Get().(*Buffer) + buf = bufferPool.Get().(*buffer) v.data.PushBack(buf) } // Is this less than the minimum batch? - if len(buf.data[buf.write:]) < minBatch && (count-done) >= int64(minBatch) { + if buf.WriteSize() < minBatch && (count-done) >= int64(minBatch) { tmp := make([]byte, minBatch) n, err = r.Read(tmp) - v.Write(tmp[:n]) + v.Append(tmp[:n]) done += int64(n) if err != nil { break @@ -316,14 +314,14 @@ func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) { } // Limit the read, if necessary. - end := len(buf.data) - if int64(end-buf.write) > (count - done) { - end = buf.write + int(count-done) + sz := buf.WriteSize() + if left := count - done; int64(sz) > left { + sz = int(left) } // Pass the relevant portion of the buffer. - n, err = r.Read(buf.data[buf.write:end]) - buf.write += n + n, err = r.Read(buf.WriteSlice()[:sz]) + buf.WriteMove(n) done += int64(n) v.size += int64(n) if err == io.EOF { @@ -340,6 +338,9 @@ func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) { // // N.B. This does not consume the bytes read. TrimFront should // be called appropriately after this call in order to do so. +// +// A minimum write size equal to unsafe.Sizeof(unintptr) is enforced, +// provided that count is greater than or equal to unsafe.Sizeof(uintptr). func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) { var ( done int64 @@ -348,15 +349,22 @@ func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) { ) offset := 0 // Spill-over for batching. for buf := v.data.Front(); buf != nil && done < count; buf = buf.Next() { - l := buf.write - buf.read - offset + // Has this been consumed? Skip it. + sz := buf.ReadSize() + if sz <= offset { + offset -= sz + continue + } + sz -= offset // Is this less than the minimum batch? - if l < minBatch && (count-done) >= int64(minBatch) && (v.size-done) >= int64(minBatch) { + left := count - done + if sz < minBatch && left >= int64(minBatch) && (v.size-done) >= int64(minBatch) { tmp := make([]byte, minBatch) n, err = v.ReadAt(tmp, done) w.Write(tmp[:n]) done += int64(n) - offset = n - l // Reset below. + offset = n - sz // Reset below. if err != nil { break } @@ -364,12 +372,12 @@ func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) { } // Limit the write if necessary. - if int64(l) >= (count - done) { - l = int(count - done) + if int64(sz) >= left { + sz = int(left) } // Perform the actual write. - n, err = w.Write(buf.data[buf.read+offset : buf.read+offset+l]) + n, err = w.Write(buf.ReadSlice()[offset : offset+sz]) done += int64(n) if err != nil { break diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go index 37e652f16..3db1bc6ee 100644 --- a/pkg/buffer/view_test.go +++ b/pkg/buffer/view_test.go @@ -16,218 +16,452 @@ package buffer import ( "bytes" + "io" "strings" "testing" ) +func fillAppend(v *View, data []byte) { + v.Append(data) +} + +func fillAppendEnd(v *View, data []byte) { + v.Grow(bufferSize-1, false) + v.Append(data) + v.TrimFront(bufferSize - 1) +} + +func fillWriteFromReader(v *View, data []byte) { + b := bytes.NewBuffer(data) + v.WriteFromReader(b, int64(len(data))) +} + +func fillWriteFromReaderEnd(v *View, data []byte) { + v.Grow(bufferSize-1, false) + b := bytes.NewBuffer(data) + v.WriteFromReader(b, int64(len(data))) + v.TrimFront(bufferSize - 1) +} + +var fillFuncs = map[string]func(*View, []byte){ + "append": fillAppend, + "appendEnd": fillAppendEnd, + "writeFromReader": fillWriteFromReader, + "writeFromReaderEnd": fillWriteFromReaderEnd, +} + +func testReadAt(t *testing.T, v *View, offset int64, n int, wantStr string, wantErr error) { + t.Helper() + d := make([]byte, n) + n, err := v.ReadAt(d, offset) + if n != len(wantStr) { + t.Errorf("got %d, want %d", n, len(wantStr)) + } + if err != wantErr { + t.Errorf("got err %v, want %v", err, wantErr) + } + if !bytes.Equal(d[:n], []byte(wantStr)) { + t.Errorf("got %q, want %q", string(d[:n]), wantStr) + } +} + func TestView(t *testing.T) { testCases := []struct { name string input string output string - ops []func(*View) + op func(*testing.T, *View) }{ - // Prepend. + // Preconditions. + { + name: "truncate-check", + input: "hello", + output: "hello", // Not touched. + op: func(t *testing.T, v *View) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Truncate(-1) did not panic") + } + }() + v.Truncate(-1) + }, + }, + { + name: "grow-check", + input: "hello", + output: "hello", // Not touched. + op: func(t *testing.T, v *View) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Grow(-1) did not panic") + } + }() + v.Grow(-1, false) + }, + }, { - name: "prepend", - input: "world", - ops: []func(*View){ - func(v *View) { - v.Prepend([]byte("hello ")) - }, + name: "advance-check", + input: "hello", + output: "", // Consumed. + op: func(t *testing.T, v *View) { + defer func() { + if r := recover(); r == nil { + t.Errorf("advanceRead(Size()+1) did not panic") + } + }() + v.advanceRead(v.Size() + 1) }, + }, + + // Prepend. + { + name: "prepend", + input: "world", output: "hello world", + op: func(t *testing.T, v *View) { + v.Prepend([]byte("hello ")) + }, }, { - name: "prepend fill", - input: strings.Repeat("1", bufferSize-1), - ops: []func(*View){ - func(v *View) { - v.Prepend([]byte("0")) - }, + name: "prepend-backfill-full", + input: "hello world", + output: "jello world", + op: func(t *testing.T, v *View) { + v.TrimFront(1) + v.Prepend([]byte("j")) }, - output: "0" + strings.Repeat("1", bufferSize-1), }, { - name: "prepend overflow", - input: strings.Repeat("1", bufferSize), - ops: []func(*View){ - func(v *View) { - v.Prepend([]byte("0")) - }, + name: "prepend-backfill-under", + input: "hello world", + output: "hola world", + op: func(t *testing.T, v *View) { + v.TrimFront(5) + v.Prepend([]byte("hola")) }, - output: "0" + strings.Repeat("1", bufferSize), }, { - name: "prepend multiple buffers", - input: strings.Repeat("1", bufferSize-1), - ops: []func(*View){ - func(v *View) { - v.Prepend([]byte(strings.Repeat("0", bufferSize*3))) - }, + name: "prepend-backfill-over", + input: "hello world", + output: "smello world", + op: func(t *testing.T, v *View) { + v.TrimFront(1) + v.Prepend([]byte("sm")) }, + }, + { + name: "prepend-fill", + input: strings.Repeat("1", bufferSize-1), + output: "0" + strings.Repeat("1", bufferSize-1), + op: func(t *testing.T, v *View) { + v.Prepend([]byte("0")) + }, + }, + { + name: "prepend-overflow", + input: strings.Repeat("1", bufferSize), + output: "0" + strings.Repeat("1", bufferSize), + op: func(t *testing.T, v *View) { + v.Prepend([]byte("0")) + }, + }, + { + name: "prepend-multiple-buffers", + input: strings.Repeat("1", bufferSize-1), output: strings.Repeat("0", bufferSize*3) + strings.Repeat("1", bufferSize-1), + op: func(t *testing.T, v *View) { + v.Prepend([]byte(strings.Repeat("0", bufferSize*3))) + }, }, - // Append. + // Append and write. { - name: "append", - input: "hello", - ops: []func(*View){ - func(v *View) { - v.Append([]byte(" world")) - }, - }, + name: "append", + input: "hello", output: "hello world", + op: func(t *testing.T, v *View) { + v.Append([]byte(" world")) + }, }, { - name: "append fill", - input: strings.Repeat("1", bufferSize-1), - ops: []func(*View){ - func(v *View) { - v.Append([]byte("0")) - }, - }, + name: "append-fill", + input: strings.Repeat("1", bufferSize-1), output: strings.Repeat("1", bufferSize-1) + "0", + op: func(t *testing.T, v *View) { + v.Append([]byte("0")) + }, }, { - name: "append overflow", - input: strings.Repeat("1", bufferSize), - ops: []func(*View){ - func(v *View) { - v.Append([]byte("0")) - }, - }, + name: "append-overflow", + input: strings.Repeat("1", bufferSize), output: strings.Repeat("1", bufferSize) + "0", + op: func(t *testing.T, v *View) { + v.Append([]byte("0")) + }, }, { - name: "append multiple buffers", - input: strings.Repeat("1", bufferSize-1), - ops: []func(*View){ - func(v *View) { - v.Append([]byte(strings.Repeat("0", bufferSize*3))) - }, - }, + name: "append-multiple-buffers", + input: strings.Repeat("1", bufferSize-1), output: strings.Repeat("1", bufferSize-1) + strings.Repeat("0", bufferSize*3), + op: func(t *testing.T, v *View) { + v.Append([]byte(strings.Repeat("0", bufferSize*3))) + }, }, // Truncate. { - name: "truncate", - input: "hello world", - ops: []func(*View){ - func(v *View) { - v.Truncate(5) - }, - }, + name: "truncate", + input: "hello world", output: "hello", + op: func(t *testing.T, v *View) { + v.Truncate(5) + }, }, { - name: "truncate multiple buffers", - input: strings.Repeat("1", bufferSize*2), - ops: []func(*View){ - func(v *View) { - v.Truncate(bufferSize*2 - 1) - }, + name: "truncate-noop", + input: "hello world", + output: "hello world", + op: func(t *testing.T, v *View) { + v.Truncate(v.Size() + 1) }, - output: strings.Repeat("1", bufferSize*2-1), }, { - name: "truncate multiple buffers to one buffer", - input: strings.Repeat("1", bufferSize*2), - ops: []func(*View){ - func(v *View) { - v.Truncate(5) - }, + name: "truncate-multiple-buffers", + input: strings.Repeat("1", bufferSize*2), + output: strings.Repeat("1", bufferSize*2-1), + op: func(t *testing.T, v *View) { + v.Truncate(bufferSize*2 - 1) }, + }, + { + name: "truncate-multiple-buffers-to-one", + input: strings.Repeat("1", bufferSize*2), output: "11111", + op: func(t *testing.T, v *View) { + v.Truncate(5) + }, }, // TrimFront. { - name: "trim", - input: "hello world", - ops: []func(*View){ - func(v *View) { - v.TrimFront(6) - }, - }, + name: "trim", + input: "hello world", output: "world", + op: func(t *testing.T, v *View) { + v.TrimFront(6) + }, }, { - name: "trim multiple buffers", - input: strings.Repeat("1", bufferSize*2), - ops: []func(*View){ - func(v *View) { - v.TrimFront(1) - }, + name: "trim-too-large", + input: "hello world", + output: "", + op: func(t *testing.T, v *View) { + v.TrimFront(v.Size() + 1) }, - output: strings.Repeat("1", bufferSize*2-1), }, { - name: "trim multiple buffers to one buffer", - input: strings.Repeat("1", bufferSize*2), - ops: []func(*View){ - func(v *View) { - v.TrimFront(bufferSize*2 - 1) - }, + name: "trim-multiple-buffers", + input: strings.Repeat("1", bufferSize*2), + output: strings.Repeat("1", bufferSize*2-1), + op: func(t *testing.T, v *View) { + v.TrimFront(1) }, + }, + { + name: "trim-multiple-buffers-to-one-buffer", + input: strings.Repeat("1", bufferSize*2), output: "1", + op: func(t *testing.T, v *View) { + v.TrimFront(bufferSize*2 - 1) + }, }, // Grow. { - name: "grow", - input: "hello world", - ops: []func(*View){ - func(v *View) { - v.Grow(1, true) - }, - }, + name: "grow", + input: "hello world", output: "hello world", + op: func(t *testing.T, v *View) { + v.Grow(1, true) + }, }, { - name: "grow from zero", - ops: []func(*View){ - func(v *View) { - v.Grow(1024, true) - }, - }, + name: "grow-from-zero", output: strings.Repeat("\x00", 1024), + op: func(t *testing.T, v *View) { + v.Grow(1024, true) + }, }, { - name: "grow from non-zero", - input: strings.Repeat("1", bufferSize), - ops: []func(*View){ - func(v *View) { - v.Grow(bufferSize*2, true) - }, - }, + name: "grow-from-non-zero", + input: strings.Repeat("1", bufferSize), output: strings.Repeat("1", bufferSize) + strings.Repeat("\x00", bufferSize), + op: func(t *testing.T, v *View) { + v.Grow(bufferSize*2, true) + }, + }, + + // Copy. + { + name: "copy", + input: "hello", + output: "hello", + op: func(t *testing.T, v *View) { + other := v.Copy() + bs := other.Flatten() + want := []byte("hello") + if !bytes.Equal(bs, want) { + t.Errorf("expected %v, got %v", want, bs) + } + }, + }, + { + name: "copy-large", + input: strings.Repeat("1", bufferSize+1), + output: strings.Repeat("1", bufferSize+1), + op: func(t *testing.T, v *View) { + other := v.Copy() + bs := other.Flatten() + want := []byte(strings.Repeat("1", bufferSize+1)) + if !bytes.Equal(bs, want) { + t.Errorf("expected %v, got %v", want, bs) + } + }, + }, + + // Merge. + { + name: "merge", + input: "hello", + output: "hello world", + op: func(t *testing.T, v *View) { + var other View + other.Append([]byte(" world")) + v.Merge(&other) + if sz := other.Size(); sz != 0 { + t.Errorf("expected 0, got %d", sz) + } + }, + }, + { + name: "merge-large", + input: strings.Repeat("1", bufferSize+1), + output: strings.Repeat("1", bufferSize+1) + strings.Repeat("0", bufferSize+1), + op: func(t *testing.T, v *View) { + var other View + other.Append([]byte(strings.Repeat("0", bufferSize+1))) + v.Merge(&other) + if sz := other.Size(); sz != 0 { + t.Errorf("expected 0, got %d", sz) + } + }, + }, + + // ReadAt. + { + name: "readat", + input: "hello", + output: "hello", + op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 6, "hello", io.EOF) }, + }, + { + name: "readat-long", + input: "hello", + output: "hello", + op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 8, "hello", io.EOF) }, + }, + { + name: "readat-short", + input: "hello", + output: "hello", + op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 3, "hel", nil) }, + }, + { + name: "readat-offset", + input: "hello", + output: "hello", + op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 3, "llo", io.EOF) }, + }, + { + name: "readat-long-offset", + input: "hello", + output: "hello", + op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 8, "llo", io.EOF) }, + }, + { + name: "readat-short-offset", + input: "hello", + output: "hello", + op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 2, "ll", nil) }, + }, + { + name: "readat-skip-all", + input: "hello", + output: "hello", + op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "", io.EOF) }, + }, + { + name: "readat-second-buffer", + input: strings.Repeat("0", bufferSize+1) + "12", + output: strings.Repeat("0", bufferSize+1) + "12", + op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "1", nil) }, + }, + { + name: "readat-second-buffer-end", + input: strings.Repeat("0", bufferSize+1) + "12", + output: strings.Repeat("0", bufferSize+1) + "12", + op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 2, "12", io.EOF) }, }, } for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Construct the new view. - var view View - view.Append([]byte(tc.input)) - - // Run all operations. - for _, op := range tc.ops { - op(&view) - } - - // Flatten and validate. - out := view.Flatten() - if !bytes.Equal([]byte(tc.output), out) { - t.Errorf("expected %q, got %q", tc.output, string(out)) - } - - // Ensure the size is correct. - if len(out) != int(view.Size()) { - t.Errorf("size is wrong: expected %d, got %d", len(out), view.Size()) - } - }) + for fillName, fn := range fillFuncs { + t.Run(fillName+"/"+tc.name, func(t *testing.T) { + // Construct & fill the view. + var view View + fn(&view, []byte(tc.input)) + + // Run the operation. + if tc.op != nil { + tc.op(t, &view) + } + + // Flatten and validate. + out := view.Flatten() + if !bytes.Equal([]byte(tc.output), out) { + t.Errorf("expected %q, got %q", tc.output, string(out)) + } + + // Ensure the size is correct. + if len(out) != int(view.Size()) { + t.Errorf("size is wrong: expected %d, got %d", len(out), view.Size()) + } + + // Calculate contents via apply. + var appliedOut []byte + view.Apply(func(b []byte) { + appliedOut = append(appliedOut, b...) + }) + if len(appliedOut) != len(out) { + t.Errorf("expected %d, got %d", len(out), len(appliedOut)) + } + if !bytes.Equal(appliedOut, out) { + t.Errorf("expected %v, got %v", out, appliedOut) + } + + // Calculate contents via ReadToWriter. + var b bytes.Buffer + n, err := view.ReadToWriter(&b, int64(len(out))) + if n != int64(len(out)) { + t.Errorf("expected %d, got %d", len(out), n) + } + if err != nil { + t.Errorf("expected nil, got %v", err) + } + if !bytes.Equal(b.Bytes(), out) { + t.Errorf("expected %v, got %v", out, b.Bytes()) + } + }) + } } } diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go index f3a609b57..8f93e4d6d 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/list.go @@ -169,8 +169,9 @@ func (l *List) InsertBefore(a, e Element) { // Remove removes e from l. func (l *List) Remove(e Element) { - prev := ElementMapper{}.linkerFor(e).Prev() - next := ElementMapper{}.linkerFor(e).Next() + linker := ElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() if prev != nil { ElementMapper{}.linkerFor(prev).SetNext(next) @@ -183,6 +184,9 @@ func (l *List) Remove(e Element) { } else { l.tail = prev } + + linker.SetNext(nil) + linker.SetPrev(nil) } // Entry is a default implementation of Linker. Users can add anonymous fields diff --git a/pkg/safemem/seq_test.go b/pkg/safemem/seq_test.go index eba4bb535..de34005e9 100644 --- a/pkg/safemem/seq_test.go +++ b/pkg/safemem/seq_test.go @@ -20,6 +20,27 @@ import ( "testing" ) +func TestBlockSeqOfEmptyBlock(t *testing.T) { + bs := BlockSeqOf(Block{}) + if !bs.IsEmpty() { + t.Errorf("BlockSeqOf(Block{}).IsEmpty(): got false, wanted true; BlockSeq is %v", bs) + } +} + +func TestBlockSeqOfNonemptyBlock(t *testing.T) { + b := BlockFromSafeSlice(make([]byte, 1)) + bs := BlockSeqOf(b) + if bs.IsEmpty() { + t.Fatalf("BlockSeqOf(non-empty Block).IsEmpty(): got true, wanted false; BlockSeq is %v", bs) + } + if head := bs.Head(); head != b { + t.Fatalf("BlockSeqOf(non-empty Block).Head(): got %v, wanted %v", head, b) + } + if tail := bs.Tail(); !tail.IsEmpty() { + t.Fatalf("BlockSeqOf(non-empty Block).Tail().IsEmpty(): got false, wanted true: tail is %v", tail) + } +} + type blockSeqTest struct { desc string diff --git a/pkg/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go index dcdfc9600..f5f0574f8 100644 --- a/pkg/safemem/seq_unsafe.go +++ b/pkg/safemem/seq_unsafe.go @@ -56,6 +56,9 @@ type BlockSeq struct { // BlockSeqOf returns a BlockSeq representing the single Block b. func BlockSeqOf(b Block) BlockSeq { + if b.length == 0 { + return BlockSeq{} + } bs := BlockSeq{ data: b.start, length: -1, diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index da5a5e4b2..88766f33b 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -451,7 +451,7 @@ func TestRandom(t *testing.T) { } } - fmt.Printf("Testing filters: %v", syscallRules) + t.Logf("Testing filters: %v", syscallRules) instrs, err := BuildProgram([]RuleSet{ RuleSet{ Rules: syscallRules, diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index 01940bca4..c29e1b841 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -95,6 +95,9 @@ type State struct { // Our floating point state. aarch64FPState `state:"wait"` + // TLS pointer + TPValue uint64 + // FeatureSet is a pointer to the currently active feature set. FeatureSet *cpuid.FeatureSet @@ -148,6 +151,7 @@ func (s *State) Fork() State { return State{ Regs: s.Regs, aarch64FPState: s.aarch64FPState.fork(), + TPValue: s.TPValue, FeatureSet: s.FeatureSet, OrigR0: s.OrigR0, } @@ -259,6 +263,7 @@ func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) { const ( _NT_PRSTATUS = 1 _NT_PRFPREG = 2 + _NT_ARM_TLS = 0x401 ) // PtraceGetRegSet implements Context.PtraceGetRegSet. diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index 885115ae2..db99c5acb 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -140,16 +140,17 @@ func (c *context64) SetStack(value uintptr) { // TLS returns the current TLS pointer. func (c *context64) TLS() uintptr { - // TODO(gvisor.dev/issue/1238): TLS is not supported. - // MRS_TPIDR_EL0 - return 0 + return uintptr(c.TPValue) } // SetTLS sets the current TLS pointer. Returns false if value is invalid. func (c *context64) SetTLS(value uintptr) bool { - // TODO(gvisor.dev/issue/1238): TLS is not supported. - // MSR_TPIDR_EL0 - return false + if value >= uintptr(maxAddr64) { + return false + } + + c.TPValue = uint64(value) + return true } // SetOldRSeqInterruptedIP implements Context.SetOldRSeqInterruptedIP. diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 151808911..663e51989 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -117,9 +117,9 @@ func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error { return nil } -// Goroutine is an RPC stub which dumps out the stack trace for all running -// goroutines. -func (p *Profile) Goroutine(o *ProfileOpts, _ *struct{}) error { +// GoroutineProfile is an RPC stub which dumps out the stack trace for all +// running goroutines. +func (p *Profile) GoroutineProfile(o *ProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { return errNoOutput } @@ -131,6 +131,34 @@ func (p *Profile) Goroutine(o *ProfileOpts, _ *struct{}) error { return nil } +// BlockProfile is an RPC stub which dumps out the stack trace that led to +// blocking on synchronization primitives. +func (p *Profile) BlockProfile(o *ProfileOpts, _ *struct{}) error { + if len(o.FilePayload.Files) < 1 { + return errNoOutput + } + output := o.FilePayload.Files[0] + defer output.Close() + if err := pprof.Lookup("block").WriteTo(output, 0); err != nil { + return err + } + return nil +} + +// MutexProfile is an RPC stub which dumps out the stack trace of holders of +// contended mutexes. +func (p *Profile) MutexProfile(o *ProfileOpts, _ *struct{}) error { + if len(o.FilePayload.Files) < 1 { + return errNoOutput + } + output := o.FilePayload.Files[0] + defer output.Close() + if err := pprof.Lookup("mutex").WriteTo(output, 0); err != nil { + return err + } + return nil +} + // StartTrace is an RPC stub which starts collection of an execution trace. func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go index 25514ace4..33de32c69 100644 --- a/pkg/sentry/fs/dirent_cache.go +++ b/pkg/sentry/fs/dirent_cache.go @@ -101,8 +101,6 @@ func (c *DirentCache) remove(d *Dirent) { panic(fmt.Sprintf("trying to remove %v, which is not in the dirent cache", d)) } c.list.Remove(d) - d.SetPrev(nil) - d.SetNext(nil) d.DecRef() c.currentSize-- if c.limit != nil { diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 21003ea45..011625c80 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -10,7 +10,7 @@ go_library( "descriptor_state.go", "device.go", "file.go", - "fs.go", + "host.go", "inode.go", "inode_state.go", "ioctl_unsafe.go", @@ -62,14 +62,12 @@ go_test( size = "small", srcs = [ "descriptor_test.go", - "fs_test.go", "inode_test.go", "socket_test.go", "wait_test.go", ], library = ":host", deps = [ - "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", "//pkg/sentry/contexttest", diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go index 1658979fc..cd84e1337 100644 --- a/pkg/sentry/fs/host/control.go +++ b/pkg/sentry/fs/host/control.go @@ -32,6 +32,8 @@ func newSCMRights(fds []int) control.SCMRights { } // Files implements control.SCMRights.Files. +// +// TODO(gvisor.dev/issue/2017): Port to VFS2. func (c *scmRights) Files(ctx context.Context, max int) (control.RightsFiles, bool) { n := max var trunc bool diff --git a/pkg/sentry/fs/host/descriptor.go b/pkg/sentry/fs/host/descriptor.go index 2a4d1b291..cfdce6a74 100644 --- a/pkg/sentry/fs/host/descriptor.go +++ b/pkg/sentry/fs/host/descriptor.go @@ -16,7 +16,6 @@ package host import ( "fmt" - "path" "syscall" "gvisor.dev/gvisor/pkg/fdnotifier" @@ -28,12 +27,9 @@ import ( // // +stateify savable type descriptor struct { - // donated is true if the host fd was donated by another process. - donated bool - // If origFD >= 0, it is the host fd that this file was originally created // from, which must be available at time of restore. The FD can be closed - // after descriptor is created. Only set if donated is true. + // after descriptor is created. origFD int // wouldBlock is true if value (below) points to a file that can @@ -41,15 +37,13 @@ type descriptor struct { wouldBlock bool // value is the wrapped host fd. It is never saved or restored - // directly. How it is restored depends on whether it was - // donated and the fs.MountSource it was originally - // opened/created from. + // directly. value int `state:"nosave"` } // newDescriptor returns a wrapped host file descriptor. On success, // the descriptor is registered for event notifications with queue. -func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue *waiter.Queue) (*descriptor, error) { +func newDescriptor(fd int, saveable bool, wouldBlock bool, queue *waiter.Queue) (*descriptor, error) { ownedFD := fd origFD := -1 if saveable { @@ -69,7 +63,6 @@ func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue * } } return &descriptor{ - donated: donated, origFD: origFD, wouldBlock: wouldBlock, value: ownedFD, @@ -77,25 +70,11 @@ func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue * } // initAfterLoad initializes the value of the descriptor after Load. -func (d *descriptor) initAfterLoad(mo *superOperations, id uint64, queue *waiter.Queue) error { - if d.donated { - var err error - d.value, err = syscall.Dup(d.origFD) - if err != nil { - return fmt.Errorf("failed to dup restored fd %d: %v", d.origFD, err) - } - } else { - name, ok := mo.inodeMappings[id] - if !ok { - return fmt.Errorf("failed to find path for inode number %d", id) - } - fullpath := path.Join(mo.root, name) - - var err error - d.value, err = open(nil, fullpath) - if err != nil { - return fmt.Errorf("failed to open %q: %v", fullpath, err) - } +func (d *descriptor) initAfterLoad(id uint64, queue *waiter.Queue) error { + var err error + d.value, err = syscall.Dup(d.origFD) + if err != nil { + return fmt.Errorf("failed to dup restored fd %d: %v", d.origFD, err) } if d.wouldBlock { if err := syscall.SetNonblock(d.value, true); err != nil { diff --git a/pkg/sentry/fs/host/descriptor_state.go b/pkg/sentry/fs/host/descriptor_state.go index 8167390a9..e880582ab 100644 --- a/pkg/sentry/fs/host/descriptor_state.go +++ b/pkg/sentry/fs/host/descriptor_state.go @@ -16,7 +16,7 @@ package host // beforeSave is invoked by stateify. func (d *descriptor) beforeSave() { - if d.donated && d.origFD < 0 { + if d.origFD < 0 { panic("donated file descriptor cannot be saved") } } diff --git a/pkg/sentry/fs/host/descriptor_test.go b/pkg/sentry/fs/host/descriptor_test.go index 4205981f5..d8e4605b6 100644 --- a/pkg/sentry/fs/host/descriptor_test.go +++ b/pkg/sentry/fs/host/descriptor_test.go @@ -47,10 +47,10 @@ func TestDescriptorRelease(t *testing.T) { // FD ownership is transferred to the descritor. queue := &waiter.Queue{} - d, err := newDescriptor(fd, false /* donated*/, tc.saveable, tc.wouldBlock, queue) + d, err := newDescriptor(fd, tc.saveable, tc.wouldBlock, queue) if err != nil { syscall.Close(fd) - t.Fatalf("newDescriptor(%d, %t, false, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err) + t.Fatalf("newDescriptor(%d, %t, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err) } if tc.saveable { if d.origFD < 0 { diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go index e08f56d04..034862694 100644 --- a/pkg/sentry/fs/host/file.go +++ b/pkg/sentry/fs/host/file.go @@ -101,8 +101,8 @@ func newFileFromDonatedFD(ctx context.Context, donated int, mounter fs.FileOwner }) return s, nil default: - msrc := newMountSource(ctx, "/", mounter, &Filesystem{}, fs.MountSourceFlags{}, false /* dontTranslateOwnership */) - inode, err := newInode(ctx, msrc, donated, saveable, true /* donated */) + msrc := fs.NewNonCachingMountSource(ctx, &filesystem{}, fs.MountSourceFlags{}) + inode, err := newInode(ctx, msrc, donated, saveable) if err != nil { return nil, err } diff --git a/pkg/sentry/fs/host/fs.go b/pkg/sentry/fs/host/fs.go deleted file mode 100644 index d3e8e3a36..000000000 --- a/pkg/sentry/fs/host/fs.go +++ /dev/null @@ -1,339 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package host implements an fs.Filesystem for files backed by host -// file descriptors. -package host - -import ( - "fmt" - "path" - "path/filepath" - "strconv" - "strings" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -// FilesystemName is the name under which Filesystem is registered. -const FilesystemName = "whitelistfs" - -const ( - // whitelistKey is the mount option containing a comma-separated list - // of host paths to whitelist. - whitelistKey = "whitelist" - - // rootPathKey is the mount option containing the root path of the - // mount. - rootPathKey = "root" - - // dontTranslateOwnershipKey is the key to superOperations.dontTranslateOwnership. - dontTranslateOwnershipKey = "dont_translate_ownership" -) - -// maxTraversals determines link traversals in building the whitelist. -const maxTraversals = 10 - -// Filesystem is a pseudo file system that is only available during the setup -// to lock down the configurations. This filesystem should only be mounted at root. -// -// Think twice before exposing this to applications. -// -// +stateify savable -type Filesystem struct { - // whitelist is a set of host paths to whitelist. - paths []string -} - -var _ fs.Filesystem = (*Filesystem)(nil) - -// Name is the identifier of this file system. -func (*Filesystem) Name() string { - return FilesystemName -} - -// AllowUserMount prohibits users from using mount(2) with this file system. -func (*Filesystem) AllowUserMount() bool { - return false -} - -// AllowUserList allows this filesystem to be listed in /proc/filesystems. -func (*Filesystem) AllowUserList() bool { - return true -} - -// Flags returns that there is nothing special about this file system. -func (*Filesystem) Flags() fs.FilesystemFlags { - return 0 -} - -// Mount returns an fs.Inode exposing the host file system. It is intended to be locked -// down in PreExec below. -func (f *Filesystem) Mount(ctx context.Context, _ string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) { - // Parse generic comma-separated key=value options. - options := fs.GenericMountSourceOptions(data) - - // Grab the whitelist if one was specified. - // TODO(edahlgren/mpratt/hzy): require another option "testonly" in order to allow - // no whitelist. - if wl, ok := options[whitelistKey]; ok { - f.paths = strings.Split(wl, "|") - delete(options, whitelistKey) - } - - // If the rootPath was set, use it. Othewise default to the root of the - // host fs. - rootPath := "/" - if rp, ok := options[rootPathKey]; ok { - rootPath = rp - delete(options, rootPathKey) - - // We must relativize the whitelisted paths to the new root. - for i, p := range f.paths { - rel, err := filepath.Rel(rootPath, p) - if err != nil { - return nil, fmt.Errorf("whitelist path %q must be a child of root path %q", p, rootPath) - } - f.paths[i] = path.Join("/", rel) - } - } - fd, err := open(nil, rootPath) - if err != nil { - return nil, fmt.Errorf("failed to find root: %v", err) - } - - var dontTranslateOwnership bool - if v, ok := options[dontTranslateOwnershipKey]; ok { - b, err := strconv.ParseBool(v) - if err != nil { - return nil, fmt.Errorf("invalid value for %q: %v", dontTranslateOwnershipKey, err) - } - dontTranslateOwnership = b - delete(options, dontTranslateOwnershipKey) - } - - // Fail if the caller passed us more options than we know about. - if len(options) > 0 { - return nil, fmt.Errorf("unsupported mount options: %v", options) - } - - // The mounting EUID/EGID will be cached by this file system. This will - // be used to assign ownership to files that we own. - owner := fs.FileOwnerFromContext(ctx) - - // Construct the host file system mount and inode. - msrc := newMountSource(ctx, rootPath, owner, f, flags, dontTranslateOwnership) - return newInode(ctx, msrc, fd, false /* saveable */, false /* donated */) -} - -// InstallWhitelist locks down the MountNamespace to only the currently installed -// Dirents and the given paths. -func (f *Filesystem) InstallWhitelist(ctx context.Context, m *fs.MountNamespace) error { - return installWhitelist(ctx, m, f.paths) -} - -func installWhitelist(ctx context.Context, m *fs.MountNamespace, paths []string) error { - if len(paths) == 0 || (len(paths) == 1 && paths[0] == "") { - // Warning will be logged during filter installation if the empty - // whitelist matters (allows for host file access). - return nil - } - - // Done tracks entries already added. - done := make(map[string]bool) - root := m.Root() - defer root.DecRef() - - for i := 0; i < len(paths); i++ { - // Make sure the path is absolute. This is a sanity check. - if !path.IsAbs(paths[i]) { - return fmt.Errorf("path %q is not absolute", paths[i]) - } - - // We need to add all the intermediate paths, in case one of - // them is a symlink that needs to be resolved. - for j := 1; j <= len(paths[i]); j++ { - if j < len(paths[i]) && paths[i][j] != '/' { - continue - } - current := paths[i][:j] - - // Lookup the given component in the tree. - remainingTraversals := uint(maxTraversals) - d, err := m.FindLink(ctx, root, nil, current, &remainingTraversals) - if err != nil { - log.Warningf("populate failed for %q: %v", current, err) - continue - } - - // It's critical that this DecRef happens after the - // freeze below. This ensures that the dentry is in - // place to be frozen. Otherwise, we freeze without - // these entries. - defer d.DecRef() - - // Expand the last component if necessary. - if current == paths[i] { - // Is it a directory or symlink? - sattr := d.Inode.StableAttr - if fs.IsDir(sattr) { - for name := range childDentAttrs(ctx, d) { - paths = append(paths, path.Join(current, name)) - } - } - if fs.IsSymlink(sattr) { - // Only expand symlinks once. The - // folder structure may contain - // recursive symlinks and we don't want - // to end up infinitely expanding this - // symlink. This is safe because this - // is the last component. If a later - // path wants to symlink something - // beneath this symlink that will still - // be handled by the FindLink above. - if done[current] { - continue - } - - s, err := d.Inode.Readlink(ctx) - if err != nil { - log.Warningf("readlink failed for %q: %v", current, err) - continue - } - if path.IsAbs(s) { - paths = append(paths, s) - } else { - target := path.Join(path.Dir(current), s) - paths = append(paths, target) - } - } - } - - // Only report this one once even though we may look - // it up more than once. If we whitelist /a/b,/a then - // /a will be "done" when it is looked up for /a/b, - // however we still need to expand all of its contents - // when whitelisting /a. - if !done[current] { - log.Debugf("whitelisted: %s", current) - } - done[current] = true - } - } - - // Freeze the mount tree in place. This prevents any new paths from - // being opened and any old ones from being removed. If we do provide - // tmpfs mounts, we'll want to freeze/thaw those separately. - m.Freeze() - return nil -} - -func childDentAttrs(ctx context.Context, d *fs.Dirent) map[string]fs.DentAttr { - dirname, _ := d.FullName(nil /* root */) - dir, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) - if err != nil { - log.Warningf("failed to open directory %q: %v", dirname, err) - return nil - } - dir.DecRef() - var stubSerializer fs.CollectEntriesSerializer - if err := dir.Readdir(ctx, &stubSerializer); err != nil { - log.Warningf("failed to iterate on host directory %q: %v", dirname, err) - return nil - } - delete(stubSerializer.Entries, ".") - delete(stubSerializer.Entries, "..") - return stubSerializer.Entries -} - -// newMountSource constructs a new host fs.MountSource -// relative to a root path. The root should match the mount point. -func newMountSource(ctx context.Context, root string, mounter fs.FileOwner, filesystem fs.Filesystem, flags fs.MountSourceFlags, dontTranslateOwnership bool) *fs.MountSource { - return fs.NewMountSource(ctx, &superOperations{ - root: root, - inodeMappings: make(map[uint64]string), - mounter: mounter, - dontTranslateOwnership: dontTranslateOwnership, - }, filesystem, flags) -} - -// superOperations implements fs.MountSourceOperations. -// -// +stateify savable -type superOperations struct { - fs.SimpleMountSourceOperations - - // root is the path of the mount point. All inode mappings - // are relative to this root. - root string - - // inodeMappings contains mappings of fs.Inodes associated - // with this MountSource to paths under root. - inodeMappings map[uint64]string - - // mounter is the cached EUID/EGID that mounted this file system. - mounter fs.FileOwner - - // dontTranslateOwnership indicates whether to not translate file - // ownership. - // - // By default, files/directories owned by the sandbox uses UID/GID - // of the mounter. For files/directories that are not owned by the - // sandbox, file UID/GID is translated to a UID/GID which cannot - // be mapped in the sandboxed application's user namespace. The - // UID/GID will look like the nobody UID/GID (65534) but is not - // strictly owned by the user "nobody". - // - // If whitelistfs is a lower filesystem in an overlay, set - // dont_translate_ownership=true in mount options. - dontTranslateOwnership bool -} - -var _ fs.MountSourceOperations = (*superOperations)(nil) - -// ResetInodeMappings implements fs.MountSourceOperations.ResetInodeMappings. -func (m *superOperations) ResetInodeMappings() { - m.inodeMappings = make(map[uint64]string) -} - -// SaveInodeMapping implements fs.MountSourceOperations.SaveInodeMapping. -func (m *superOperations) SaveInodeMapping(inode *fs.Inode, path string) { - // This is very unintuitive. We *CANNOT* trust the inode's StableAttrs, - // because overlay copyUp may have changed them out from under us. - // So much for "immutable". - sattr := inode.InodeOperations.(*inodeOperations).fileState.sattr - m.inodeMappings[sattr.InodeID] = path -} - -// Keep implements fs.MountSourceOperations.Keep. -// -// TODO(b/72455313,b/77596690): It is possible to change the permissions on a -// host file while it is in the dirent cache (say from RO to RW), but it is not -// possible to re-open the file with more relaxed permissions, since the host -// FD is already open and stored in the inode. -// -// Using the dirent LRU cache increases the odds that this bug is encountered. -// Since host file access is relatively fast anyways, we disable the LRU cache -// for host fs files. Once we can properly deal with permissions changes and -// re-opening host files, we should revisit whether or not to make use of the -// LRU cache. -func (*superOperations) Keep(*fs.Dirent) bool { - return false -} - -func init() { - fs.RegisterFilesystem(&Filesystem{}) -} diff --git a/pkg/sentry/fs/host/fs_test.go b/pkg/sentry/fs/host/fs_test.go deleted file mode 100644 index 3111d2df9..000000000 --- a/pkg/sentry/fs/host/fs_test.go +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "fmt" - "io/ioutil" - "os" - "path" - "reflect" - "sort" - "testing" - - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -// newTestMountNamespace creates a MountNamespace with a ramfs root. -// It returns the host folder created, which should be removed when done. -func newTestMountNamespace(t *testing.T) (*fs.MountNamespace, string, error) { - p, err := ioutil.TempDir("", "root") - if err != nil { - return nil, "", err - } - - fd, err := open(nil, p) - if err != nil { - os.RemoveAll(p) - return nil, "", err - } - ctx := contexttest.Context(t) - root, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false) - if err != nil { - os.RemoveAll(p) - return nil, "", err - } - mm, err := fs.NewMountNamespace(ctx, root) - if err != nil { - os.RemoveAll(p) - return nil, "", err - } - return mm, p, nil -} - -// createTestDirs populates the root with some test files and directories. -// /a/a1.txt -// /a/a2.txt -// /b/b1.txt -// /b/c/c1.txt -// /symlinks/normal.txt -// /symlinks/to_normal.txt -> /symlinks/normal.txt -// /symlinks/recursive -> /symlinks -func createTestDirs(ctx context.Context, t *testing.T, m *fs.MountNamespace) error { - r := m.Root() - defer r.DecRef() - - if err := r.CreateDirectory(ctx, r, "a", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - a, err := r.Walk(ctx, r, "a") - if err != nil { - return err - } - defer a.DecRef() - - a1, err := a.Create(ctx, r, "a1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - a1.DecRef() - - a2, err := a.Create(ctx, r, "a2.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - a2.DecRef() - - if err := r.CreateDirectory(ctx, r, "b", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - b, err := r.Walk(ctx, r, "b") - if err != nil { - return err - } - defer b.DecRef() - - b1, err := b.Create(ctx, r, "b1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - b1.DecRef() - - if err := b.CreateDirectory(ctx, r, "c", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - c, err := b.Walk(ctx, r, "c") - if err != nil { - return err - } - defer c.DecRef() - - c1, err := c.Create(ctx, r, "c1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - c1.DecRef() - - if err := r.CreateDirectory(ctx, r, "symlinks", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - symlinks, err := r.Walk(ctx, r, "symlinks") - if err != nil { - return err - } - defer symlinks.DecRef() - - normal, err := symlinks.Create(ctx, r, "normal.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - normal.DecRef() - - if err := symlinks.CreateLink(ctx, r, "/symlinks/normal.txt", "to_normal.txt"); err != nil { - return err - } - - return symlinks.CreateLink(ctx, r, "/symlinks", "recursive") -} - -// allPaths returns a slice of all paths of entries visible in the rootfs. -func allPaths(ctx context.Context, t *testing.T, m *fs.MountNamespace, base string) ([]string, error) { - var paths []string - root := m.Root() - defer root.DecRef() - - maxTraversals := uint(1) - d, err := m.FindLink(ctx, root, nil, base, &maxTraversals) - if err != nil { - t.Logf("FindLink failed for %q", base) - return paths, err - } - defer d.DecRef() - - if fs.IsDir(d.Inode.StableAttr) { - dir, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) - if err != nil { - return nil, fmt.Errorf("failed to open directory %q: %v", base, err) - } - iter, ok := dir.FileOperations.(fs.DirIterator) - if !ok { - return nil, fmt.Errorf("cannot directly iterate on host directory %q", base) - } - dirCtx := &fs.DirCtx{ - Serializer: noopDentrySerializer{}, - } - if _, err := fs.DirentReaddir(ctx, d, iter, root, dirCtx, 0); err != nil { - return nil, err - } - for name := range dirCtx.DentAttrs() { - if name == "." || name == ".." { - continue - } - - fullName := path.Join(base, name) - paths = append(paths, fullName) - - // Recurse. - subpaths, err := allPaths(ctx, t, m, fullName) - if err != nil { - return paths, err - } - paths = append(paths, subpaths...) - } - } - - return paths, nil -} - -type noopDentrySerializer struct{} - -func (noopDentrySerializer) CopyOut(string, fs.DentAttr) error { - return nil -} -func (noopDentrySerializer) Written() int { - return 4096 -} - -// pathsEqual returns true if the two string slices contain the same entries. -func pathsEqual(got, want []string) bool { - sort.Strings(got) - sort.Strings(want) - - if len(got) != len(want) { - return false - } - - for i := range got { - if got[i] != want[i] { - return false - } - } - - return true -} - -func TestWhitelist(t *testing.T) { - for _, test := range []struct { - // description of the test. - desc string - // paths are the paths to whitelist - paths []string - // want are all of the directory entries that should be - // visible (nothing beyond this set should be visible). - want []string - }{ - { - desc: "root", - paths: []string{"/"}, - want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt", "/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt", "/symlinks/recursive"}, - }, - { - desc: "top-level directories", - paths: []string{"/a", "/b"}, - want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "nested directories (1/2)", - paths: []string{"/b", "/b/c"}, - want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "nested directories (2/2)", - paths: []string{"/b/c", "/b"}, - want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "single file", - paths: []string{"/b/c/c1.txt"}, - want: []string{"/b", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "single file and directory", - paths: []string{"/a/a1.txt", "/b/c"}, - want: []string{"/a", "/a/a1.txt", "/b", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "symlink", - paths: []string{"/symlinks/to_normal.txt"}, - want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt"}, - }, - { - desc: "recursive symlink", - paths: []string{"/symlinks/recursive/normal.txt"}, - want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/recursive"}, - }, - } { - t.Run(test.desc, func(t *testing.T) { - m, p, err := newTestMountNamespace(t) - if err != nil { - t.Errorf("Failed to create MountNamespace: %v", err) - } - defer os.RemoveAll(p) - - ctx := withRoot(contexttest.RootContext(t), m.Root()) - if err := createTestDirs(ctx, t, m); err != nil { - t.Errorf("Failed to create test dirs: %v", err) - } - - if err := installWhitelist(ctx, m, test.paths); err != nil { - t.Errorf("installWhitelist(%v) err got %v want nil", test.paths, err) - } - - got, err := allPaths(ctx, t, m, "/") - if err != nil { - t.Fatalf("Failed to lookup paths (whitelisted: %v): %v", test.paths, err) - } - - if !pathsEqual(got, test.want) { - t.Errorf("For paths %v got %v want %v", test.paths, got, test.want) - } - }) - } -} - -func TestRootPath(t *testing.T) { - // Create a temp dir, which will be the root of our mounted fs. - rootPath, err := ioutil.TempDir(os.TempDir(), "root") - if err != nil { - t.Fatalf("TempDir failed: %v", err) - } - defer os.RemoveAll(rootPath) - - // Create two files inside the new root, one which will be whitelisted - // and one not. - whitelisted, err := ioutil.TempFile(rootPath, "white") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - if _, err := ioutil.TempFile(rootPath, "black"); err != nil { - t.Fatalf("TempFile failed: %v", err) - } - - // Create a mount with a root path and single whitelisted file. - hostFS := &Filesystem{} - ctx := contexttest.Context(t) - data := fmt.Sprintf("%s=%s,%s=%s", rootPathKey, rootPath, whitelistKey, whitelisted.Name()) - inode, err := hostFS.Mount(ctx, "", fs.MountSourceFlags{}, data, nil) - if err != nil { - t.Fatalf("Mount failed: %v", err) - } - mm, err := fs.NewMountNamespace(ctx, inode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - if err := hostFS.InstallWhitelist(ctx, mm); err != nil { - t.Fatalf("InstallWhitelist failed: %v", err) - } - - // Get the contents of the root directory. - rootDir := mm.Root() - rctx := withRoot(ctx, rootDir) - f, err := rootDir.Inode.GetFile(rctx, rootDir, fs.FileFlags{}) - if err != nil { - t.Fatalf("GetFile failed: %v", err) - } - c := &fs.CollectEntriesSerializer{} - if err := f.Readdir(rctx, c); err != nil { - t.Fatalf("Readdir failed: %v", err) - } - - // We should have only our whitelisted file, plus the dots. - want := []string{path.Base(whitelisted.Name()), ".", ".."} - got := c.Order - sort.Strings(want) - sort.Strings(got) - if !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got %v, wanted %v", got, want) - } -} - -type rootContext struct { - context.Context - root *fs.Dirent -} - -// withRoot returns a copy of ctx with the given root. -func withRoot(ctx context.Context, root *fs.Dirent) context.Context { - return &rootContext{ - Context: ctx, - root: root, - } -} - -// Value implements Context.Value. -func (rc rootContext) Value(key interface{}) interface{} { - switch key { - case fs.CtxRoot: - rc.root.IncRef() - return rc.root - default: - return rc.Context.Value(key) - } -} diff --git a/pkg/sentry/fs/host/host.go b/pkg/sentry/fs/host/host.go new file mode 100644 index 000000000..081ba1dd8 --- /dev/null +++ b/pkg/sentry/fs/host/host.go @@ -0,0 +1,59 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package host supports file descriptors imported directly. +package host + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// filesystem is a host filesystem. +// +// +stateify savable +type filesystem struct{} + +func init() { + fs.RegisterFilesystem(&filesystem{}) +} + +// FilesystemName is the name under which the filesystem is registered. +const FilesystemName = "host" + +// Name is the name of the filesystem. +func (*filesystem) Name() string { + return FilesystemName +} + +// Mount returns an error. Mounting hostfs is not allowed. +func (*filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, dataObj interface{}) (*fs.Inode, error) { + return nil, syserror.EPERM +} + +// AllowUserMount prohibits users from using mount(2) with this file system. +func (*filesystem) AllowUserMount() bool { + return false +} + +// AllowUserList prohibits this filesystem to be listed in /proc/filesystems. +func (*filesystem) AllowUserList() bool { + return false +} + +// Flags returns that there is nothing special about this file system. +func (*filesystem) Flags() fs.FilesystemFlags { + return 0 +} diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 6fa39caab..1da3c0a17 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -17,12 +17,10 @@ package host import ( "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/secio" - "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -69,9 +67,6 @@ type inodeOperations struct { // // +stateify savable type inodeFileState struct { - // Common file system state. - mops *superOperations `state:"wait"` - // descriptor is the backing host FD. descriptor *descriptor `state:"wait"` @@ -160,7 +155,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err if err := syscall.Fstat(i.FD(), &s); err != nil { return fs.UnstableAttr{}, err } - return unstableAttr(i.mops, &s), nil + return unstableAttr(&s), nil } // Allocate implements fsutil.CachedFileObject.Allocate. @@ -172,7 +167,7 @@ func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error var _ fs.InodeOperations = (*inodeOperations)(nil) // newInode returns a new fs.Inode backed by the host FD. -func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool, donated bool) (*fs.Inode, error) { +func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool) (*fs.Inode, error) { // Retrieve metadata. var s syscall.Stat_t err := syscall.Fstat(fd, &s) @@ -181,24 +176,17 @@ func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool, } fileState := &inodeFileState{ - mops: msrc.MountSourceOperations.(*superOperations), sattr: stableAttr(&s), } // Initialize the wrapped host file descriptor. - fileState.descriptor, err = newDescriptor( - fd, - donated, - saveable, - wouldBlock(&s), - &fileState.queue, - ) + fileState.descriptor, err = newDescriptor(fd, saveable, wouldBlock(&s), &fileState.queue) if err != nil { return nil, err } // Build the fs.InodeOperations. - uattr := unstableAttr(msrc.MountSourceOperations.(*superOperations), &s) + uattr := unstableAttr(&s) iops := &inodeOperations{ fileState: fileState, cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{ @@ -232,54 +220,23 @@ func (i *inodeOperations) Release(context.Context) { // Lookup implements fs.InodeOperations.Lookup. func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) { - // Get a new FD relative to i at name. - fd, err := open(i, name) - if err != nil { - if err == syserror.ENOENT { - return nil, syserror.ENOENT - } - return nil, err - } - - inode, err := newInode(ctx, dir.MountSource, fd, false /* saveable */, false /* donated */) - if err != nil { - return nil, err - } - - // Return the fs.Dirent. - return fs.NewDirent(ctx, inode, name), nil + return nil, syserror.ENOENT } // Create implements fs.InodeOperations.Create. func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perm fs.FilePermissions) (*fs.File, error) { - // Create a file relative to i at name. - // - // N.B. We always open this file O_RDWR regardless of flags because a - // future GetFile might want more access. Open allows this regardless - // of perm. - fd, err := openAt(i, name, syscall.O_RDWR|syscall.O_CREAT|syscall.O_EXCL, perm.LinuxMode()) - if err != nil { - return nil, err - } - - inode, err := newInode(ctx, dir.MountSource, fd, false /* saveable */, false /* donated */) - if err != nil { - return nil, err - } + return nil, syserror.EPERM - d := fs.NewDirent(ctx, inode, name) - defer d.DecRef() - return inode.GetFile(ctx, d, flags) } // CreateDirectory implements fs.InodeOperations.CreateDirectory. func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error { - return syscall.Mkdirat(i.fileState.FD(), name, uint32(perm.LinuxMode())) + return syserror.EPERM } // CreateLink implements fs.InodeOperations.CreateLink. func (i *inodeOperations) CreateLink(ctx context.Context, dir *fs.Inode, oldname string, newname string) error { - return createLink(i.fileState.FD(), oldname, newname) + return syserror.EPERM } // CreateHardLink implements fs.InodeOperations.CreateHardLink. @@ -294,25 +251,17 @@ func (*inodeOperations) CreateFifo(context.Context, *fs.Inode, string, fs.FilePe // Remove implements fs.InodeOperations.Remove. func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string) error { - return unlinkAt(i.fileState.FD(), name, false /* dir */) + return syserror.EPERM } // RemoveDirectory implements fs.InodeOperations.RemoveDirectory. func (i *inodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, name string) error { - return unlinkAt(i.fileState.FD(), name, true /* dir */) + return syserror.EPERM } // Rename implements fs.InodeOperations.Rename. func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error { - op, ok := oldParent.InodeOperations.(*inodeOperations) - if !ok { - return syscall.EXDEV - } - np, ok := newParent.InodeOperations.(*inodeOperations) - if !ok { - return syscall.EXDEV - } - return syscall.Renameat(op.fileState.FD(), oldName, np.fileState.FD(), newName) + return syserror.EPERM } // Bind implements fs.InodeOperations.Bind. @@ -461,69 +410,7 @@ func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {} // readdirAll returns all of the directory entries in i. func (i *inodeOperations) readdirAll(d *dirInfo) (map[string]fs.DentAttr, error) { - i.readdirMu.Lock() - defer i.readdirMu.Unlock() - - fd := i.fileState.FD() - - // syscall.ReadDirent will use getdents, which will seek the file past - // the last directory entry. To read the directory entries a second - // time, we need to seek back to the beginning. - if _, err := syscall.Seek(fd, 0, 0); err != nil { - if err == syscall.ESPIPE { - // All directories should be seekable. If this file - // isn't seekable, it is not a directory and we should - // return that more sane error. - err = syscall.ENOTDIR - } - return nil, err - } - - names := make([]string, 0, 100) - for { - // Refill the buffer if necessary - if d.bufp >= d.nbuf { - d.bufp = 0 - // ReadDirent will just do a sys_getdents64 to the kernel. - n, err := syscall.ReadDirent(fd, d.buf) - if err != nil { - return nil, err - } - if n == 0 { - break // EOF - } - d.nbuf = n - } - - var nb int - // Parse the dirent buffer we just get and return the directory names along - // with the number of bytes consumed in the buffer. - nb, _, names = syscall.ParseDirent(d.buf[d.bufp:d.nbuf], -1, names) - d.bufp += nb - } - - entries := make(map[string]fs.DentAttr) - for _, filename := range names { - // Lookup the type and host device and inode. - stat, lerr := fstatat(fd, filename, linux.AT_SYMLINK_NOFOLLOW) - if lerr == syscall.ENOENT { - // File disappeared between readdir and lstat. - // Just treat it as if it didn't exist. - continue - } - - // There was a serious problem, we should probably report it. - if lerr != nil { - return nil, lerr - } - - entries[filename] = fs.DentAttr{ - Type: nodeType(&stat), - InodeID: hostFileDevice.Map(device.MultiDeviceKey{ - Device: stat.Dev, - Inode: stat.Ino, - }), - } - } - return entries, nil + // We only support non-directory file descriptors that have been + // imported, so just claim that this isn't a directory, even if it is. + return nil, syscall.ENOTDIR } diff --git a/pkg/sentry/fs/host/inode_state.go b/pkg/sentry/fs/host/inode_state.go index 299e0e0b0..1adbd4562 100644 --- a/pkg/sentry/fs/host/inode_state.go +++ b/pkg/sentry/fs/host/inode_state.go @@ -18,29 +18,14 @@ import ( "fmt" "syscall" - "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" ) -// beforeSave is invoked by stateify. -func (i *inodeFileState) beforeSave() { - if !i.queue.IsEmpty() { - panic("event queue must be empty") - } - if !i.descriptor.donated && i.sattr.Type == fs.RegularFile { - uattr, err := i.unstableAttr(context.Background()) - if err != nil { - panic(fs.ErrSaveRejection{fmt.Errorf("failed to get unstable atttribute of %s: %v", i.mops.inodeMappings[i.sattr.InodeID], err)}) - } - i.savedUAttr = &uattr - } -} - // afterLoad is invoked by stateify. func (i *inodeFileState) afterLoad() { // Initialize the descriptor value. - if err := i.descriptor.initAfterLoad(i.mops, i.sattr.InodeID, &i.queue); err != nil { + if err := i.descriptor.initAfterLoad(i.sattr.InodeID, &i.queue); err != nil { panic(fmt.Sprintf("failed to load value of descriptor: %v", err)) } @@ -61,19 +46,4 @@ func (i *inodeFileState) afterLoad() { // change across save and restore, error out. panic(fs.ErrCorruption{fmt.Errorf("host %s conflict in host device mappings: %s", key, hostFileDevice)}) } - - if !i.descriptor.donated && i.sattr.Type == fs.RegularFile { - env, ok := fs.CurrentRestoreEnvironment() - if !ok { - panic("missing restore environment") - } - uattr := unstableAttr(i.mops, &s) - if env.ValidateFileSize && uattr.Size != i.savedUAttr.Size { - panic(fs.ErrCorruption{fmt.Errorf("file size has changed for %s: previously %d, now %d", i.mops.inodeMappings[i.sattr.InodeID], i.savedUAttr.Size, uattr.Size)}) - } - if env.ValidateFileTimestamp && uattr.ModificationTime != i.savedUAttr.ModificationTime { - panic(fs.ErrCorruption{fmt.Errorf("file modification time has changed for %s: previously %v, now %v", i.mops.inodeMappings[i.sattr.InodeID], i.savedUAttr.ModificationTime, uattr.ModificationTime)}) - } - i.savedUAttr = nil - } } diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go index 7221bc825..4c374681c 100644 --- a/pkg/sentry/fs/host/inode_test.go +++ b/pkg/sentry/fs/host/inode_test.go @@ -15,9 +15,6 @@ package host import ( - "io/ioutil" - "os" - "path" "syscall" "testing" @@ -25,69 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" ) -// TestMultipleReaddir verifies that multiple Readdir calls return the same -// thing if they use different dir contexts. -func TestMultipleReaddir(t *testing.T) { - p, err := ioutil.TempDir("", "readdir") - if err != nil { - t.Fatalf("Failed to create test dir: %v", err) - } - defer os.RemoveAll(p) - - f, err := os.Create(path.Join(p, "a.txt")) - if err != nil { - t.Fatalf("Failed to create a.txt: %v", err) - } - f.Close() - - f, err = os.Create(path.Join(p, "b.txt")) - if err != nil { - t.Fatalf("Failed to create b.txt: %v", err) - } - f.Close() - - fd, err := open(nil, p) - if err != nil { - t.Fatalf("Failed to open %q: %v", p, err) - } - ctx := contexttest.Context(t) - n, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false) - if err != nil { - t.Fatalf("Failed to create inode: %v", err) - } - - dirent := fs.NewDirent(ctx, n, "readdir") - openFile, err := n.GetFile(ctx, dirent, fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("Failed to get file: %v", err) - } - defer openFile.DecRef() - - c1 := &fs.DirCtx{DirCursor: new(string)} - if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, dirent, c1, 0); err != nil { - t.Fatalf("First Readdir failed: %v", err) - } - - c2 := &fs.DirCtx{DirCursor: new(string)} - if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, dirent, c2, 0); err != nil { - t.Errorf("Second Readdir failed: %v", err) - } - - if _, ok := c1.DentAttrs()["a.txt"]; !ok { - t.Errorf("want a.txt in first Readdir, got %v", c1.DentAttrs()) - } - if _, ok := c1.DentAttrs()["b.txt"]; !ok { - t.Errorf("want b.txt in first Readdir, got %v", c1.DentAttrs()) - } - - if _, ok := c2.DentAttrs()["a.txt"]; !ok { - t.Errorf("want a.txt in second Readdir, got %v", c2.DentAttrs()) - } - if _, ok := c2.DentAttrs()["b.txt"]; !ok { - t.Errorf("want b.txt in second Readdir, got %v", c2.DentAttrs()) - } -} - // TestCloseFD verifies fds will be closed. func TestCloseFD(t *testing.T) { var p [2]int diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go index e37e687c6..388108fdf 100644 --- a/pkg/sentry/fs/host/util.go +++ b/pkg/sentry/fs/host/util.go @@ -16,7 +16,6 @@ package host import ( "os" - "path" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -24,49 +23,10 @@ import ( "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/syserror" ) -func open(parent *inodeOperations, name string) (int, error) { - if parent == nil && !path.IsAbs(name) { - return -1, syserror.EINVAL - } - name = path.Clean(name) - - // Don't follow through symlinks. - flags := syscall.O_NOFOLLOW - - if fd, err := openAt(parent, name, flags|syscall.O_RDWR, 0); err == nil { - return fd, nil - } - // Retry as read-only. - if fd, err := openAt(parent, name, flags|syscall.O_RDONLY, 0); err == nil { - return fd, nil - } - - // Retry as write-only. - if fd, err := openAt(parent, name, flags|syscall.O_WRONLY, 0); err == nil { - return fd, nil - } - - // Retry as a symlink, by including O_PATH as an option. - fd, err := openAt(parent, name, linux.O_PATH|flags, 0) - if err == nil { - return fd, nil - } - - // Everything failed. - return -1, err -} - -func openAt(parent *inodeOperations, name string, flags int, perm linux.FileMode) (int, error) { - if parent == nil { - return syscall.Open(name, flags, uint32(perm)) - } - return syscall.Openat(parent.fileState.FD(), name, flags, uint32(perm)) -} - func nodeType(s *syscall.Stat_t) fs.InodeType { switch x := (s.Mode & syscall.S_IFMT); x { case syscall.S_IFLNK: @@ -107,54 +67,22 @@ func stableAttr(s *syscall.Stat_t) fs.StableAttr { } } -func owner(mo *superOperations, s *syscall.Stat_t) fs.FileOwner { - // User requested no translation, just return actual owner. - if mo.dontTranslateOwnership { - return fs.FileOwner{auth.KUID(s.Uid), auth.KGID(s.Gid)} +func owner(s *syscall.Stat_t) fs.FileOwner { + return fs.FileOwner{ + UID: auth.KUID(s.Uid), + GID: auth.KGID(s.Gid), } - - // Show only IDs relevant to the sandboxed task. I.e. if we not own the - // file, no sandboxed task can own the file. In that case, we - // use OverflowID for UID, implying that the IDs are not mapped in the - // "root" user namespace. - // - // E.g. - // sandbox's host EUID/EGID is 1/1. - // some_dir's host UID/GID is 2/1. - // Task that mounted this fs has virtualized EUID/EGID 5/5. - // - // If you executed `ls -n` in the sandboxed task, it would show: - // drwxwrxwrx [...] 65534 5 [...] some_dir - - // Files are owned by OverflowID by default. - owner := fs.FileOwner{auth.KUID(auth.OverflowUID), auth.KGID(auth.OverflowGID)} - - // If we own file on host, let mounting task's initial EUID own - // the file. - if s.Uid == hostUID { - owner.UID = mo.mounter.UID - } - - // If our group matches file's group, make file's group match - // the mounting task's initial EGID. - for _, gid := range hostGIDs { - if s.Gid == gid { - owner.GID = mo.mounter.GID - break - } - } - return owner } -func unstableAttr(mo *superOperations, s *syscall.Stat_t) fs.UnstableAttr { +func unstableAttr(s *syscall.Stat_t) fs.UnstableAttr { return fs.UnstableAttr{ Size: s.Size, Usage: s.Blocks * 512, Perms: fs.FilePermsFromMode(linux.FileMode(s.Mode)), - Owner: owner(mo, s), - AccessTime: ktime.FromUnix(s.Atim.Sec, s.Atim.Nsec), - ModificationTime: ktime.FromUnix(s.Mtim.Sec, s.Mtim.Nsec), - StatusChangeTime: ktime.FromUnix(s.Ctim.Sec, s.Ctim.Nsec), + Owner: owner(s), + AccessTime: time.FromUnix(s.Atim.Sec, s.Atim.Nsec), + ModificationTime: time.FromUnix(s.Mtim.Sec, s.Mtim.Nsec), + StatusChangeTime: time.FromUnix(s.Ctim.Sec, s.Ctim.Nsec), Links: uint64(s.Nlink), } } @@ -165,6 +93,8 @@ type dirInfo struct { bufp int // location of next record in buf. } +// LINT.IfChange + // isBlockError unwraps os errors and checks if they are caused by EAGAIN or // EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { @@ -177,6 +107,8 @@ func isBlockError(err error) bool { return false } +// LINT.ThenChange(../../fsimpl/host/util.go) + func hostEffectiveKIDs() (uint32, []uint32, error) { gids, err := os.Getgroups() if err != nil { diff --git a/pkg/sentry/fs/host/util_unsafe.go b/pkg/sentry/fs/host/util_unsafe.go index 3ab36b088..23bd35d64 100644 --- a/pkg/sentry/fs/host/util_unsafe.go +++ b/pkg/sentry/fs/host/util_unsafe.go @@ -26,26 +26,6 @@ import ( // NulByte is a single NUL byte. It is passed to readlinkat as an empty string. var NulByte byte = '\x00' -func createLink(fd int, name string, linkName string) error { - namePtr, err := syscall.BytePtrFromString(name) - if err != nil { - return err - } - linkNamePtr, err := syscall.BytePtrFromString(linkName) - if err != nil { - return err - } - _, _, errno := syscall.Syscall( - syscall.SYS_SYMLINKAT, - uintptr(unsafe.Pointer(namePtr)), - uintptr(fd), - uintptr(unsafe.Pointer(linkNamePtr))) - if errno != 0 { - return errno - } - return nil -} - func readLink(fd int) (string, error) { // Buffer sizing copied from os.Readlink. for l := 128; ; l *= 2 { @@ -66,27 +46,6 @@ func readLink(fd int) (string, error) { } } -func unlinkAt(fd int, name string, dir bool) error { - namePtr, err := syscall.BytePtrFromString(name) - if err != nil { - return err - } - var flags uintptr - if dir { - flags = linux.AT_REMOVEDIR - } - _, _, errno := syscall.Syscall( - syscall.SYS_UNLINKAT, - uintptr(fd), - uintptr(unsafe.Pointer(namePtr)), - flags, - ) - if errno != 0 { - return errno - } - return nil -} - func timespecFromTimestamp(t ktime.Time, omit, setSysTime bool) syscall.Timespec { if omit { return syscall.Timespec{0, linux.UTIME_OMIT} diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index 928c90aa0..e3a715c1f 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -143,7 +143,10 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i } var writeLen int64 - for event := i.events.Front(); event != nil; event = event.Next() { + for it := i.events.Front(); it != nil; { + event := it + it = it.Next() + // Does the buffer have enough remaining space to hold the event we're // about to write out? if dst.NumBytes() < int64(event.sizeOf()) { diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 95d5817ff..bd18177d4 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -40,47 +40,48 @@ import ( // LINT.IfChange -// newNet creates a new proc net entry. -func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode { +// newNetDir creates a new proc net entry. +func newNetDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + k := t.Kernel() + var contents map[string]*fs.Inode - // TODO(gvisor.dev/issue/1833): Support for using the network stack in the - // network namespace of the calling process. We should make this per-process, - // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net. - if s := p.k.RootNetworkNamespace().Stack(); s != nil { + if s := t.NetworkNamespace().Stack(); s != nil { + // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task + // network namespace. contents = map[string]*fs.Inode{ - "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), - "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc), + "dev": seqfile.NewSeqFileInode(t, &netDev{s: s}, msrc), + "snmp": seqfile.NewSeqFileInode(t, &netSnmp{s: s}, msrc), // The following files are simple stubs until they are // implemented in netstack, if the file contains a // header the stub is just the header otherwise it is // an empty file. - "arp": newStaticProcInode(ctx, msrc, []byte("IP address HW type Flags HW address Mask Device\n")), + "arp": newStaticProcInode(t, msrc, []byte("IP address HW type Flags HW address Mask Device\n")), - "netlink": newStaticProcInode(ctx, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")), - "netstat": newStaticProcInode(ctx, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")), - "packet": newStaticProcInode(ctx, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")), - "protocols": newStaticProcInode(ctx, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")), + "netlink": newStaticProcInode(t, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")), + "netstat": newStaticProcInode(t, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")), + "packet": newStaticProcInode(t, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")), + "protocols": newStaticProcInode(t, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")), // Linux sets psched values to: nsec per usec, psched // tick in ns, 1000000, high res timer ticks per sec // (ClockGetres returns 1ns resolution). - "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))), - "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function\n")), - "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc), - "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc), - "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc), - "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc), + "psched": newStaticProcInode(t, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))), + "ptype": newStaticProcInode(t, msrc, []byte("Type Device Function\n")), + "route": seqfile.NewSeqFileInode(t, &netRoute{s: s}, msrc), + "tcp": seqfile.NewSeqFileInode(t, &netTCP{k: k}, msrc), + "udp": seqfile.NewSeqFileInode(t, &netUDP{k: k}, msrc), + "unix": seqfile.NewSeqFileInode(t, &netUnix{k: k}, msrc), } if s.SupportsIPv6() { - contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc) - contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte("")) - contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc) - contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n")) + contents["if_inet6"] = seqfile.NewSeqFileInode(t, &ifinet6{s: s}, msrc) + contents["ipv6_route"] = newStaticProcInode(t, msrc, []byte("")) + contents["tcp6"] = seqfile.NewSeqFileInode(t, &netTCP6{k: k}, msrc) + contents["udp6"] = newStaticProcInode(t, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n")) } } - d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) - return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil) + d := ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) + return newProcInode(t, d, msrc, fs.SpecialDirectory, t) } // ifinet6 implements seqfile.SeqSource for /proc/net/if_inet6. @@ -837,4 +838,4 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se return data, 0 } -// LINT.ThenChange(../../fsimpl/proc/tasks_net.go) +// LINT.ThenChange(../../fsimpl/proc/task_net.go) diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index c8abb5052..c659224a7 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -70,6 +70,7 @@ func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string "loadavg": seqfile.NewSeqFileInode(ctx, &loadavgData{}, msrc), "meminfo": seqfile.NewSeqFileInode(ctx, &meminfoData{k}, msrc), "mounts": newProcInode(ctx, ramfs.NewSymlink(ctx, fs.RootOwner, "self/mounts"), msrc, fs.Symlink, nil), + "net": newProcInode(ctx, ramfs.NewSymlink(ctx, fs.RootOwner, "self/net"), msrc, fs.Symlink, nil), "self": newSelf(ctx, pidns, msrc), "stat": seqfile.NewSeqFileInode(ctx, &statData{k}, msrc), "thread-self": newThreadSelf(ctx, pidns, msrc), @@ -86,7 +87,6 @@ func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string } // Add more contents that need proc to be initialized. - p.AddChild(ctx, "net", p.newNetDir(ctx, k, msrc)) p.AddChild(ctx, "sys", p.newSysDir(ctx, msrc)) return newProcInode(ctx, p, msrc, fs.SpecialDirectory, nil), nil diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 8ab8d8a02..d6c5dd2c1 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -72,24 +72,27 @@ var _ fs.InodeOperations = (*taskDir)(nil) // newTaskDir creates a new proc task entry. func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { contents := map[string]*fs.Inode{ - "auxv": newAuxvec(t, msrc), - "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), - "comm": newComm(t, msrc), - "environ": newExecArgInode(t, msrc, environExecArg), - "exe": newExe(t, msrc), - "fd": newFdDir(t, msrc), - "fdinfo": newFdInfoDir(t, msrc), - "gid_map": newGIDMap(t, msrc), - "io": newIO(t, msrc, isThreadGroup), - "maps": newMaps(t, msrc), - "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), - "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), - "ns": newNamespaceDir(t, msrc), - "smaps": newSmaps(t, msrc), - "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns), - "statm": newStatm(t, msrc), - "status": newStatus(t, msrc, p.pidns), - "uid_map": newUIDMap(t, msrc), + "auxv": newAuxvec(t, msrc), + "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), + "comm": newComm(t, msrc), + "environ": newExecArgInode(t, msrc, environExecArg), + "exe": newExe(t, msrc), + "fd": newFdDir(t, msrc), + "fdinfo": newFdInfoDir(t, msrc), + "gid_map": newGIDMap(t, msrc), + "io": newIO(t, msrc, isThreadGroup), + "maps": newMaps(t, msrc), + "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), + "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), + "net": newNetDir(t, msrc), + "ns": newNamespaceDir(t, msrc), + "oom_score": newOOMScore(t, msrc), + "oom_score_adj": newOOMScoreAdj(t, msrc), + "smaps": newSmaps(t, msrc), + "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns), + "statm": newStatm(t, msrc), + "status": newStatus(t, msrc, p.pidns), + "uid_map": newUIDMap(t, msrc), } if isThreadGroup { contents["task"] = p.newSubtasks(t, msrc) @@ -796,4 +799,95 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc return int64(n), err } +// newOOMScore returns a oom_score file. It is a stub that always returns 0. +// TODO(gvisor.dev/issue/1967) +func newOOMScore(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newStaticProcInode(t, msrc, []byte("0\n")) +} + +// oomScoreAdj is a file containing the oom_score adjustment for a task. +// +// +stateify savable +type oomScoreAdj struct { + fsutil.SimpleFileInode + + t *kernel.Task +} + +// +stateify savable +type oomScoreAdjFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + t *kernel.Task +} + +// newOOMScoreAdj returns a oom_score_adj file. +func newOOMScoreAdj(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + i := &oomScoreAdj{ + SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), + t: t, + } + return newProcInode(t, i, msrc, fs.SpecialFile, t) +} + +// Truncate implements fs.InodeOperations.Truncate. Truncate is called when +// O_TRUNC is specified for any kind of existing Dirent but is not called via +// (f)truncate for proc files. +func (*oomScoreAdj) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// GetFile implements fs.InodeOperations.GetFile. +func (o *oomScoreAdj) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + return fs.NewFile(ctx, dirent, flags, &oomScoreAdjFile{t: o.t}), nil +} + +// Read implements fs.FileOperations.Read. +func (f *oomScoreAdjFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if f.t.ExitState() == kernel.TaskExitDead { + return 0, syserror.ESRCH + } + var buf bytes.Buffer + fmt.Fprintf(&buf, "%d\n", f.t.OOMScoreAdj()) + if offset >= int64(buf.Len()) { + return 0, io.EOF + } + n, err := dst.CopyOut(ctx, buf.Bytes()[offset:]) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +func (f *oomScoreAdjFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit input size so as not to impact performance if input size is large. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + + if f.t.ExitState() == kernel.TaskExitDead { + return 0, syserror.ESRCH + } + if err := f.t.SetOOMScoreAdj(v); err != nil { + return 0, err + } + + return n, nil +} + // LINT.ThenChange(../../fsimpl/proc/task.go|../../fsimpl/proc/task_files.go) diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go index 6aa17bfc1..79b808359 100644 --- a/pkg/sentry/fsbridge/vfs.go +++ b/pkg/sentry/fsbridge/vfs.go @@ -115,8 +115,6 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry) // // remainingTraversals is not configurable in VFS2, all callers are using the // default anyways. -// -// TODO(gvisor.dev/issue/1623): Check mount has read and exec permission. func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) { vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem() creds := auth.CredentialsFromContext(ctx) diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index e05429d41..8497be615 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -255,6 +256,15 @@ func (fs *filesystem) statTo(stat *linux.Statfs) { // TODO(b/134676337): Set Statfs.Flags and Statfs.FSID. } +// AccessAt implements vfs.Filesystem.Impl.AccessAt. +func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { + _, inode, err := fs.walk(rp, false) + if err != nil { + return err + } + return inode.checkPermissions(rp.Credentials(), ats) +} + // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { vfsd, inode, err := fs.walk(rp, false) diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 5cfb0dc4c..38e4cdbc5 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" ) @@ -499,6 +500,18 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) { putDentrySlice(*ds) } +// AccessAt implements vfs.Filesystem.Impl.AccessAt. +func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckCaching(&ds) + d, err := fs.resolveLocked(ctx, rp, &ds) + if err != nil { + return err + } + return d.checkPermissions(creds, ats, d.isDir()) +} + // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { var ds *[]*dentry diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index c4a8f0b38..999485492 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -713,7 +713,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 { return syserror.EPERM } - if err := vfs.CheckSetStat(creds, stat, uint16(atomic.LoadUint32(&d.mode))&^linux.S_IFMT, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, stat, uint16(atomic.LoadUint32(&d.mode))&^linux.S_IFMT, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } if err := mnt.CheckBeginWrite(); err != nil { diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index e95209661..3593eb1d5 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -126,6 +126,11 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off if opts.Flags != 0 { return 0, syserror.EOPNOTSUPP } + limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes()) + if err != nil { + return 0, err + } + src = src.TakeFirst64(limit) d := fd.dentry() d.metadataMu.Lock() diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 08c691c47..274f7346f 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -107,6 +107,14 @@ func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off return 0, syserror.EOPNOTSUPP } + if fd.dentry().fileType() == linux.S_IFREG { + limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes()) + if err != nil { + return 0, err + } + src = src.TakeFirst64(limit) + } + // Do a buffered write. See rationale in PRead. if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { d.touchCMtime(ctx) diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD new file mode 100644 index 000000000..5d67f88e3 --- /dev/null +++ b/pkg/sentry/fsimpl/host/BUILD @@ -0,0 +1,29 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "host", + srcs = [ + "default_file.go", + "host.go", + "util.go", + ], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/fd", + "//pkg/log", + "//pkg/refs", + "//pkg/safemem", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", + "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/sentry/fsimpl/host/default_file.go b/pkg/sentry/fsimpl/host/default_file.go new file mode 100644 index 000000000..459238603 --- /dev/null +++ b/pkg/sentry/fsimpl/host/default_file.go @@ -0,0 +1,247 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "math" + "syscall" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// defaultFileFD implements FileDescriptionImpl for non-socket, non-TTY files. +type defaultFileFD struct { + fileDescription + + // canMap specifies whether we allow the file to be memory mapped. + canMap bool + + // mu protects the fields below. + mu sync.Mutex + + // offset specifies the current file offset. + offset int64 +} + +// TODO(gvisor.dev/issue/1672): Implement Waitable interface. + +// PRead implements FileDescriptionImpl. +func (f *defaultFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. + if f.inode.isStream { + return 0, syserror.ESPIPE + } + + return readFromHostFD(ctx, f.inode.hostFD, dst, offset, int(opts.Flags)) +} + +// Read implements FileDescriptionImpl. +func (f *defaultFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. + if f.inode.isStream { + // These files can't be memory mapped, assert this. + if f.canMap { + panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") + } + + n, err := readFromHostFD(ctx, f.inode.hostFD, dst, -1, int(opts.Flags)) + if isBlockError(err) { + // If we got any data at all, return it as a "completed" partial read + // rather than retrying until complete. + if n != 0 { + err = nil + } else { + err = syserror.ErrWouldBlock + } + } + return n, err + } + // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. + f.mu.Lock() + n, err := readFromHostFD(ctx, f.inode.hostFD, dst, f.offset, int(opts.Flags)) + f.offset += n + f.mu.Unlock() + return n, err +} + +func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags int) (int64, error) { + // TODO(gvisor.dev/issue/1672): Support select preadv2 flags. + if flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + var reader safemem.Reader + if offset == -1 { + reader = safemem.FromIOReader{fd.NewReadWriter(hostFD)} + } else { + reader = safemem.FromVecReaderFunc{ + func(srcs [][]byte) (int64, error) { + n, err := unix.Preadv(hostFD, srcs, offset) + return int64(n), err + }, + } + } + n, err := dst.CopyOutFrom(ctx, reader) + return int64(n), err +} + +// PWrite implements FileDescriptionImpl. +func (f *defaultFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. + if f.inode.isStream { + return 0, syserror.ESPIPE + } + return writeToHostFD(ctx, f.inode.hostFD, src, offset, int(opts.Flags)) +} + +// Write implements FileDescriptionImpl. +func (f *defaultFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // TODO(b/34716638): Some char devices do support offsets, e.g. /dev/null. + if f.inode.isStream { + // These files can't be memory mapped, assert this. + if f.canMap { + panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") + } + + n, err := writeToHostFD(ctx, f.inode.hostFD, src, -1, int(opts.Flags)) + if isBlockError(err) { + err = syserror.ErrWouldBlock + } + return n, err + } + // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. + // TODO(gvisor.dev/issue/1672): Write to end of file and update offset if O_APPEND is set on this file. + f.mu.Lock() + n, err := writeToHostFD(ctx, f.inode.hostFD, src, f.offset, int(opts.Flags)) + f.offset += n + f.mu.Unlock() + return n, err +} + +func writeToHostFD(ctx context.Context, hostFD int, src usermem.IOSequence, offset int64, flags int) (int64, error) { + // TODO(gvisor.dev/issue/1672): Support select pwritev2 flags. + if flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes()) + if err != nil { + return 0, err + } + src = src.TakeFirst64(limit) + + var writer safemem.Writer + if offset == -1 { + writer = safemem.FromIOWriter{fd.NewReadWriter(hostFD)} + } else { + writer = safemem.FromVecWriterFunc{ + func(srcs [][]byte) (int64, error) { + n, err := unix.Pwritev(hostFD, srcs, offset) + return int64(n), err + }, + } + } + n, err := src.CopyInTo(ctx, writer) + return int64(n), err +} + +// Seek implements FileDescriptionImpl. +// +// Note that we do not support seeking on directories, since we do not even +// allow directory fds to be imported at all. +func (f *defaultFileFD) Seek(_ context.Context, offset int64, whence int32) (int64, error) { + // TODO(b/34716638): Some char devices do support seeking, e.g. /dev/null. + if f.inode.isStream { + return 0, syserror.ESPIPE + } + + f.mu.Lock() + defer f.mu.Unlock() + + switch whence { + case linux.SEEK_SET: + if offset < 0 { + return f.offset, syserror.EINVAL + } + f.offset = offset + + case linux.SEEK_CUR: + // Check for overflow. Note that underflow cannot occur, since f.offset >= 0. + if offset > math.MaxInt64-f.offset { + return f.offset, syserror.EOVERFLOW + } + if f.offset+offset < 0 { + return f.offset, syserror.EINVAL + } + f.offset += offset + + case linux.SEEK_END: + var s syscall.Stat_t + if err := syscall.Fstat(f.inode.hostFD, &s); err != nil { + return f.offset, err + } + size := s.Size + + // Check for overflow. Note that underflow cannot occur, since size >= 0. + if offset > math.MaxInt64-size { + return f.offset, syserror.EOVERFLOW + } + if size+offset < 0 { + return f.offset, syserror.EINVAL + } + f.offset = size + offset + + case linux.SEEK_DATA, linux.SEEK_HOLE: + // Modifying the offset in the host file table should not matter, since + // this is the only place where we use it. + // + // For reading and writing, we always rely on our internal offset. + n, err := unix.Seek(f.inode.hostFD, offset, int(whence)) + if err != nil { + return f.offset, err + } + f.offset = n + + default: + // Invalid whence. + return f.offset, syserror.EINVAL + } + + return f.offset, nil +} + +// Sync implements FileDescriptionImpl. +func (f *defaultFileFD) Sync(context.Context) error { + // TODO(gvisor.dev/issue/1672): Currently we do not support the SyncData optimization, so we always sync everything. + return unix.Fsync(f.inode.hostFD) +} + +// ConfigureMMap implements FileDescriptionImpl. +func (f *defaultFileFD) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error { + if !f.canMap { + return syserror.ENODEV + } + // TODO(gvisor.dev/issue/1672): Implement ConfigureMMap and Mappable interface. + return syserror.ENODEV +} diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go new file mode 100644 index 000000000..2eebcd60c --- /dev/null +++ b/pkg/sentry/fsimpl/host/host.go @@ -0,0 +1,396 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package host provides a filesystem implementation for host files imported as +// file descriptors. +package host + +import ( + "errors" + "fmt" + "syscall" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" +) + +// filesystem implements vfs.FilesystemImpl. +type filesystem struct { + kernfs.Filesystem +} + +// NewMount returns a new disconnected mount in vfsObj that may be passed to ImportFD. +func NewMount(vfsObj *vfs.VirtualFilesystem) (*vfs.Mount, error) { + fs := &filesystem{} + fs.Init(vfsObj) + vfsfs := fs.VFSFilesystem() + // NewDisconnectedMount will take an additional reference on vfsfs. + defer vfsfs.DecRef() + return vfsObj.NewDisconnectedMount(vfsfs, nil, &vfs.MountOptions{}) +} + +// ImportFD sets up and returns a vfs.FileDescription from a donated fd. +func ImportFD(mnt *vfs.Mount, hostFD int, ownerUID auth.KUID, ownerGID auth.KGID, isTTY bool) (*vfs.FileDescription, error) { + fs, ok := mnt.Filesystem().Impl().(*kernfs.Filesystem) + if !ok { + return nil, fmt.Errorf("can't import host FDs into filesystems of type %T", mnt.Filesystem().Impl()) + } + + // Retrieve metadata. + var s syscall.Stat_t + if err := syscall.Fstat(hostFD, &s); err != nil { + return nil, err + } + + fileMode := linux.FileMode(s.Mode) + fileType := fileMode.FileType() + // Pipes, character devices, and sockets. + isStream := fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK + + i := &inode{ + hostFD: hostFD, + isStream: isStream, + isTTY: isTTY, + ino: fs.NextIno(), + mode: fileMode, + uid: ownerUID, + gid: ownerGID, + } + + d := &kernfs.Dentry{} + d.Init(i) + // i.open will take a reference on d. + defer d.DecRef() + + return i.open(d.VFSDentry(), mnt) +} + +// inode implements kernfs.Inode. +type inode struct { + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + + // When the reference count reaches zero, the host fd is closed. + refs.AtomicRefCount + + // hostFD contains the host fd that this file was originally created from, + // which must be available at time of restore. + // + // This field is initialized at creation time and is immutable. + hostFD int + + // isStream is true if the host fd points to a file representing a stream, + // e.g. a socket or a pipe. Such files are not seekable and can return + // EWOULDBLOCK for I/O operations. + // + // This field is initialized at creation time and is immutable. + isStream bool + + // isTTY is true if this file represents a TTY. + // + // This field is initialized at creation time and is immutable. + isTTY bool + + // ino is an inode number unique within this filesystem. + ino uint64 + + // mu protects the inode metadata below. + // TODO(gvisor.dev/issue/1672): actually protect fields below. + //mu sync.Mutex + + // mode is the file mode of this inode. Note that this value may become out + // of date if the mode is changed on the host, e.g. with chmod. + mode linux.FileMode + + // uid and gid of the file owner. Note that these refer to the owner of the + // file created on import, not the fd on the host. + uid auth.KUID + gid auth.KGID +} + +// Note that these flags may become out of date, since they can be modified +// on the host, e.g. with fcntl. +func fileFlagsFromHostFD(fd int) (int, error) { + flags, err := unix.FcntlInt(uintptr(fd), syscall.F_GETFL, 0) + if err != nil { + log.Warningf("Failed to get file flags for donated FD %d: %v", fd, err) + return 0, err + } + // TODO(gvisor.dev/issue/1672): implement behavior corresponding to these allowed flags. + flags &= syscall.O_ACCMODE | syscall.O_DIRECT | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND + return flags, nil +} + +// CheckPermissions implements kernfs.Inode. +func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, atx vfs.AccessTypes) error { + return vfs.GenericCheckPermissions(creds, atx, false /* isDir */, uint16(i.mode), i.uid, i.gid) +} + +// Mode implements kernfs.Inode. +func (i *inode) Mode() linux.FileMode { + return i.mode +} + +// Stat implements kernfs.Inode. +func (i *inode) Stat(_ *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + if opts.Mask&linux.STATX__RESERVED != 0 { + return linux.Statx{}, syserror.EINVAL + } + if opts.Sync&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE { + return linux.Statx{}, syserror.EINVAL + } + + // Limit our host call only to known flags. + mask := opts.Mask & linux.STATX_ALL + var s unix.Statx_t + err := unix.Statx(i.hostFD, "", int(unix.AT_EMPTY_PATH|opts.Sync), int(mask), &s) + // Fallback to fstat(2), if statx(2) is not supported on the host. + // + // TODO(b/151263641): Remove fallback. + if err == syserror.ENOSYS { + return i.fstat(opts) + } else if err != nil { + return linux.Statx{}, err + } + + ls := linux.Statx{Mask: mask} + // Unconditionally fill blksize, attributes, and device numbers, as indicated + // by /include/uapi/linux/stat.h. + // + // RdevMajor/RdevMinor are left as zero, so as not to expose host device + // numbers. + // + // TODO(gvisor.dev/issue/1672): Use kernfs-specific, internally defined + // device numbers. If we use the device number from the host, it may collide + // with another sentry-internal device number. We handle device/inode + // numbers without relying on the host to prevent collisions. + ls.Blksize = s.Blksize + ls.Attributes = s.Attributes + ls.AttributesMask = s.Attributes_mask + + if mask|linux.STATX_TYPE != 0 { + ls.Mode |= s.Mode & linux.S_IFMT + } + if mask|linux.STATX_MODE != 0 { + ls.Mode |= s.Mode &^ linux.S_IFMT + } + if mask|linux.STATX_NLINK != 0 { + ls.Nlink = s.Nlink + } + if mask|linux.STATX_ATIME != 0 { + ls.Atime = unixToLinuxStatxTimestamp(s.Atime) + } + if mask|linux.STATX_BTIME != 0 { + ls.Btime = unixToLinuxStatxTimestamp(s.Btime) + } + if mask|linux.STATX_CTIME != 0 { + ls.Ctime = unixToLinuxStatxTimestamp(s.Ctime) + } + if mask|linux.STATX_MTIME != 0 { + ls.Mtime = unixToLinuxStatxTimestamp(s.Mtime) + } + if mask|linux.STATX_SIZE != 0 { + ls.Size = s.Size + } + if mask|linux.STATX_BLOCKS != 0 { + ls.Blocks = s.Blocks + } + + // Use our own internal inode number and file owner. + if mask|linux.STATX_INO != 0 { + ls.Ino = i.ino + } + if mask|linux.STATX_UID != 0 { + ls.UID = uint32(i.uid) + } + if mask|linux.STATX_GID != 0 { + ls.GID = uint32(i.gid) + } + + return ls, nil +} + +// fstat is a best-effort fallback for inode.Stat() if the host does not +// support statx(2). +// +// We ignore the mask and sync flags in opts and simply supply +// STATX_BASIC_STATS, as fstat(2) itself does not allow the specification +// of a mask or sync flags. fstat(2) does not provide any metadata +// equivalent to Statx.Attributes, Statx.AttributesMask, or Statx.Btime, so +// those fields remain empty. +func (i *inode) fstat(opts vfs.StatOptions) (linux.Statx, error) { + var s unix.Stat_t + if err := unix.Fstat(i.hostFD, &s); err != nil { + return linux.Statx{}, err + } + + // Note that rdev numbers are left as 0; do not expose host device numbers. + ls := linux.Statx{ + Mask: linux.STATX_BASIC_STATS, + Blksize: uint32(s.Blksize), + Nlink: uint32(s.Nlink), + Mode: uint16(s.Mode), + Size: uint64(s.Size), + Blocks: uint64(s.Blocks), + Atime: timespecToStatxTimestamp(s.Atim), + Ctime: timespecToStatxTimestamp(s.Ctim), + Mtime: timespecToStatxTimestamp(s.Mtim), + } + + // Use our own internal inode number and file owner. + // + // TODO(gvisor.dev/issue/1672): Use a kernfs-specific device number as well. + // If we use the device number from the host, it may collide with another + // sentry-internal device number. We handle device/inode numbers without + // relying on the host to prevent collisions. + ls.Ino = i.ino + ls.UID = uint32(i.uid) + ls.GID = uint32(i.gid) + + return ls, nil +} + +// SetStat implements kernfs.Inode. +func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + s := opts.Stat + + m := s.Mask + if m == 0 { + return nil + } + if m&^(linux.STATX_MODE|linux.STATX_SIZE|linux.STATX_ATIME|linux.STATX_MTIME) != 0 { + return syserror.EPERM + } + if err := vfs.CheckSetStat(ctx, creds, &s, uint16(i.Mode().Permissions()), i.uid, i.gid); err != nil { + return err + } + + if m&linux.STATX_MODE != 0 { + if err := syscall.Fchmod(i.hostFD, uint32(s.Mode)); err != nil { + return err + } + i.mode = linux.FileMode(s.Mode) + } + if m&linux.STATX_SIZE != 0 { + if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil { + return err + } + } + if m&(linux.STATX_ATIME|linux.STATX_MTIME) != 0 { + timestamps := []unix.Timespec{ + toTimespec(s.Atime, m&linux.STATX_ATIME == 0), + toTimespec(s.Mtime, m&linux.STATX_MTIME == 0), + } + if err := unix.UtimesNanoAt(i.hostFD, "", timestamps, unix.AT_EMPTY_PATH); err != nil { + return err + } + } + return nil +} + +// DecRef implements kernfs.Inode. +func (i *inode) DecRef() { + i.AtomicRefCount.DecRefWithDestructor(i.Destroy) +} + +// Destroy implements kernfs.Inode. +func (i *inode) Destroy() { + if err := unix.Close(i.hostFD); err != nil { + log.Warningf("failed to close host fd %d: %v", i.hostFD, err) + } +} + +// Open implements kernfs.Inode. +func (i *inode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + return i.open(vfsd, rp.Mount()) +} + +func (i *inode) open(d *vfs.Dentry, mnt *vfs.Mount) (*vfs.FileDescription, error) { + fileType := i.mode.FileType() + if fileType == syscall.S_IFSOCK { + if i.isTTY { + return nil, errors.New("cannot use host socket as TTY") + } + // TODO(gvisor.dev/issue/1672): support importing sockets. + return nil, errors.New("importing host sockets not supported") + } + + // TODO(gvisor.dev/issue/1672): Whitelist specific file types here, so that + // we don't allow importing arbitrary file types without proper support. + if i.isTTY { + // TODO(gvisor.dev/issue/1672): support importing host fd as TTY. + return nil, errors.New("importing host fd as TTY not supported") + } + + // For simplicity, set offset to 0. Technically, we should + // only set to 0 on files that are not seekable (sockets, pipes, etc.), + // and use the offset from the host fd otherwise. + fd := &defaultFileFD{ + fileDescription: fileDescription{ + inode: i, + }, + canMap: canMap(uint32(fileType)), + mu: sync.Mutex{}, + offset: 0, + } + + vfsfd := &fd.vfsfd + flags, err := fileFlagsFromHostFD(i.hostFD) + if err != nil { + return nil, err + } + + if err := vfsfd.Init(fd, uint32(flags), mnt, d, &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + return vfsfd, nil +} + +// fileDescription is embedded by host fd implementations of FileDescriptionImpl. +type fileDescription struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + + // inode is vfsfd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode), but + // cached to reduce indirections and casting. fileDescription does not hold + // a reference on the inode through the inode field (since one is already + // held via the Dentry). + // + // inode is immutable after fileDescription creation. + inode *inode +} + +// SetStat implements vfs.FileDescriptionImpl. +func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + creds := auth.CredentialsFromContext(ctx) + return f.inode.SetStat(ctx, nil, creds, opts) +} + +// Stat implements vfs.FileDescriptionImpl. +func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) { + return f.inode.Stat(nil, opts) +} + +// Release implements vfs.FileDescriptionImpl. +func (f *fileDescription) Release() { + // noop +} diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go new file mode 100644 index 000000000..d519feef5 --- /dev/null +++ b/pkg/sentry/fsimpl/host/util.go @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "syscall" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/syserror" +) + +func toTimespec(ts linux.StatxTimestamp, omit bool) unix.Timespec { + if omit { + return unix.Timespec{ + Sec: 0, + Nsec: unix.UTIME_OMIT, + } + } + return unix.Timespec{ + Sec: int64(ts.Sec), + Nsec: int64(ts.Nsec), + } +} + +func unixToLinuxStatxTimestamp(ts unix.StatxTimestamp) linux.StatxTimestamp { + return linux.StatxTimestamp{Sec: ts.Sec, Nsec: ts.Nsec} +} + +func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp { + return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)} +} + +// wouldBlock returns true for file types that can return EWOULDBLOCK +// for blocking operations, e.g. pipes, character devices, and sockets. +func wouldBlock(fileType uint32) bool { + return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK +} + +// canMap returns true if a file with fileType is allowed to be memory mapped. +// This is ported over from VFS1, but it's probably not the best way for us +// to check if a file can be memory mapped. +func canMap(fileType uint32) bool { + // TODO(gvisor.dev/issue/1672): Also allow "special files" to be mapped (see fs/host:canMap()). + // + // TODO(b/38213152): Some obscure character devices can be mapped. + return fileType == syscall.S_IFREG +} + +// isBlockError checks if an error is EAGAIN or EWOULDBLOCK. +// If so, they can be transformed into syserror.ErrWouldBlock. +func isBlockError(err error) bool { + return err == syserror.EAGAIN || err == syserror.EWOULDBLOCK +} diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index d092ccb2a..d8bddbafa 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -61,9 +61,10 @@ func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vf return &fd.vfsfd, nil } -// SetStat implements Inode.SetStat. -func (f *DynamicBytesFile) SetStat(*vfs.Filesystem, vfs.SetStatOptions) error { - // DynamicBytesFiles are immutable. +// SetStat implements Inode.SetStat. By default DynamicBytesFile doesn't allow +// inode attributes to be changed. Override SetStat() making it call +// f.InodeAttrs to allow it. +func (*DynamicBytesFile) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } @@ -122,7 +123,7 @@ func (fd *DynamicBytesFD) Release() {} // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() - return fd.inode.Stat(fs), nil + return fd.inode.Stat(fs, opts) } // SetStat implements vfs.FileDescriptionImpl.SetStat. diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index 5650512e0..75c4bab1a 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -17,6 +17,7 @@ package kernfs import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -107,9 +108,13 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent fs.mu.Lock() defer fs.mu.Unlock() + opts := vfs.StatOptions{Mask: linux.STATX_INO} // Handle ".". if fd.off == 0 { - stat := fd.inode().Stat(vfsFS) + stat, err := fd.inode().Stat(vfsFS, opts) + if err != nil { + return err + } dirent := vfs.Dirent{ Name: ".", Type: linux.DT_DIR, @@ -125,7 +130,10 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent // Handle "..". if fd.off == 1 { parentInode := vfsd.ParentOrSelf().Impl().(*Dentry).inode - stat := parentInode.Stat(vfsFS) + stat, err := parentInode.Stat(vfsFS, opts) + if err != nil { + return err + } dirent := vfs.Dirent{ Name: "..", Type: linux.FileMode(stat.Mode).DirentType(), @@ -146,7 +154,10 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent childIdx := fd.off - 2 for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() { inode := it.Dentry.Impl().(*Dentry).inode - stat := inode.Stat(vfsFS) + stat, err := inode.Stat(vfsFS, opts) + if err != nil { + return err + } dirent := vfs.Dirent{ Name: it.Name, Type: linux.FileMode(stat.Mode).DirentType(), @@ -190,12 +201,13 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := fd.filesystem() inode := fd.inode() - return inode.Stat(fs), nil + return inode.Stat(fs, opts) } // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { fs := fd.filesystem() + creds := auth.CredentialsFromContext(ctx) inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode - return inode.SetStat(fs, opts) + return inode.SetStat(ctx, fs, creds, opts) } diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 292f58afd..31da8b511 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This file implements vfs.FilesystemImpl for kernfs. - package kernfs +// This file implements vfs.FilesystemImpl for kernfs. + import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" ) @@ -229,6 +230,19 @@ func (fs *Filesystem) Sync(ctx context.Context) error { return nil } +// AccessAt implements vfs.Filesystem.Impl.AccessAt. +func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + defer fs.processDeferredDecRefs() + + _, inode, err := fs.walkExistingLocked(ctx, rp) + if err != nil { + return err + } + return inode.CheckPermissions(ctx, creds, ats) +} + // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { fs.mu.RLock() @@ -622,7 +636,7 @@ func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts if opts.Stat.Mask == 0 { return nil } - return inode.SetStat(fs.VFSFilesystem(), opts) + return inode.SetStat(ctx, fs.VFSFilesystem(), rp.Credentials(), opts) } // StatAt implements vfs.FilesystemImpl.StatAt. @@ -634,7 +648,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return linux.Statx{}, err } - return inode.Stat(fs.VFSFilesystem()), nil + return inode.Stat(fs.VFSFilesystem(), opts) } // StatFSAt implements vfs.FilesystemImpl.StatFSAt. diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 099d70a16..c612dcf07 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -36,20 +36,20 @@ type InodeNoopRefCount struct { } // IncRef implements Inode.IncRef. -func (n *InodeNoopRefCount) IncRef() { +func (InodeNoopRefCount) IncRef() { } // DecRef implements Inode.DecRef. -func (n *InodeNoopRefCount) DecRef() { +func (InodeNoopRefCount) DecRef() { } // TryIncRef implements Inode.TryIncRef. -func (n *InodeNoopRefCount) TryIncRef() bool { +func (InodeNoopRefCount) TryIncRef() bool { return true } // Destroy implements Inode.Destroy. -func (n *InodeNoopRefCount) Destroy() { +func (InodeNoopRefCount) Destroy() { } // InodeDirectoryNoNewChildren partially implements the Inode interface. @@ -58,27 +58,27 @@ func (n *InodeNoopRefCount) Destroy() { type InodeDirectoryNoNewChildren struct{} // NewFile implements Inode.NewFile. -func (*InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { return nil, syserror.EPERM } // NewDir implements Inode.NewDir. -func (*InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { return nil, syserror.EPERM } // NewLink implements Inode.NewLink. -func (*InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { return nil, syserror.EPERM } // NewSymlink implements Inode.NewSymlink. -func (*InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { return nil, syserror.EPERM } // NewNode implements Inode.NewNode. -func (*InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { return nil, syserror.EPERM } @@ -90,62 +90,62 @@ type InodeNotDirectory struct { } // HasChildren implements Inode.HasChildren. -func (*InodeNotDirectory) HasChildren() bool { +func (InodeNotDirectory) HasChildren() bool { return false } // NewFile implements Inode.NewFile. -func (*InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { panic("NewFile called on non-directory inode") } // NewDir implements Inode.NewDir. -func (*InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { panic("NewDir called on non-directory inode") } // NewLink implements Inode.NewLinkink. -func (*InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { panic("NewLink called on non-directory inode") } // NewSymlink implements Inode.NewSymlink. -func (*InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { panic("NewSymlink called on non-directory inode") } // NewNode implements Inode.NewNode. -func (*InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { panic("NewNode called on non-directory inode") } // Unlink implements Inode.Unlink. -func (*InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error { +func (InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error { panic("Unlink called on non-directory inode") } // RmDir implements Inode.RmDir. -func (*InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error { +func (InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error { panic("RmDir called on non-directory inode") } // Rename implements Inode.Rename. -func (*InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) { +func (InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) { panic("Rename called on non-directory inode") } // Lookup implements Inode.Lookup. -func (*InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { panic("Lookup called on non-directory inode") } // IterDirents implements Inode.IterDirents. -func (*InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { +func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { panic("IterDirents called on non-directory inode") } // Valid implements Inode.Valid. -func (*InodeNotDirectory) Valid(context.Context) bool { +func (InodeNotDirectory) Valid(context.Context) bool { return true } @@ -157,17 +157,17 @@ func (*InodeNotDirectory) Valid(context.Context) bool { type InodeNoDynamicLookup struct{} // Lookup implements Inode.Lookup. -func (*InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { return nil, syserror.ENOENT } // IterDirents implements Inode.IterDirents. -func (*InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { return offset, nil } // Valid implements Inode.Valid. -func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool { +func (InodeNoDynamicLookup) Valid(ctx context.Context) bool { return true } @@ -177,7 +177,7 @@ func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool { type InodeNotSymlink struct{} // Readlink implements Inode.Readlink. -func (*InodeNotSymlink) Readlink(context.Context) (string, error) { +func (InodeNotSymlink) Readlink(context.Context) (string, error) { return "", syserror.EINVAL } @@ -219,7 +219,7 @@ func (a *InodeAttrs) Mode() linux.FileMode { // Stat partially implements Inode.Stat. Note that this function doesn't provide // all the stat fields, and the embedder should consider extending the result // with filesystem-specific fields. -func (a *InodeAttrs) Stat(*vfs.Filesystem) linux.Statx { +func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK stat.Ino = atomic.LoadUint64(&a.ino) @@ -228,13 +228,23 @@ func (a *InodeAttrs) Stat(*vfs.Filesystem) linux.Statx { stat.GID = atomic.LoadUint32(&a.gid) stat.Nlink = atomic.LoadUint32(&a.nlink) - // TODO: Implement other stat fields like timestamps. + // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - return stat + return stat, nil } // SetStat implements Inode.SetStat. -func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error { +func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + if opts.Stat.Mask == 0 { + return nil + } + if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 { + return syserror.EPERM + } + if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, uint16(a.Mode().Permissions()), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { + return err + } + stat := opts.Stat if stat.Mask&linux.STATX_MODE != 0 { for { @@ -256,7 +266,7 @@ func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error { // Note that not all fields are modifiable. For example, the file type and // inode numbers are immutable after node creation. - // TODO: Implement other stat fields like timestamps. + // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. return nil } @@ -554,3 +564,16 @@ func (s *StaticDirectory) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs fd.Init(rp.Mount(), vfsd, &s.OrderedChildren, &opts) return fd.VFSFileDescription(), nil } + +// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +func (*StaticDirectory) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + +// AlwaysValid partially implements kernfs.inodeDynamicLookup. +type AlwaysValid struct{} + +// Valid implements kernfs.inodeDynamicLookup. +func (*AlwaysValid) Valid(context.Context) bool { + return true +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index c74fa999b..794e38908 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -176,8 +176,6 @@ type Dentry struct { vfsd vfs.Dentry inode Inode - refs uint64 - // flags caches useful information about the dentry from the inode. See the // dflags* consts above. Must be accessed by atomic ops. flags uint32 @@ -302,7 +300,8 @@ type Inode interface { // this inode. The returned file description should hold a reference on the // inode for its lifetime. // - // Precondition: !rp.Done(). vfsd.Impl() must be a kernfs Dentry. + // Precondition: rp.Done(). vfsd.Impl() must be the kernfs Dentry containing + // the inode on which Open() is being called. Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) } @@ -320,7 +319,7 @@ type inodeMetadata interface { // CheckPermissions checks that creds may access this inode for the // requested access type, per the the rules of // fs/namei.c:generic_permission(). - CheckPermissions(ctx context.Context, creds *auth.Credentials, atx vfs.AccessTypes) error + CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error // Mode returns the (struct stat)::st_mode value for this inode. This is // separated from Stat for performance. @@ -328,11 +327,13 @@ type inodeMetadata interface { // Stat returns the metadata for this inode. This corresponds to // vfs.FilesystemImpl.StatAt. - Stat(fs *vfs.Filesystem) linux.Statx + Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) // SetStat updates the metadata for this inode. This corresponds to - // vfs.FilesystemImpl.SetStatAt. - SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error + // vfs.FilesystemImpl.SetStatAt. Implementations are responsible for checking + // if the operation can be performed (see vfs.CheckSetStat() for common + // checks). + SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error } // Precondition: All methods in this interface may only be called on directory diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 0459fb305..fb0d25ad7 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -91,7 +91,7 @@ type attrs struct { kernfs.InodeAttrs } -func (a *attrs) SetStat(fs *vfs.Filesystem, opt vfs.SetStatOptions) error { +func (*attrs) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index 0ee7eb9b7..5918d3309 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -18,6 +18,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" ) // StaticSymlink provides an Inode implementation for symlinks that point to @@ -52,3 +54,8 @@ func (s *StaticSymlink) Init(creds *auth.Credentials, ino uint64, target string) func (s *StaticSymlink) Readlink(_ context.Context) (string, error) { return s.target, nil } + +// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +func (*StaticSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index a83245866..8156984eb 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -8,10 +8,11 @@ go_library( "filesystem.go", "subtasks.go", "task.go", + "task_fds.go", "task_files.go", + "task_net.go", "tasks.go", "tasks_files.go", - "tasks_net.go", "tasks_sys.go", ], visibility = ["//pkg/sentry:internal"], @@ -19,8 +20,10 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/log", + "//pkg/refs", "//pkg/safemem", "//pkg/sentry/fs", + "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", @@ -53,6 +56,7 @@ go_test( "//pkg/fspath", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", + "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index f3f4e49b4..a21313666 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" ) @@ -34,6 +35,7 @@ type subtasksInode struct { kernfs.InodeDirectoryNoNewChildren kernfs.InodeAttrs kernfs.OrderedChildren + kernfs.AlwaysValid task *kernel.Task pidns *kernel.PIDNamespace @@ -61,11 +63,6 @@ func newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, inoGen InoGenera return dentry } -// Valid implements kernfs.inodeDynamicLookup. -func (i *subtasksInode) Valid(ctx context.Context) bool { - return true -} - // Lookup implements kernfs.inodeDynamicLookup. func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { tid, err := strconv.ParseUint(name, 10, 32) @@ -121,8 +118,18 @@ func (i *subtasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.O } // Stat implements kernfs.Inode. -func (i *subtasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx { - stat := i.InodeAttrs.Stat(vsfs) - stat.Nlink += uint32(i.task.ThreadGroup().Count()) - return stat +func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.InodeAttrs.Stat(vsfs, opts) + if err != nil { + return linux.Statx{}, err + } + if opts.Mask&linux.STATX_NLINK != 0 { + stat.Nlink += uint32(i.task.ThreadGroup().Count()) + } + return stat, nil +} + +// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +func (*subtasksInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM } diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 2d814668a..49d6efb0e 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -45,28 +45,31 @@ var _ kernfs.Inode = (*taskInode)(nil) func newTaskInode(inoGen InoGenerator, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) *kernfs.Dentry { contents := map[string]*kernfs.Dentry{ - "auxv": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &auxvData{task: task}), - "cmdline": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), - "comm": newComm(task, inoGen.NextIno(), 0444), - "environ": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), - //"exe": newExe(t, msrc), - //"fd": newFdDir(t, msrc), - //"fdinfo": newFdInfoDir(t, msrc), - "gid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: true}), - "io": newTaskOwnedFile(task, inoGen.NextIno(), 0400, newIO(task, isThreadGroup)), - "maps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mapsData{task: task}), - //"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), - //"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), + "auxv": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &auxvData{task: task}), + "cmdline": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), + "comm": newComm(task, inoGen.NextIno(), 0444), + "environ": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), + "exe": newExeSymlink(task, inoGen.NextIno()), + "fd": newFDDirInode(task, inoGen), + "fdinfo": newFDInfoDirInode(task, inoGen), + "gid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: true}), + "io": newTaskOwnedFile(task, inoGen.NextIno(), 0400, newIO(task, isThreadGroup)), + "maps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mapsData{task: task}), + "mountinfo": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mountInfoData{task: task}), + "mounts": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mountsData{task: task}), + "net": newTaskNetDir(task, inoGen), "ns": newTaskOwnedDir(task, inoGen.NextIno(), 0511, map[string]*kernfs.Dentry{ "net": newNamespaceSymlink(task, inoGen.NextIno(), "net"), "pid": newNamespaceSymlink(task, inoGen.NextIno(), "pid"), "user": newNamespaceSymlink(task, inoGen.NextIno(), "user"), }), - "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}), - "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), - "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}), - "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}), - "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}), + "oom_score": newTaskOwnedFile(task, inoGen.NextIno(), 0444, newStaticFile("0\n")), + "oom_score_adj": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &oomScoreAdj{task: task}), + "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}), + "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), + "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}), + "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}), + "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}), } if isThreadGroup { contents["task"] = newSubtasks(task, pidns, inoGen, cgroupControllers) @@ -104,13 +107,9 @@ func (i *taskInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenO return fd.VFSFileDescription(), nil } -// SetStat implements kernfs.Inode. -func (i *taskInode) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error { - stat := opts.Stat - if stat.Mask&linux.STATX_MODE != 0 { - return syserror.EPERM - } - return nil +// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +func (*taskInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM } // taskOwnedInode implements kernfs.Inode and overrides inode owner with task @@ -152,12 +151,21 @@ func newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, childre } // Stat implements kernfs.Inode. -func (i *taskOwnedInode) Stat(fs *vfs.Filesystem) linux.Statx { - stat := i.Inode.Stat(fs) - uid, gid := i.getOwner(linux.FileMode(stat.Mode)) - stat.UID = uint32(uid) - stat.GID = uint32(gid) - return stat +func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.Inode.Stat(fs, opts) + if err != nil { + return linux.Statx{}, err + } + if opts.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 { + uid, gid := i.getOwner(linux.FileMode(stat.Mode)) + if opts.Mask&linux.STATX_UID != 0 { + stat.UID = uint32(uid) + } + if opts.Mask&linux.STATX_GID != 0 { + stat.GID = uint32(gid) + } + } + return stat, nil } // CheckPermissions implements kernfs.Inode. @@ -234,7 +242,7 @@ func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentr // member, there is one entry containing three colon-separated fields: // hierarchy-ID:controller-list:cgroup-path" func newCgroupData(controllers map[string]string) dynamicInode { - buf := bytes.Buffer{} + var buf bytes.Buffer // The hierarchy ids must be positive integers (for cgroup v1), but the // exact number does not matter, so long as they are unique. We can diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go new file mode 100644 index 000000000..76bfc5307 --- /dev/null +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -0,0 +1,287 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "bytes" + "fmt" + "sort" + "strconv" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +type fdDir struct { + inoGen InoGenerator + task *kernel.Task + + // When produceSymlinks is set, dirents produces for the FDs are reported + // as symlink. Otherwise, they are reported as regular files. + produceSymlink bool +} + +func (i *fdDir) lookup(name string) (*vfs.FileDescription, kernel.FDFlags, error) { + fd, err := strconv.ParseUint(name, 10, 64) + if err != nil { + return nil, kernel.FDFlags{}, syserror.ENOENT + } + + var ( + file *vfs.FileDescription + flags kernel.FDFlags + ) + i.task.WithMuLocked(func(t *kernel.Task) { + if fdTable := t.FDTable(); fdTable != nil { + file, flags = fdTable.GetVFS2(int32(fd)) + } + }) + if file == nil { + return nil, kernel.FDFlags{}, syserror.ENOENT + } + return file, flags, nil +} + +// IterDirents implements kernfs.inodeDynamicLookup. +func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, absOffset, relOffset int64) (int64, error) { + var fds []int32 + i.task.WithMuLocked(func(t *kernel.Task) { + if fdTable := t.FDTable(); fdTable != nil { + fds = fdTable.GetFDs() + } + }) + + offset := absOffset + relOffset + typ := uint8(linux.DT_REG) + if i.produceSymlink { + typ = linux.DT_LNK + } + + // Find the appropriate starting point. + idx := sort.Search(len(fds), func(i int) bool { return fds[i] >= int32(relOffset) }) + if idx >= len(fds) { + return offset, nil + } + for _, fd := range fds[idx:] { + dirent := vfs.Dirent{ + Name: strconv.FormatUint(uint64(fd), 10), + Type: typ, + Ino: i.inoGen.NextIno(), + NextOff: offset + 1, + } + if err := cb.Handle(dirent); err != nil { + return offset, err + } + offset++ + } + return offset, nil +} + +// fdDirInode represents the inode for /proc/[pid]/fd directory. +// +// +stateify savable +type fdDirInode struct { + kernfs.InodeNotSymlink + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeAttrs + kernfs.OrderedChildren + kernfs.AlwaysValid + fdDir +} + +var _ kernfs.Inode = (*fdDirInode)(nil) + +func newFDDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { + inode := &fdDirInode{ + fdDir: fdDir{ + inoGen: inoGen, + task: task, + produceSymlink: true, + }, + } + inode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555) + + dentry := &kernfs.Dentry{} + dentry.Init(inode) + inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + + return dentry +} + +// Lookup implements kernfs.inodeDynamicLookup. +func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { + file, _, err := i.lookup(name) + if err != nil { + return nil, err + } + taskDentry := newFDSymlink(i.task.Credentials(), file, i.inoGen.NextIno()) + return taskDentry.VFSDentry(), nil +} + +// Open implements kernfs.Inode. +func (i *fdDirInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd := &kernfs.GenericDirectoryFD{} + fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts) + return fd.VFSFileDescription(), nil +} + +// CheckPermissions implements kernfs.Inode. +// +// This is to match Linux, which uses a special permission handler to guarantee +// that a process can still access /proc/self/fd after it has executed +// setuid. See fs/proc/fd.c:proc_fd_permission. +func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { + err := i.InodeAttrs.CheckPermissions(ctx, creds, ats) + if err == nil { + // Access granted, no extra check needed. + return nil + } + if t := kernel.TaskFromContext(ctx); t != nil { + // Allow access if the task trying to access it is in the thread group + // corresponding to this directory. + if i.task.ThreadGroup() == t.ThreadGroup() { + // Access granted (overridden). + return nil + } + } + return err +} + +// fdSymlink is an symlink for the /proc/[pid]/fd/[fd] file. +// +// +stateify savable +type fdSymlink struct { + refs.AtomicRefCount + kernfs.InodeAttrs + kernfs.InodeSymlink + + file *vfs.FileDescription +} + +var _ kernfs.Inode = (*fdSymlink)(nil) + +func newFDSymlink(creds *auth.Credentials, file *vfs.FileDescription, ino uint64) *kernfs.Dentry { + file.IncRef() + inode := &fdSymlink{file: file} + inode.Init(creds, ino, linux.ModeSymlink|0777) + + d := &kernfs.Dentry{} + d.Init(inode) + return d +} + +func (s *fdSymlink) Readlink(ctx context.Context) (string, error) { + root := vfs.RootFromContext(ctx) + defer root.DecRef() + + vfsObj := s.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem() + return vfsObj.PathnameWithDeleted(ctx, root, s.file.VirtualDentry()) +} + +func (s *fdSymlink) DecRef() { + s.AtomicRefCount.DecRefWithDestructor(func() { + s.Destroy() + }) +} + +func (s *fdSymlink) Destroy() { + s.file.DecRef() +} + +// fdInfoDirInode represents the inode for /proc/[pid]/fdinfo directory. +// +// +stateify savable +type fdInfoDirInode struct { + kernfs.InodeNotSymlink + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeAttrs + kernfs.OrderedChildren + kernfs.AlwaysValid + fdDir +} + +var _ kernfs.Inode = (*fdInfoDirInode)(nil) + +func newFDInfoDirInode(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { + inode := &fdInfoDirInode{ + fdDir: fdDir{ + inoGen: inoGen, + task: task, + }, + } + inode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555) + + dentry := &kernfs.Dentry{} + dentry.Init(inode) + inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + + return dentry +} + +// Lookup implements kernfs.inodeDynamicLookup. +func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { + file, flags, err := i.lookup(name) + if err != nil { + return nil, err + } + + data := &fdInfoData{file: file, flags: flags} + dentry := newTaskOwnedFile(i.task, i.inoGen.NextIno(), 0444, data) + return dentry.VFSDentry(), nil +} + +// Open implements kernfs.Inode. +func (i *fdInfoDirInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd := &kernfs.GenericDirectoryFD{} + fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts) + return fd.VFSFileDescription(), nil +} + +// fdInfoData implements vfs.DynamicBytesSource for /proc/[pid]/fdinfo/[fd]. +// +// +stateify savable +type fdInfoData struct { + kernfs.DynamicBytesFile + refs.AtomicRefCount + + file *vfs.FileDescription + flags kernel.FDFlags +} + +var _ dynamicInode = (*fdInfoData)(nil) + +func (d *fdInfoData) DecRef() { + d.AtomicRefCount.DecRefWithDestructor(d.destroy) +} + +func (d *fdInfoData) destroy() { + d.file.DecRef() +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { + // TODO(b/121266871): Include pos, locks, and other data. For now we only + // have flags. + // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt + flags := uint(d.file.StatusFlags()) | d.flags.ToLinuxFileFlags() + fmt.Fprintf(buf, "flags:\t0%o\n", flags) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index efd3b3453..8c743df8d 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -18,10 +18,14 @@ import ( "bytes" "fmt" "io" + "sort" + "strings" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -496,7 +500,7 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } -// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider. +// ioUsage is the /proc/[pid]/io and /proc/[pid]/task/[tid]/io data provider. type ioUsage interface { // IOUsage returns the io usage data. IOUsage() *usage.IO @@ -525,3 +529,293 @@ func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled) return nil } + +// oomScoreAdj is a stub of the /proc/<pid>/oom_score_adj file. +// +// +stateify savable +type oomScoreAdj struct { + kernfs.DynamicBytesFile + + task *kernel.Task +} + +var _ vfs.WritableDynamicBytesSource = (*oomScoreAdj)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (o *oomScoreAdj) Generate(ctx context.Context, buf *bytes.Buffer) error { + if o.task.ExitState() == kernel.TaskExitDead { + return syserror.ESRCH + } + fmt.Fprintf(buf, "%d\n", o.task.OOMScoreAdj()) + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit input size so as not to impact performance if input size is large. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + + if o.task.ExitState() == kernel.TaskExitDead { + return 0, syserror.ESRCH + } + if err := o.task.SetOOMScoreAdj(v); err != nil { + return 0, err + } + + return n, nil +} + +// exeSymlink is an symlink for the /proc/[pid]/exe file. +// +// +stateify savable +type exeSymlink struct { + kernfs.InodeAttrs + kernfs.InodeNoopRefCount + kernfs.InodeSymlink + + task *kernel.Task +} + +var _ kernfs.Inode = (*exeSymlink)(nil) + +func newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry { + inode := &exeSymlink{task: task} + inode.Init(task.Credentials(), ino, linux.ModeSymlink|0777) + + d := &kernfs.Dentry{} + d.Init(inode) + return d +} + +// Readlink implements kernfs.Inode. +func (s *exeSymlink) Readlink(ctx context.Context) (string, error) { + if !kernel.ContextCanTrace(ctx, s.task, false) { + return "", syserror.EACCES + } + + // Pull out the executable for /proc/[pid]/exe. + exec, err := s.executable() + if err != nil { + return "", err + } + defer exec.DecRef() + + return exec.PathnameWithDeleted(ctx), nil +} + +func (s *exeSymlink) executable() (file fsbridge.File, err error) { + s.task.WithMuLocked(func(t *kernel.Task) { + mm := t.MemoryManager() + if mm == nil { + // TODO(b/34851096): Check shouldn't allow Readlink once the + // Task is zombied. + err = syserror.EACCES + return + } + + // The MemoryManager may be destroyed, in which case + // MemoryManager.destroy will simply set the executable to nil + // (with locks held). + file = mm.Executable() + if file == nil { + err = syserror.ENOENT + } + }) + return +} + +// forEachMountSource runs f for the process root mount and each mount that is +// a descendant of the root. +func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) { + var fsctx *kernel.FSContext + t.WithMuLocked(func(t *kernel.Task) { + fsctx = t.FSContext() + }) + if fsctx == nil { + // The task has been destroyed. Nothing to show here. + return + } + + // All mount points must be relative to the rootDir, and mounts outside + // will be excluded. + rootDir := fsctx.RootDirectory() + if rootDir == nil { + // The task has been destroyed. Nothing to show here. + return + } + defer rootDir.DecRef() + + mnt := t.MountNamespace().FindMount(rootDir) + if mnt == nil { + // Has it just been unmounted? + return + } + ms := t.MountNamespace().AllMountsUnder(mnt) + sort.Slice(ms, func(i, j int) bool { + return ms[i].ID < ms[j].ID + }) + for _, m := range ms { + mroot := m.Root() + if mroot == nil { + continue // No longer valid. + } + mountPath, desc := mroot.FullName(rootDir) + mroot.DecRef() + if !desc { + // MountSources that are not descendants of the chroot jail are ignored. + continue + } + fn(mountPath, m) + } +} + +// mountInfoData is used to implement /proc/[pid]/mountinfo. +// +// +stateify savable +type mountInfoData struct { + kernfs.DynamicBytesFile + + task *kernel.Task +} + +var _ dynamicInode = (*mountInfoData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (i *mountInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { + forEachMount(i.task, func(mountPath string, m *fs.Mount) { + mroot := m.Root() + if mroot == nil { + return // No longer valid. + } + defer mroot.DecRef() + + // Format: + // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue + // (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11) + + // (1) MountSource ID. + fmt.Fprintf(buf, "%d ", m.ID) + + // (2) Parent ID (or this ID if there is no parent). + pID := m.ID + if !m.IsRoot() && !m.IsUndo() { + pID = m.ParentID + } + fmt.Fprintf(buf, "%d ", pID) + + // (3) Major:Minor device ID. We don't have a superblock, so we + // just use the root inode device number. + sa := mroot.Inode.StableAttr + fmt.Fprintf(buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor) + + // (4) Root: the pathname of the directory in the filesystem + // which forms the root of this mount. + // + // NOTE(b/78135857): This will always be "/" until we implement + // bind mounts. + fmt.Fprintf(buf, "/ ") + + // (5) Mount point (relative to process root). + fmt.Fprintf(buf, "%s ", mountPath) + + // (6) Mount options. + flags := mroot.Inode.MountSource.Flags + opts := "rw" + if flags.ReadOnly { + opts = "ro" + } + if flags.NoAtime { + opts += ",noatime" + } + if flags.NoExec { + opts += ",noexec" + } + fmt.Fprintf(buf, "%s ", opts) + + // (7) Optional fields: zero or more fields of the form "tag[:value]". + // (8) Separator: the end of the optional fields is marked by a single hyphen. + fmt.Fprintf(buf, "- ") + + // (9) Filesystem type. + fmt.Fprintf(buf, "%s ", mroot.Inode.MountSource.FilesystemType) + + // (10) Mount source: filesystem-specific information or "none". + fmt.Fprintf(buf, "none ") + + // (11) Superblock options, and final newline. + fmt.Fprintf(buf, "%s\n", superBlockOpts(mountPath, mroot.Inode.MountSource)) + }) + return nil +} + +func superBlockOpts(mountPath string, msrc *fs.MountSource) string { + // gVisor doesn't (yet) have a concept of super block options, so we + // use the ro/rw bit from the mount flag. + opts := "rw" + if msrc.Flags.ReadOnly { + opts = "ro" + } + + // NOTE(b/147673608): If the mount is a cgroup, we also need to include + // the cgroup name in the options. For now we just read that from the + // path. + // TODO(gvisor.dev/issues/190): Once gVisor has full cgroup support, we + // should get this value from the cgroup itself, and not rely on the + // path. + if msrc.FilesystemType == "cgroup" { + splitPath := strings.Split(mountPath, "/") + cgroupType := splitPath[len(splitPath)-1] + opts += "," + cgroupType + } + return opts +} + +// mountsData is used to implement /proc/[pid]/mounts. +// +// +stateify savable +type mountsData struct { + kernfs.DynamicBytesFile + + task *kernel.Task +} + +var _ dynamicInode = (*mountInfoData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + forEachMount(i.task, func(mountPath string, m *fs.Mount) { + // Format: + // <special device or remote filesystem> <mount point> <filesystem type> <mount options> <needs dump> <fsck order> + // + // We use the filesystem name as the first field, since there + // is no real block device we can point to, and we also should + // not expose anything about the remote filesystem. + // + // Only ro/rw option is supported for now. + // + // The "needs dump"and fsck flags are always 0, which is allowed. + root := m.Root() + if root == nil { + return // No longer valid. + } + defer root.DecRef() + + flags := root.Inode.MountSource.Flags + opts := "rw" + if flags.ReadOnly { + opts = "ro" + } + fmt.Fprintf(buf, "%s %s %s %s %d %d\n", "none", mountPath, root.Inode.MountSource.FilesystemType, opts, 0, 0) + }) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/task_net.go index d4e1812d8..373a7b17d 100644 --- a/pkg/sentry/fsimpl/proc/tasks_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -37,12 +37,13 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { +func newTaskNetDir(task *kernel.Task, inoGen InoGenerator) *kernfs.Dentry { + k := task.Kernel() + pidns := task.PIDNamespace() + root := auth.NewRootCredentials(pidns.UserNamespace()) + var contents map[string]*kernfs.Dentry - // TODO(gvisor.dev/issue/1833): Support for using the network stack in the - // network namespace of the calling process. We should make this per-process, - // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net. - if stack := k.RootNetworkNamespace().Stack(); stack != nil { + if stack := task.NetworkNamespace().Stack(); stack != nil { const ( arp = "IP address HW type Flags HW address Mask Device\n" netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n" @@ -53,6 +54,8 @@ func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k ) psched := fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)) + // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task + // network namespace. contents = map[string]*kernfs.Dentry{ "dev": newDentry(root, inoGen.NextIno(), 0444, &netDevData{stack: stack}), "snmp": newDentry(root, inoGen.NextIno(), 0444, &netSnmpData{stack: stack}), @@ -84,7 +87,7 @@ func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k } } - return kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, contents) + return newTaskOwnedDir(task, inoGen.NextIno(), 0555, contents) } // ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6. diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 10c08fa90..9f2ef8200 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -46,6 +46,7 @@ type tasksInode struct { kernfs.InodeDirectoryNoNewChildren kernfs.InodeAttrs kernfs.OrderedChildren + kernfs.AlwaysValid inoGen InoGenerator pidns *kernel.PIDNamespace @@ -66,23 +67,23 @@ var _ kernfs.Inode = (*tasksInode)(nil) func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) (*tasksInode, *kernfs.Dentry) { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]*kernfs.Dentry{ - "cpuinfo": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(cpuInfoData(k))), - //"filesystems": newDentry(root, inoGen.NextIno(), 0444, &filesystemsData{}), - "loadavg": newDentry(root, inoGen.NextIno(), 0444, &loadavgData{}), - "sys": newSysDir(root, inoGen, k), - "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}), - "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"), - "net": newNetDir(root, inoGen, k), - "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{k: k}), - "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}), - "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{k: k}), + "cpuinfo": newDentry(root, inoGen.NextIno(), 0444, newStaticFileSetStat(cpuInfoData(k))), + "filesystems": newDentry(root, inoGen.NextIno(), 0444, &filesystemsData{}), + "loadavg": newDentry(root, inoGen.NextIno(), 0444, &loadavgData{}), + "sys": newSysDir(root, inoGen, k), + "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}), + "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"), + "net": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/net"), + "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}), + "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}), + "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}), } inode := &tasksInode{ pidns: pidns, inoGen: inoGen, - selfSymlink: newSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(), - threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), 0444, pidns).VFSDentry(), + selfSymlink: newSelfSymlink(root, inoGen.NextIno(), pidns).VFSDentry(), + threadSelfSymlink: newThreadSelfSymlink(root, inoGen.NextIno(), pidns).VFSDentry(), cgroupControllers: cgroupControllers, } inode.InodeAttrs.Init(root, inoGen.NextIno(), linux.ModeDirectory|0555) @@ -121,11 +122,6 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro return taskDentry.VFSDentry(), nil } -// Valid implements kernfs.inodeDynamicLookup. -func (i *tasksInode) Valid(ctx context.Context) bool { - return true -} - // IterDirents implements kernfs.inodeDynamicLookup. func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256 @@ -211,17 +207,36 @@ func (i *tasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.Open return fd.VFSFileDescription(), nil } -func (i *tasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx { - stat := i.InodeAttrs.Stat(vsfs) +func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.InodeAttrs.Stat(vsfs, opts) + if err != nil { + return linux.Statx{}, err + } - // Add dynamic children to link count. - for _, tg := range i.pidns.ThreadGroups() { - if leader := tg.Leader(); leader != nil { - stat.Nlink++ + if opts.Mask&linux.STATX_NLINK != 0 { + // Add dynamic children to link count. + for _, tg := range i.pidns.ThreadGroups() { + if leader := tg.Leader(); leader != nil { + stat.Nlink++ + } } } - return stat + return stat, nil +} + +// staticFileSetStat implements a special static file that allows inode +// attributes to be set. This is to support /proc files that are readonly, but +// allow attributes to be set. +type staticFileSetStat struct { + dynamicBytesFileSetAttr + vfs.StaticData +} + +var _ dynamicInode = (*staticFileSetStat)(nil) + +func newStaticFileSetStat(data string) *staticFileSetStat { + return &staticFileSetStat{StaticData: vfs.StaticData{Data: data}} } func cpuInfoData(k *kernel.Kernel) string { diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 434998910..882c1981e 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -40,9 +41,9 @@ type selfSymlink struct { var _ kernfs.Inode = (*selfSymlink)(nil) -func newSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry { +func newSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry { inode := &selfSymlink{pidns: pidns} - inode.Init(creds, ino, linux.ModeSymlink|perm) + inode.Init(creds, ino, linux.ModeSymlink|0777) d := &kernfs.Dentry{} d.Init(inode) @@ -62,6 +63,11 @@ func (s *selfSymlink) Readlink(ctx context.Context) (string, error) { return strconv.FormatUint(uint64(tgid), 10), nil } +// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +func (*selfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + type threadSelfSymlink struct { kernfs.InodeAttrs kernfs.InodeNoopRefCount @@ -72,9 +78,9 @@ type threadSelfSymlink struct { var _ kernfs.Inode = (*threadSelfSymlink)(nil) -func newThreadSelfSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, pidns *kernel.PIDNamespace) *kernfs.Dentry { +func newThreadSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry { inode := &threadSelfSymlink{pidns: pidns} - inode.Init(creds, ino, linux.ModeSymlink|perm) + inode.Init(creds, ino, linux.ModeSymlink|0777) d := &kernfs.Dentry{} d.Init(inode) @@ -95,6 +101,23 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) { return fmt.Sprintf("%d/task/%d", tgid, tid), nil } +// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + +// dynamicBytesFileSetAttr implements a special file that allows inode +// attributes to be set. This is to support /proc files that are readonly, but +// allow attributes to be set. +type dynamicBytesFileSetAttr struct { + kernfs.DynamicBytesFile +} + +// SetStat implements Inode.SetStat. +func (d *dynamicBytesFileSetAttr) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + return d.DynamicBytesFile.InodeAttrs.SetStat(ctx, fs, creds, opts) +} + // cpuStats contains the breakdown of CPU time for /proc/stat. type cpuStats struct { // user is time spent in userspace tasks with non-positive niceness. @@ -137,22 +160,20 @@ func (c cpuStats) String() string { // // +stateify savable type statData struct { - kernfs.DynamicBytesFile - - // k is the owning Kernel. - k *kernel.Kernel + dynamicBytesFileSetAttr } var _ dynamicInode = (*statData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. -func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error { +func (*statData) Generate(ctx context.Context, buf *bytes.Buffer) error { // TODO(b/37226836): We currently export only zero CPU stats. We could // at least provide some aggregate stats. var cpu cpuStats fmt.Fprintf(buf, "cpu %s\n", cpu) - for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ { + k := kernel.KernelFromContext(ctx) + for c, max := uint(0), k.ApplicationCores(); c < max; c++ { fmt.Fprintf(buf, "cpu%d %s\n", c, cpu) } @@ -176,7 +197,7 @@ func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "ctxt 0\n") // CLOCK_REALTIME timestamp from boot, in seconds. - fmt.Fprintf(buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds()) + fmt.Fprintf(buf, "btime %d\n", k.Timekeeper().BootTime().Seconds()) // Total number of clones. // TODO(b/37226836): Count this. @@ -203,13 +224,13 @@ func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error { // // +stateify savable type loadavgData struct { - kernfs.DynamicBytesFile + dynamicBytesFileSetAttr } var _ dynamicInode = (*loadavgData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. -func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error { +func (*loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error { // TODO(b/62345059): Include real data in fields. // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods. // Column 4-5: currently running processes and the total number of processes. @@ -222,17 +243,15 @@ func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error { // // +stateify savable type meminfoData struct { - kernfs.DynamicBytesFile - - // k is the owning Kernel. - k *kernel.Kernel + dynamicBytesFileSetAttr } var _ dynamicInode = (*meminfoData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. -func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { - mf := d.k.MemoryFile() +func (*meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { + k := kernel.KernelFromContext(ctx) + mf := k.MemoryFile() mf.UpdateUsage() snapshot, totalUsage := usage.MemoryAccounting.Copy() totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage) @@ -275,7 +294,7 @@ func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error { // // +stateify savable type uptimeData struct { - kernfs.DynamicBytesFile + dynamicBytesFileSetAttr } var _ dynamicInode = (*uptimeData)(nil) @@ -294,17 +313,15 @@ func (*uptimeData) Generate(ctx context.Context, buf *bytes.Buffer) error { // // +stateify savable type versionData struct { - kernfs.DynamicBytesFile - - // k is the owning Kernel. - k *kernel.Kernel + dynamicBytesFileSetAttr } var _ dynamicInode = (*versionData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. -func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { - init := v.k.GlobalInit() +func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { + k := kernel.KernelFromContext(ctx) + init := k.GlobalInit() if init == nil { // Attempted to read before the init Task is created. This can // only occur during startup, which should never need to read @@ -335,3 +352,19 @@ func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version) return nil } + +// filesystemsData backs /proc/filesystems. +// +// +stateify savable +type filesystemsData struct { + kernfs.DynamicBytesFile +} + +var _ dynamicInode = (*filesystemsData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + k := kernel.KernelFromContext(ctx) + k.VFS().GenerateProcFilesystems(buf) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index c5d531fe0..d0f97c137 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -47,10 +48,11 @@ var ( var ( tasksStaticFiles = map[string]testutil.DirentType{ "cpuinfo": linux.DT_REG, + "filesystems": linux.DT_REG, "loadavg": linux.DT_REG, "meminfo": linux.DT_REG, "mounts": linux.DT_LNK, - "net": linux.DT_DIR, + "net": linux.DT_LNK, "self": linux.DT_LNK, "stat": linux.DT_REG, "sys": linux.DT_DIR, @@ -63,21 +65,29 @@ var ( "thread-self": threadSelfLink.NextOff, } taskStaticFiles = map[string]testutil.DirentType{ - "auxv": linux.DT_REG, - "cgroup": linux.DT_REG, - "cmdline": linux.DT_REG, - "comm": linux.DT_REG, - "environ": linux.DT_REG, - "gid_map": linux.DT_REG, - "io": linux.DT_REG, - "maps": linux.DT_REG, - "ns": linux.DT_DIR, - "smaps": linux.DT_REG, - "stat": linux.DT_REG, - "statm": linux.DT_REG, - "status": linux.DT_REG, - "task": linux.DT_DIR, - "uid_map": linux.DT_REG, + "auxv": linux.DT_REG, + "cgroup": linux.DT_REG, + "cmdline": linux.DT_REG, + "comm": linux.DT_REG, + "environ": linux.DT_REG, + "exe": linux.DT_LNK, + "fd": linux.DT_DIR, + "fdinfo": linux.DT_DIR, + "gid_map": linux.DT_REG, + "io": linux.DT_REG, + "maps": linux.DT_REG, + "mountinfo": linux.DT_REG, + "mounts": linux.DT_REG, + "net": linux.DT_DIR, + "ns": linux.DT_DIR, + "oom_score": linux.DT_REG, + "oom_score_adj": linux.DT_REG, + "smaps": linux.DT_REG, + "stat": linux.DT_REG, + "statm": linux.DT_REG, + "status": linux.DT_REG, + "task": linux.DT_DIR, + "uid_map": linux.DT_REG, } ) @@ -93,17 +103,37 @@ func setup(t *testing.T) *testutil.System { k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - fsOpts := vfs.GetFilesystemOptions{ - InternalData: &InternalData{ - Cgroups: map[string]string{ - "cpuset": "/foo/cpuset", - "memory": "/foo/memory", + + mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{}) + if err != nil { + t.Fatalf("NewMountNamespace(): %v", err) + } + pop := &vfs.PathOperation{ + Root: mntns.Root(), + Start: mntns.Root(), + Path: fspath.Parse("/proc"), + } + if err := k.VFS().MkdirAt(ctx, creds, pop, &vfs.MkdirOptions{Mode: 0777}); err != nil { + t.Fatalf("MkDir(/proc): %v", err) + } + + pop = &vfs.PathOperation{ + Root: mntns.Root(), + Start: mntns.Root(), + Path: fspath.Parse("/proc"), + } + mntOpts := &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: &InternalData{ + Cgroups: map[string]string{ + "cpuset": "/foo/cpuset", + "memory": "/foo/memory", + }, }, }, } - mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", Name, &fsOpts) - if err != nil { - t.Fatalf("NewMountNamespace(): %v", err) + if err := k.VFS().MountAt(ctx, creds, "", pop, Name, mntOpts); err != nil { + t.Fatalf("MountAt(/proc): %v", err) } return testutil.NewSystem(ctx, t, k.VFS(), mntns) } @@ -112,7 +142,7 @@ func TestTasksEmpty(t *testing.T) { s := setup(t) defer s.Destroy() - collector := s.ListDirents(s.PathOpAtRoot("/")) + collector := s.ListDirents(s.PathOpAtRoot("/proc")) s.AssertAllDirentTypes(collector, tasksStaticFiles) s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs) } @@ -138,7 +168,7 @@ func TestTasks(t *testing.T) { expectedDirents[fmt.Sprintf("%d", i+1)] = linux.DT_DIR } - collector := s.ListDirents(s.PathOpAtRoot("/")) + collector := s.ListDirents(s.PathOpAtRoot("/proc")) s.AssertAllDirentTypes(collector, expectedDirents) s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs) @@ -178,7 +208,7 @@ func TestTasks(t *testing.T) { } // Test lookup. - for _, path := range []string{"/1", "/2"} { + for _, path := range []string{"/proc/1", "/proc/2"} { fd, err := s.VFS.OpenAt( s.Ctx, s.Creds, @@ -188,6 +218,7 @@ func TestTasks(t *testing.T) { if err != nil { t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err) } + defer fd.DecRef() buf := make([]byte, 1) bufIOSeq := usermem.BytesIOSequence(buf) if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR { @@ -198,10 +229,10 @@ func TestTasks(t *testing.T) { if _, err := s.VFS.OpenAt( s.Ctx, s.Creds, - s.PathOpAtRoot("/9999"), + s.PathOpAtRoot("/proc/9999"), &vfs.OpenOptions{}, ); err != syserror.ENOENT { - t.Fatalf("wrong error from vfsfs.OpenAt(/9999): %v", err) + t.Fatalf("wrong error from vfsfs.OpenAt(/proc/9999): %v", err) } } @@ -299,12 +330,13 @@ func TestTasksOffset(t *testing.T) { fd, err := s.VFS.OpenAt( s.Ctx, s.Creds, - s.PathOpAtRoot("/"), + s.PathOpAtRoot("/proc"), &vfs.OpenOptions{}, ) if err != nil { t.Fatalf("vfsfs.OpenAt(/) failed: %v", err) } + defer fd.DecRef() if _, err := fd.Seek(s.Ctx, tc.offset, linux.SEEK_SET); err != nil { t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err) } @@ -341,7 +373,7 @@ func TestTask(t *testing.T) { t.Fatalf("CreateTask(): %v", err) } - collector := s.ListDirents(s.PathOpAtRoot("/1")) + collector := s.ListDirents(s.PathOpAtRoot("/proc/1")) s.AssertAllDirentTypes(collector, taskStaticFiles) } @@ -359,14 +391,14 @@ func TestProcSelf(t *testing.T) { collector := s.WithTemporaryContext(task).ListDirents(&vfs.PathOperation{ Root: s.Root, Start: s.Root, - Path: fspath.Parse("/self/"), + Path: fspath.Parse("/proc/self/"), FollowFinalSymlink: true, }) s.AssertAllDirentTypes(collector, taskStaticFiles) } func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.FileDescription) { - t.Logf("Iterating: /proc%s", fd.MappedName(ctx)) + t.Logf("Iterating: %s", fd.MappedName(ctx)) var collector testutil.DirentCollector if err := fd.IterDirents(ctx, &collector); err != nil { @@ -409,6 +441,7 @@ func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.F t.Errorf("vfsfs.OpenAt(%v) failed: %v", childPath, err) continue } + defer child.DecRef() stat, err := child.Stat(ctx, vfs.StatOptions{}) if err != nil { t.Errorf("Stat(%v) failed: %v", childPath, err) @@ -429,6 +462,22 @@ func TestTree(t *testing.T) { defer s.Destroy() k := kernel.KernelFromContext(s.Ctx) + + pop := &vfs.PathOperation{ + Root: s.Root, + Start: s.Root, + Path: fspath.Parse("test-file"), + } + opts := &vfs.OpenOptions{ + Flags: linux.O_RDONLY | linux.O_CREAT, + Mode: 0777, + } + file, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, opts) + if err != nil { + t.Fatalf("failed to create test file: %v", err) + } + defer file.DecRef() + var tasks []*kernel.Task for i := 0; i < 5; i++ { tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) @@ -436,6 +485,8 @@ func TestTree(t *testing.T) { if err != nil { t.Fatalf("CreateTask(): %v", err) } + // Add file to populate /proc/[pid]/fd and fdinfo directories. + task.FDTable().NewFDVFS2(task, 0, file, kernel.FDFlags{}) tasks = append(tasks, task) } @@ -443,11 +494,12 @@ func TestTree(t *testing.T) { fd, err := s.VFS.OpenAt( ctx, auth.CredentialsFromContext(s.Ctx), - &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse("/")}, + &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse("/proc")}, &vfs.OpenOptions{}, ) if err != nil { - t.Fatalf("vfsfs.OpenAt(/) failed: %v", err) + t.Fatalf("vfsfs.OpenAt(/proc) failed: %v", err) } iterateDir(ctx, t, s, fd) + fd.DecRef() } diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index c36c4fa11..7abfd62f2 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -94,15 +94,17 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte return &d.dentry } -// SetStat implements kernfs.Inode.SetStat. -func (d *dir) SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error { +// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } // Open implements kernfs.Inode.Open. func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &kernfs.GenericDirectoryFD{} - fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts) + if err := fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts); err != nil { + return nil, err + } return fd.VFSFileDescription(), nil } diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index e4f36f4ae..0e4053a46 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -16,12 +16,14 @@ go_library( "//pkg/cpuid", "//pkg/fspath", "//pkg/memutil", + "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/sched", "//pkg/sentry/limits", "//pkg/sentry/loader", + "//pkg/sentry/mm", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", "//pkg/sentry/platform/kvm", diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 488478e29..c16a36cdb 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -23,13 +23,16 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/loader" + "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/time" @@ -123,10 +126,17 @@ func Boot() (*kernel.Kernel, error) { // CreateTask creates a new bare bones task for tests. func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns *vfs.MountNamespace, root, cwd vfs.VirtualDentry) (*kernel.Task, error) { k := kernel.KernelFromContext(ctx) + exe, err := newFakeExecutable(ctx, k.VFS(), auth.CredentialsFromContext(ctx), root) + if err != nil { + return nil, err + } + m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation) + m.SetExecutable(fsbridge.NewVFSFile(exe)) + config := &kernel.TaskConfig{ Kernel: k, ThreadGroup: tc, - TaskContext: &kernel.TaskContext{Name: name}, + TaskContext: &kernel.TaskContext{Name: name, MemoryManager: m}, Credentials: auth.CredentialsFromContext(ctx), NetworkNamespace: k.RootNetworkNamespace(), AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()), @@ -135,10 +145,25 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(), MountNamespaceVFS2: mntns, FSContext: kernel.NewFSContextVFS2(root, cwd, 0022), + FDTable: k.NewFDTable(), } return k.TaskSet().NewTask(config) } +func newFakeExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry) (*vfs.FileDescription, error) { + const name = "executable" + pop := &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(name), + } + opts := &vfs.OpenOptions{ + Flags: linux.O_RDONLY | linux.O_CREAT, + Mode: 0777, + } + return vfsObj.OpenAt(ctx, creds, pop, opts) +} + func createMemoryFile() (*pgalloc.MemoryFile, error) { const memfileName = "test-memory" memfd, err := memutil.CreateMemFD(memfileName, 0) diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index e1b551422..75d01b853 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" ) @@ -154,6 +155,17 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa return create(parent, name) } +// AccessAt implements vfs.Filesystem.Impl.AccessAt. +func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + d, err := resolveLocked(rp) + if err != nil { + return err + } + return d.inode.checkPermissions(creds, ats, d.inode.isDir()) +} + // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { fs.mu.RLock() @@ -563,7 +575,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts if err != nil { return err } - return d.inode.setStat(opts.Stat) + return d.inode.setStat(ctx, rp.Credentials(), &opts.Stat) } // StatAt implements vfs.FilesystemImpl.StatAt. diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 711442424..5a2896bf6 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -308,11 +308,18 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off return 0, nil } f := fd.inode().impl.(*regularFile) - end := offset + srclen - if end < offset { + if end := offset + srclen; end < offset { // Overflow. return 0, syserror.EFBIG } + + var err error + srclen, err = vfs.CheckLimit(ctx, offset, srclen) + if err != nil { + return 0, err + } + src = src.TakeFirst64(srclen) + f.inode.mu.Lock() rw := getRegularFileReadWriter(f, offset) n, err := src.CopyInTo(ctx, rw) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 521206305..ff69372b3 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -299,10 +299,16 @@ func (i *inode) statTo(stat *linux.Statx) { } } -func (i *inode) setStat(stat linux.Statx) error { +func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error { if stat.Mask == 0 { return nil } + if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME|linux.STATX_SIZE) != 0 { + return syserror.EPERM + } + if err := vfs.CheckSetStat(ctx, creds, stat, uint16(atomic.LoadUint32(&i.mode))&^linux.S_IFMT, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { + return err + } i.mu.Lock() var ( needsMtimeBump bool @@ -457,5 +463,6 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - return fd.inode().setStat(opts.Stat) + creds := auth.CredentialsFromContext(ctx) + return fd.inode().setStat(ctx, creds, &opts.Stat) } diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go index c16667e7f..029af3025 100644 --- a/pkg/sentry/inet/namespace.go +++ b/pkg/sentry/inet/namespace.go @@ -23,7 +23,10 @@ type Namespace struct { // creator allows kernel to create new network stack for network namespaces. // If nil, no networking will function if network is namespaced. - creator NetworkStackCreator + // + // At afterLoad(), creator will be used to create network stack. Stateify + // needs to wait for this field to be loaded before calling afterLoad(). + creator NetworkStackCreator `state:"wait"` // isRoot indicates whether this is the root network namespace. isRoot bool diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go index a0d35d350..8e9f200d0 100644 --- a/pkg/sentry/kernel/epoll/epoll_state.go +++ b/pkg/sentry/kernel/epoll/epoll_state.go @@ -38,11 +38,14 @@ func (e *EventPoll) afterLoad() { } } - for it := e.waitingList.Front(); it != nil; it = it.Next() { - if it.id.File.Readiness(it.mask) != 0 { - e.waitingList.Remove(it) - e.readyList.PushBack(it) - it.curList = &e.readyList + for it := e.waitingList.Front(); it != nil; { + entry := it + it = it.Next() + + if entry.id.File.Readiness(entry.mask) != 0 { + e.waitingList.Remove(entry) + e.readyList.PushBack(entry) + entry.curList = &e.readyList e.Notify(waiter.EventIn) } } diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 58001d56c..d09d97825 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -191,10 +191,12 @@ func (f *FDTable) Size() int { return int(size) } -// forEach iterates over all non-nil files. +// forEach iterates over all non-nil files in sorted order. // // It is the caller's responsibility to acquire an appropriate lock. func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) { + // retries tracks the number of failed TryIncRef attempts for the same FD. + retries := 0 fd := int32(0) for { file, fileVFS2, flags, ok := f.getAll(fd) @@ -204,17 +206,26 @@ func (f *FDTable) forEach(fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDes switch { case file != nil: if !file.TryIncRef() { + retries++ + if retries > 1000 { + panic(fmt.Sprintf("File in FD table has been destroyed. FD: %d, File: %+v, FileOps: %+v", fd, file, file.FileOperations)) + } continue // Race caught. } fn(fd, file, nil, flags) file.DecRef() case fileVFS2 != nil: if !fileVFS2.TryIncRef() { + retries++ + if retries > 1000 { + panic(fmt.Sprintf("File in FD table has been destroyed. FD: %d, File: %+v, Impl: %+v", fd, fileVFS2, fileVFS2.Impl())) + } continue // Race caught. } fn(fd, nil, fileVFS2, flags) fileVFS2.DecRef() } + retries = 0 fd++ } } @@ -327,7 +338,7 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc fd = f.next } for fd < end { - if d, _, _ := f.get(fd); d == nil { + if d, _, _ := f.getVFS2(fd); d == nil { f.setVFS2(fd, file, flags) if fd == f.next { // Update next search start position. @@ -447,7 +458,10 @@ func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) { } } -// GetFDs returns a list of valid fds. +// GetFDs returns a sorted list of valid fds. +// +// Precondition: The caller must be running on the task goroutine, or Task.mu +// must be locked. func (f *FDTable) GetFDs() []int32 { fds := make([]int32, 0, int(atomic.LoadInt32(&f.used))) f.forEach(func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) { @@ -522,7 +536,9 @@ func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) { case orig2 != nil: orig2.IncRef() } - f.setAll(fd, nil, nil, FDFlags{}) // Zap entry. + if orig != nil || orig2 != nil { + f.setAll(fd, nil, nil, FDFlags{}) // Zap entry. + } return orig, orig2 } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 1d627564f..6feda8fa1 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -467,6 +467,11 @@ func (k *Kernel) flushMountSourceRefs() error { // // Precondition: Must be called with the kernel paused. func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) (err error) { + // TODO(gvisor.dev/issue/1663): Add save support for VFS2. + if VFS2Enabled { + return nil + } + ts.mu.RLock() defer ts.mu.RUnlock() for t := range ts.Root.tids { @@ -484,7 +489,7 @@ func (ts *TaskSet) forEachFDPaused(f func(*fs.File, *vfs.FileDescription) error) } func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { - // TODO(gvisor.dev/issues/1663): Add save support for VFS2. + // TODO(gvisor.dev/issue/1663): Add save support for VFS2. return ts.forEachFDPaused(func(file *fs.File, _ *vfs.FileDescription) error { if flags := file.Flags(); !flags.Write { return nil @@ -533,6 +538,11 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { } func (ts *TaskSet) unregisterEpollWaiters() { + // TODO(gvisor.dev/issue/1663): Add save support for VFS2. + if VFS2Enabled { + return + } + ts.mu.RLock() defer ts.mu.RUnlock() for t := range ts.Root.tids { @@ -1005,11 +1015,14 @@ func (k *Kernel) pauseTimeLocked() { // This means we'll iterate FDTables shared by multiple tasks repeatedly, // but ktime.Timer.Pause is idempotent so this is harmless. if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { - if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok { - tfd.PauseTimer() - } - }) + // TODO(gvisor.dev/issue/1663): Add save support for VFS2. + if !VFS2Enabled { + t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok { + tfd.PauseTimer() + } + }) + } } } k.timekeeper.PauseUpdates() @@ -1034,12 +1047,15 @@ func (k *Kernel) resumeTimeLocked() { it.ResumeTimer() } } - if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { - if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok { - tfd.ResumeTimer() - } - }) + // TODO(gvisor.dev/issue/1663): Add save support for VFS2. + if !VFS2Enabled { + if t.fdTable != nil { + t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok { + tfd.ResumeTimer() + } + }) + } } } } diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 1000f3287..c00fa1138 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -554,6 +554,7 @@ func (s *sem) wakeWaiters() { for w := s.waiters.Front(); w != nil; { if s.value < w.value { // Still blocked, skip it. + w = w.Next() continue } w.ch <- struct{}{} diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index 047b5214d..0e19286de 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -246,7 +246,7 @@ func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error { var lastErr error for tg := range tasks.Root.tgids { - if tg.ProcessGroup() == pg { + if tg.processGroup == pg { tg.signalHandlers.mu.Lock() infoCopy := *info if err := tg.leader.sendSignalLocked(&infoCopy, true /*group*/); err != nil { diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 2cee2e6ed..8452ddf5b 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -37,6 +37,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -847,3 +848,18 @@ func (t *Task) AbstractSockets() *AbstractSocketNamespace { func (t *Task) ContainerID() string { return t.containerID } + +// OOMScoreAdj gets the task's thread group's OOM score adjustment. +func (t *Task) OOMScoreAdj() int32 { + return atomic.LoadInt32(&t.tg.oomScoreAdj) +} + +// SetOOMScoreAdj sets the task's thread group's OOM score adjustment. The +// value should be between -1000 and 1000 inclusive. +func (t *Task) SetOOMScoreAdj(adj int32) error { + if adj > 1000 || adj < -1000 { + return syserror.EINVAL + } + atomic.StoreInt32(&t.tg.oomScoreAdj, adj) + return nil +} diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 78866f280..e1ecca99e 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -15,6 +15,8 @@ package kernel import ( + "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -260,6 +262,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { sh = sh.Fork() } tg = t.k.NewThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy()) + tg.oomScoreAdj = atomic.LoadInt32(&t.tg.oomScoreAdj) rseqAddr = t.rseqAddr rseqSignature = t.rseqSignature } diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 268f62e9d..52849f5b3 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -254,6 +254,13 @@ type ThreadGroup struct { // // tty is protected by the signal mutex. tty *TTY + + // oomScoreAdj is the thread group's OOM score adjustment. This is + // currently not used but is maintained for consistency. + // TODO(gvisor.dev/issue/1967) + // + // oomScoreAdj is accessed using atomic memory operations. + oomScoreAdj int32 } // NewThreadGroup returns a new, empty thread group in PID namespace pidns. The diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index c61700892..04efa0147 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -82,6 +82,8 @@ fallback: // dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 - // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64. - MOVD R9, 8(RSP) - BL ·dieHandler(SB) + // R0: Fake the old PC as caller + // R1: First argument (vCPU) + MOVD.P R1, 8(RSP) // R1: First argument (vCPU) + MOVD.P R0, 8(RSP) // R0: Fake the old PC as caller + B ·dieHandler(SB) diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index 2f02c03cf..eb5ed574e 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -17,10 +17,33 @@ package kvm import ( + "unsafe" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) +// dieArchSetup initialies the state for dieTrampoline. +// +// The arm64 dieTrampoline requires the vCPU to be set in R1, and the last PC +// to be in R0. The trampoline then simulates a call to dieHandler from the +// provided PC. +// //go:nosplit func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) { - // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64. + // If the vCPU is in user mode, we set the stack to the stored stack + // value in the vCPU itself. We don't want to unwind the user stack. + if guestRegs.Regs.Pstate&ring0.PSR_MODE_MASK == ring0.PSR_MODE_EL0t { + regs := c.CPU.Registers() + context.Regs[0] = regs.Regs[0] + context.Sp = regs.Sp + context.Regs[29] = regs.Regs[29] // stack base address + } else { + context.Regs[0] = guestRegs.Regs.Pc + context.Sp = guestRegs.Regs.Sp + context.Regs[29] = guestRegs.Regs.Regs[29] + context.Pstate = guestRegs.Regs.Pstate + } + context.Regs[1] = uint64(uintptr(unsafe.Pointer(c))) + context.Pc = uint64(dieTrampolineAddr) } diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index 972ba85c3..a9b4af43e 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -27,6 +27,38 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// userMemoryRegion is a region of physical memory. +// +// This mirrors kvm_memory_region. +type userMemoryRegion struct { + slot uint32 + flags uint32 + guestPhysAddr uint64 + memorySize uint64 + userspaceAddr uint64 +} + +// runData is the run structure. This may be mapped for synchronous register +// access (although that doesn't appear to be supported by my kernel at least). +// +// This mirrors kvm_run. +type runData struct { + requestInterruptWindow uint8 + _ [7]uint8 + + exitReason uint32 + readyForInterruptInjection uint8 + ifFlag uint8 + _ [2]uint8 + + cr8 uint64 + apicBase uint64 + + // This is the union data for exits. Interpretation depends entirely on + // the exitReason above (see vCPU code for more information). + data [32]uint64 +} + // KVM represents a lightweight VM context. type KVM struct { platform.NoCPUPreemptionDetection diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go index c5a6f9c7d..093497bc4 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64.go +++ b/pkg/sentry/platform/kvm/kvm_amd64.go @@ -21,17 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) -// userMemoryRegion is a region of physical memory. -// -// This mirrors kvm_memory_region. -type userMemoryRegion struct { - slot uint32 - flags uint32 - guestPhysAddr uint64 - memorySize uint64 - userspaceAddr uint64 -} - // userRegs represents KVM user registers. // // This mirrors kvm_regs. @@ -169,27 +158,6 @@ type modelControlRegisters struct { entries [16]modelControlRegister } -// runData is the run structure. This may be mapped for synchronous register -// access (although that doesn't appear to be supported by my kernel at least). -// -// This mirrors kvm_run. -type runData struct { - requestInterruptWindow uint8 - _ [7]uint8 - - exitReason uint32 - readyForInterruptInjection uint8 - ifFlag uint8 - _ [2]uint8 - - cr8 uint64 - apicBase uint64 - - // This is the union data for exits. Interpretation depends entirely on - // the exitReason above (see vCPU code for more information). - data [32]uint64 -} - // cpuidEntry is a single CPUID entry. // // This mirrors kvm_cpuid_entry2. diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go index 2319c86d3..79045651e 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -20,17 +20,6 @@ import ( "syscall" ) -// userMemoryRegion is a region of physical memory. -// -// This mirrors kvm_memory_region. -type userMemoryRegion struct { - slot uint32 - flags uint32 - guestPhysAddr uint64 - memorySize uint64 - userspaceAddr uint64 -} - type kvmOneReg struct { id uint64 addr uint64 @@ -53,27 +42,6 @@ type userRegs struct { fpRegs userFpsimdState } -// runData is the run structure. This may be mapped for synchronous register -// access (although that doesn't appear to be supported by my kernel at least). -// -// This mirrors kvm_run. -type runData struct { - requestInterruptWindow uint8 - _ [7]uint8 - - exitReason uint32 - readyForInterruptInjection uint8 - ifFlag uint8 - _ [2]uint8 - - cr8 uint64 - apicBase uint64 - - // This is the union data for exits. Interpretation depends entirely on - // the exitReason above (see vCPU code for more information). - data [32]uint64 -} - // updateGlobalOnce does global initialization. It has to be called only once. func updateGlobalOnce(fd int) error { physicalInit() diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 95abd321e..30402c2df 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -9,6 +9,7 @@ go_library( "ptrace.go", "ptrace_amd64.go", "ptrace_arm64.go", + "ptrace_arm64_unsafe.go", "ptrace_unsafe.go", "stub_amd64.s", "stub_arm64.s", diff --git a/pkg/sentry/platform/ptrace/ptrace_amd64.go b/pkg/sentry/platform/ptrace/ptrace_amd64.go index db0212538..24fc5dc62 100644 --- a/pkg/sentry/platform/ptrace/ptrace_amd64.go +++ b/pkg/sentry/platform/ptrace/ptrace_amd64.go @@ -31,3 +31,17 @@ func fpRegSet(useXsave bool) uintptr { func stackPointer(r *syscall.PtraceRegs) uintptr { return uintptr(r.Rsp) } + +// x86 use the fs_base register to store the TLS pointer which can be +// get/set in "func (t *thread) get/setRegs(regs *syscall.PtraceRegs)". +// So both of the get/setTLS() operations are noop here. + +// getTLS gets the thread local storage register. +func (t *thread) getTLS(tls *uint64) error { + return nil +} + +// setTLS sets the thread local storage register. +func (t *thread) setTLS(tls *uint64) error { + return nil +} diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go new file mode 100644 index 000000000..32b8a6be9 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go @@ -0,0 +1,62 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ptrace + +import ( + "syscall" + "unsafe" + + "gvisor.dev/gvisor/pkg/abi/linux" +) + +// getTLS gets the thread local storage register. +func (t *thread) getTLS(tls *uint64) error { + iovec := syscall.Iovec{ + Base: (*byte)(unsafe.Pointer(tls)), + Len: uint64(unsafe.Sizeof(*tls)), + } + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_GETREGSET, + uintptr(t.tid), + linux.NT_ARM_TLS, + uintptr(unsafe.Pointer(&iovec)), + 0, 0) + if errno != 0 { + return errno + } + return nil +} + +// setTLS sets the thread local storage register. +func (t *thread) setTLS(tls *uint64) error { + iovec := syscall.Iovec{ + Base: (*byte)(unsafe.Pointer(tls)), + Len: uint64(unsafe.Sizeof(*tls)), + } + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_SETREGSET, + uintptr(t.tid), + linux.NT_ARM_TLS, + uintptr(unsafe.Pointer(&iovec)), + 0, 0) + if errno != 0 { + return errno + } + return nil +} diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 31b7cec53..a644609ef 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -506,6 +506,9 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { regs := &ac.StateData().Regs t.resetSysemuRegs(regs) + // Extract TLS register + tls := uint64(ac.TLS()) + // Check for interrupts, and ensure that future interrupts will signal t. if !c.interrupt.Enable(t) { // Pending interrupt; simulate. @@ -526,6 +529,9 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { if err := t.setFPRegs(fpState, uint64(fpLen), useXsave); err != nil { panic(fmt.Sprintf("ptrace set fpregs (%+v) failed: %v", fpState, err)) } + if err := t.setTLS(&tls); err != nil { + panic(fmt.Sprintf("ptrace set tls (%+v) failed: %v", tls, err)) + } for { // Start running until the next system call. @@ -555,6 +561,12 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { if err := t.getFPRegs(fpState, uint64(fpLen), useXsave); err != nil { panic(fmt.Sprintf("ptrace get fpregs failed: %v", err)) } + if err := t.getTLS(&tls); err != nil { + panic(fmt.Sprintf("ptrace get tls failed: %v", err)) + } + if !ac.SetTLS(uintptr(tls)) { + panic(fmt.Sprintf("tls value %v is invalid", tls)) + } // Is it a system call? if sig == (syscallEvent | syscall.SIGTRAP) { diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go index f6da41c27..8122ac6e2 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -27,26 +27,27 @@ const ( _PTE_PGT_BASE = 0x7000 _PTE_PGT_SIZE = 0x1000 - _PSR_MODE_EL0t = 0x0 - _PSR_MODE_EL1t = 0x4 - _PSR_MODE_EL1h = 0x5 - _PSR_EL_MASK = 0xf - - _PSR_D_BIT = 0x200 - _PSR_A_BIT = 0x100 - _PSR_I_BIT = 0x80 - _PSR_F_BIT = 0x40 + _PSR_D_BIT = 0x00000200 + _PSR_A_BIT = 0x00000100 + _PSR_I_BIT = 0x00000080 + _PSR_F_BIT = 0x00000040 ) const ( + // PSR bits + PSR_MODE_EL0t = 0x00000000 + PSR_MODE_EL1t = 0x00000004 + PSR_MODE_EL1h = 0x00000005 + PSR_MODE_MASK = 0x0000000f + // KernelFlagsSet should always be set in the kernel. - KernelFlagsSet = _PSR_MODE_EL1h + KernelFlagsSet = PSR_MODE_EL1h // UserFlagsSet are always set in userspace. - UserFlagsSet = _PSR_MODE_EL0t + UserFlagsSet = PSR_MODE_EL0t - KernelFlagsClear = _PSR_EL_MASK - UserFlagsClear = _PSR_EL_MASK + KernelFlagsClear = PSR_MODE_MASK + UserFlagsClear = PSR_MODE_MASK PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT ) diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go index ba1f9043d..83195d5a1 100644 --- a/pkg/sentry/sighandling/sighandling.go +++ b/pkg/sentry/sighandling/sighandling.go @@ -85,6 +85,11 @@ func StartSignalForwarding(handler func(linux.Signal)) func() { for sig := 1; sig <= numSignals+1; sig++ { sigchan := make(chan os.Signal, 1) sigchans = append(sigchans, sigchan) + + // SIGURG is used by Go's runtime scheduler. + if sig == int(linux.SIGURG) { + continue + } signal.Notify(sigchan, syscall.Signal(sig)) } // Start up our listener. diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 2ec11f6ac..b5b9be46f 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" @@ -35,6 +36,11 @@ import ( // shouldn't be reached - an error has occurred if we fall through to one. const errorTargetName = "ERROR" +// redirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port/destination IP for packets. +const redirectTargetName = "REDIRECT" + // Metadata is used to verify that we are correctly serializing and // deserializing iptables into structs consumable by the iptables tool. We save // a metadata struct when the tables are written, and when they are read out we @@ -240,6 +246,8 @@ func marshalTarget(target iptables.Target) []byte { return marshalErrorTarget(tg.Name) case iptables.ReturnTarget: return marshalStandardTarget(iptables.RuleReturn) + case iptables.RedirectTarget: + return marshalRedirectTarget() case JumpTarget: return marshalJumpTarget(tg) default: @@ -276,6 +284,19 @@ func marshalErrorTarget(errorName string) []byte { return binary.Marshal(ret, usermem.ByteOrder, target) } +func marshalRedirectTarget() []byte { + // This is a redirect target named redirect + target := linux.XTRedirectTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTRedirectTarget, + }, + } + copy(target.Target.Name[:], redirectTargetName) + + ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + func marshalJumpTarget(jt JumpTarget) []byte { nflog("convert to binary: marshalling jump target") @@ -345,6 +366,8 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { switch replace.Name.String() { case iptables.TablenameFilter: table = iptables.EmptyFilterTable() + case iptables.TablenameNat: + table = iptables.EmptyNatTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument @@ -404,7 +427,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) return syserr.ErrInvalidArgument } - target, err := parseTarget(optVal[:targetSize]) + target, err := parseTarget(filter, optVal[:targetSize]) if err != nil { nflog("failed to parse target: %v", err) return syserr.ErrInvalidArgument @@ -495,10 +518,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { } // TODO(gvisor.dev/issue/170): Support other chains. - // Since we only support modifying the INPUT chain right now, make sure - // all other chains point to ACCEPT rules. + // Since we only support modifying the INPUT chain and redirect for + // PREROUTING chain right now, make sure all other chains point to + // ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook != iptables.Input { + if hook != iptables.Input && hook != iptables.Prerouting { if _, ok := table.Rules[ruleIdx].Target.(iptables.AcceptTarget); !ok { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument @@ -570,7 +594,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma // parseTarget parses a target from optVal. optVal should contain only the // target. -func parseTarget(optVal []byte) (iptables.Target, error) { +func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target, error) { nflog("set entries: parsing target of size %d", len(optVal)) if len(optVal) < linux.SizeOfXTEntryTarget { return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) @@ -620,6 +644,55 @@ func parseTarget(optVal []byte) (iptables.Target, error) { nflog("set entries: user-defined target %q", name) return iptables.UserChainTarget{Name: name}, nil } + + case redirectTargetName: + // Redirect target. + if len(optVal) < linux.SizeOfXTRedirectTarget { + return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) + } + + if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + var redirectTarget linux.XTRedirectTarget + buf = optVal[:linux.SizeOfXTRedirectTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + + // Copy linux.XTRedirectTarget to iptables.RedirectTarget. + var target iptables.RedirectTarget + nfRange := redirectTarget.NfRange + + // RangeSize should be 1. + if nfRange.RangeSize != 1 { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + // TODO(gvisor.dev/issue/170): Check if the flags are valid. + // Also check if we need to map ports or IP. + // For now, redirect target only supports destination port change. + // Port range and IP range are not supported yet. + if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + target.RangeProtoSpecified = true + + target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:]) + + // TODO(gvisor.dev/issue/170): Port range is not supported yet. + if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + // Convert port from big endian to little endian. + port := make([]byte, 2) + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort) + target.MinPort = binary.LittleEndian.Uint16(port) + + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort) + target.MaxPort = binary.LittleEndian.Uint16(port) + return target, nil } // Unknown target. @@ -630,25 +703,34 @@ func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, error) { if containsUnsupportedFields(iptip) { return iptables.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) } + if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { + return iptables.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) + } return iptables.IPHeaderFilter{ - Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + Dst: tcpip.Address(iptip.Dst[:]), + DstMask: tcpip.Address(iptip.DstMask[:]), + DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, }, nil } func containsUnsupportedFields(iptip linux.IPTIP) bool { - // Currently we check that everything except protocol is zeroed. + // The following features are supported: + // - Protocol + // - Dst and DstMask + // - The inverse destination IP check flag var emptyInetAddr = linux.InetAddr{} var emptyInterface = [linux.IFNAMSIZ]byte{} - return iptip.Dst != emptyInetAddr || - iptip.Src != emptyInetAddr || + // Disable any supported inverse flags. + inverseMask := uint8(linux.IPT_INV_DSTIP) + return iptip.Src != emptyInetAddr || iptip.SrcMask != emptyInetAddr || - iptip.DstMask != emptyInetAddr || iptip.InputInterface != emptyInterface || iptip.OutputInterface != emptyInterface || iptip.InputInterfaceMask != emptyInterface || iptip.OutputInterfaceMask != emptyInterface || iptip.Flags != 0 || - iptip.InverseFlags != 0 + iptip.InverseFlags&^inverseMask != 0 } func validUnderflow(rule iptables.Rule) bool { diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index d10a9bed8..35a98212a 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -514,7 +514,7 @@ func (ac accessContext) Value(key interface{}) interface{} { } } -func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, resolve bool, mode uint) error { +func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode uint) error { const rOK = 4 const wOK = 2 const xOK = 1 @@ -529,7 +529,7 @@ func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, resolve bool, mode return syserror.EINVAL } - return fileOpOn(t, dirFD, path, resolve, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { + return fileOpOn(t, dirFD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { // access(2) and faccessat(2) check permissions using real // UID/GID, not effective UID/GID. // @@ -564,17 +564,23 @@ func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal addr := args[0].Pointer() mode := args[1].ModeT() - return 0, nil, accessAt(t, linux.AT_FDCWD, addr, true, mode) + return 0, nil, accessAt(t, linux.AT_FDCWD, addr, mode) } // Faccessat implements linux syscall faccessat(2). +// +// Note that the faccessat() system call does not take a flags argument: +// "The raw faccessat() system call takes only the first three arguments. The +// AT_EACCESS and AT_SYMLINK_NOFOLLOW flags are actually implemented within +// the glibc wrapper function for faccessat(). If either of these flags is +// specified, then the wrapper function employs fstatat(2) to determine access +// permissions." - faccessat(2) func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { dirFD := args[0].Int() addr := args[1].Pointer() mode := args[2].ModeT() - flags := args[3].Int() - return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode) + return 0, nil, accessAt(t, dirFD, addr, mode) } // LINT.ThenChange(vfs2/filesystem.go) diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index 9bd2df104..a11a87cd1 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -136,7 +136,10 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall mask := args[3].Uint() statxAddr := args[4].Pointer() - if mask&linux.STATX__RESERVED > 0 { + if mask&linux.STATX__RESERVED != 0 { + return 0, nil, syserror.EINVAL + } + if flags&^(linux.AT_SYMLINK_NOFOLLOW|linux.AT_EMPTY_PATH|linux.AT_STATX_SYNC_TYPE) != 0 { return 0, nil, syserror.EINVAL } if flags&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE { diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 9250659ff..136453ccc 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -173,12 +173,13 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, err } - return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{ + err = setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{ Stat: linux.Statx{ Mask: linux.STATX_SIZE, Size: uint64(length), }, }) + return 0, nil, handleSetSizeError(t, err) } // Ftruncate implements Linux syscall ftruncate(2). @@ -196,12 +197,13 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } defer file.DecRef() - return 0, nil, file.SetStat(t, vfs.SetStatOptions{ + err := file.SetStat(t, vfs.SetStatOptions{ Stat: linux.Statx{ Mask: linux.STATX_SIZE, Size: uint64(length), }, }) + return 0, nil, handleSetSizeError(t, err) } // Utime implements Linux syscall utime(2). @@ -378,3 +380,12 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa FollowFinalSymlink: bool(shouldFollowFinalSymlink), }, opts) } + +func handleSetSizeError(t *kernel.Task, err error) error { + if err == syserror.ErrExceedsFileSizeLimit { + // Convert error to EFBIG and send a SIGXFSZ per setrlimit(2). + t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t)) + return syserror.EFBIG + } + return err +} diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go index 12c532310..97eaedd66 100644 --- a/pkg/sentry/syscalls/linux/vfs2/stat.go +++ b/pkg/sentry/syscalls/linux/vfs2/stat.go @@ -150,7 +150,11 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall mask := args[3].Uint() statxAddr := args[4].Pointer() - if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW|linux.AT_STATX_SYNC_TYPE) != 0 { + return 0, nil, syserror.EINVAL + } + + if mask&linux.STATX__RESERVED != 0 { return 0, nil, syserror.EINVAL } @@ -228,14 +232,64 @@ func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Access implements Linux syscall access(2). func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - // FIXME(jamieliu): actually implement - return 0, nil, nil + addr := args[0].Pointer() + mode := args[1].ModeT() + + return 0, nil, accessAt(t, linux.AT_FDCWD, addr, mode) } -// Faccessat implements Linux syscall access(2). +// Faccessat implements Linux syscall faccessat(2). +// +// Note that the faccessat() system call does not take a flags argument: +// "The raw faccessat() system call takes only the first three arguments. The +// AT_EACCESS and AT_SYMLINK_NOFOLLOW flags are actually implemented within +// the glibc wrapper function for faccessat(). If either of these flags is +// specified, then the wrapper function employs fstatat(2) to determine access +// permissions." - faccessat(2) func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - // FIXME(jamieliu): actually implement - return 0, nil, nil + dirfd := args[0].Int() + addr := args[1].Pointer() + mode := args[2].ModeT() + + return 0, nil, accessAt(t, dirfd, addr, mode) +} + +func accessAt(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error { + const rOK = 4 + const wOK = 2 + const xOK = 1 + + // Sanity check the mode. + if mode&^(rOK|wOK|xOK) != 0 { + return syserror.EINVAL + } + + path, err := copyInPath(t, pathAddr) + if err != nil { + return err + } + tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink) + if err != nil { + return err + } + + // access(2) and faccessat(2) check permissions using real + // UID/GID, not effective UID/GID. + // + // "access() needs to use the real uid/gid, not the effective + // uid/gid. We do this by temporarily clearing all FS-related + // capabilities and switching the fsuid/fsgid around to the + // real ones." -fs/open.c:faccessat + creds := t.Credentials().Fork() + creds.EffectiveKUID = creds.RealKUID + creds.EffectiveKGID = creds.RealKGID + if creds.EffectiveKUID.In(creds.UserNamespace) == auth.RootUID { + creds.EffectiveCaps = creds.PermittedCaps + } else { + creds.EffectiveCaps = 0 + } + + return t.Kernel().VFS().AccessAt(t, creds, vfs.AccessTypes(mode), &tpop.pop) } // Readlinkat implements Linux syscall mknodat(2). diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 07c8383e6..a2a06fc8f 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -42,17 +42,22 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/fd", "//pkg/fspath", "//pkg/gohacks", "//pkg/log", + "//pkg/safemem", "//pkg/sentry/arch", + "//pkg/sentry/fs", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", + "//pkg/sentry/limits", "//pkg/sentry/memmap", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index 2db25be49..925996517 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -41,7 +41,14 @@ func (vfs *VirtualFilesystem) NewAnonVirtualDentry(name string) VirtualDentry { } } -const anonfsBlockSize = usermem.PageSize // via fs/libfs.c:pseudo_fs_fill_super() +const ( + anonfsBlockSize = usermem.PageSize // via fs/libfs.c:pseudo_fs_fill_super() + + // Mode, UID, and GID for a generic anonfs file. + anonFileMode = 0600 // no type is correct + anonFileUID = auth.RootKUID + anonFileGID = auth.RootKGID +) // anonFilesystem is the implementation of FilesystemImpl that backs // VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry(). @@ -69,6 +76,16 @@ func (fs *anonFilesystem) Sync(ctx context.Context) error { return nil } +// AccessAt implements vfs.Filesystem.Impl.AccessAt. +// +// TODO(gvisor.dev/issue/1965): Implement access permissions. +func (fs *anonFilesystem) AccessAt(ctx context.Context, rp *ResolvingPath, creds *auth.Credentials, ats AccessTypes) error { + if !rp.Done() { + return syserror.ENOTDIR + } + return GenericCheckPermissions(creds, ats, false /* isDir */, anonFileMode, anonFileUID, anonFileGID) +} + // GetDentryAt implements FilesystemImpl.GetDentryAt. func (fs *anonFilesystem) GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error) { if !rp.Done() { @@ -167,9 +184,9 @@ func (fs *anonFilesystem) StatAt(ctx context.Context, rp *ResolvingPath, opts St Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS, Blksize: anonfsBlockSize, Nlink: 1, - UID: uint32(auth.RootKUID), - GID: uint32(auth.RootKGID), - Mode: 0600, // no type is correct + UID: uint32(anonFileUID), + GID: uint32(anonFileGID), + Mode: anonFileMode, Ino: 1, Size: 0, Blocks: 0, diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 9a1ad630c..8ee549dc2 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -286,7 +286,8 @@ type FileDescriptionImpl interface { Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) // SetStat updates metadata for the file represented by the - // FileDescription. + // FileDescription. Implementations are responsible for checking if the + // operation can be performed (see vfs.CheckSetStat() for common checks). SetStat(ctx context.Context, opts SetStatOptions) error // StatFS returns metadata for the filesystem containing the file diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index c2a52ec1b..d45e602ce 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -33,8 +33,8 @@ import ( // implementations to adapt: // - Have a local fileDescription struct (containing FileDescription) which // embeds FileDescriptionDefaultImpl and overrides the default methods -// which are common to all fd implementations for that for that filesystem -// like StatusFlags, SetStatusFlags, Stat, SetStat, StatFS, etc. +// which are common to all fd implementations for that filesystem like +// StatusFlags, SetStatusFlags, Stat, SetStat, StatFS, etc. // - This should be embedded in all file description implementations as the // first field by value. // - Directory FDs would also embed DirectoryFileDescriptionDefaultImpl. @@ -339,6 +339,11 @@ func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 { return 0, syserror.EOPNOTSUPP } + limit, err := CheckLimit(ctx, offset, src.NumBytes()) + if err != nil { + return 0, err + } + src = src.TakeFirst64(limit) writable, ok := fd.data.(WritableDynamicBytesSource) if !ok { diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 556976d0b..332decce6 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) // A Filesystem is a tree of nodes represented by Dentries, which forms part of @@ -144,6 +145,9 @@ type FilesystemImpl interface { // file data to be written to the underlying [filesystem]", as by syncfs(2). Sync(ctx context.Context) error + // AccessAt checks whether a user with creds can access the file at rp. + AccessAt(ctx context.Context, rp *ResolvingPath, creds *auth.Credentials, ats AccessTypes) error + // GetDentryAt returns a Dentry representing the file at rp. A reference is // taken on the returned Dentry. // @@ -362,7 +366,9 @@ type FilesystemImpl interface { // ResolvingPath.Resolve*(), then !rp.Done(). RmdirAt(ctx context.Context, rp *ResolvingPath) error - // SetStatAt updates metadata for the file at the given path. + // SetStatAt updates metadata for the file at the given path. Implementations + // are responsible for checking if the operation can be performed + // (see vfs.CheckSetStat() for common checks). // // Errors: // diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 31a4e5480..05f6233f9 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -74,6 +74,10 @@ type Mount struct { // umounted is true. umounted is protected by VirtualFilesystem.mountMu. umounted bool + // flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except + // for MS_RDONLY which is tracked in "writers". + flags MountFlags + // The lower 63 bits of writers is the number of calls to // Mount.CheckBeginWrite() that have not yet been paired with a call to // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect. @@ -81,6 +85,21 @@ type Mount struct { writers int64 } +func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *MountNamespace, opts *MountOptions) *Mount { + mnt := &Mount{ + vfs: vfs, + fs: fs, + root: root, + flags: opts.Flags, + ns: mntns, + refs: 1, + } + if opts.ReadOnly { + mnt.setReadOnlyLocked(true) + } + return mnt +} + // A MountNamespace is a collection of Mounts. // // MountNamespaces are reference-counted. Unless otherwise specified, all @@ -129,13 +148,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth refs: 1, mountpoints: make(map[*Dentry]uint32), } - mntns.root = &Mount{ - vfs: vfs, - fs: fs, - root: root, - ns: mntns, - refs: 1, - } + mntns.root = newMount(vfs, fs, root, mntns, &MountOptions{}) return mntns, nil } @@ -148,12 +161,7 @@ func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry, if root != nil { root.IncRef() } - return &Mount{ - vfs: vfs, - fs: fs, - root: root, - refs: 1, - }, nil + return newMount(vfs, fs, root, nil /* mntns */, opts), nil } // MountAt creates and mounts a Filesystem configured by the given arguments. @@ -218,13 +226,7 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia // are directories, or neither are, and returns ENOTDIR if this is not the // case. mntns := vd.mount.ns - mnt := &Mount{ - vfs: vfs, - fs: fs, - root: root, - ns: mntns, - refs: 1, - } + mnt := newMount(vfs, fs, root, mntns, opts) vfs.mounts.seq.BeginWrite() vfs.connectLocked(mnt, vd, mntns) vfs.mounts.seq.EndWrite() diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 6af7fdac1..3e90dc4ed 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -46,8 +46,21 @@ type MknodOptions struct { DevMinor uint32 } +// MountFlags contains flags as specified for mount(2), e.g. MS_NOEXEC. +// MS_RDONLY is not part of MountFlags because it's tracked in Mount.writers. +type MountFlags struct { + // NoExec is equivalent to MS_NOEXEC. + NoExec bool +} + // MountOptions contains options to VirtualFilesystem.MountAt(). type MountOptions struct { + // Flags contains flags as specified for mount(2), e.g. MS_NOEXEC. + Flags MountFlags + + // ReadOnly is equivalent to MS_RDONLY. + ReadOnly bool + // GetFilesystemOptions contains options to FilesystemType.GetFilesystem(). GetFilesystemOptions GetFilesystemOptions @@ -75,7 +88,8 @@ type OpenOptions struct { // FileExec is set when the file is being opened to be executed. // VirtualFilesystem.OpenAt() checks that the caller has execute permissions - // on the file, and that the file is a regular file. + // on the file, that the file is a regular file, and that the mount doesn't + // have MS_NOEXEC set. FileExec bool } diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index 8e250998a..2c8f23f55 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -15,8 +15,12 @@ package vfs import ( + "math" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/syserror" ) @@ -147,7 +151,16 @@ func MayWriteFileWithOpenFlags(flags uint32) bool { // CheckSetStat checks that creds has permission to change the metadata of a // file with the given permissions, UID, and GID as specified by stat, subject // to the rules of Linux's fs/attr.c:setattr_prepare(). -func CheckSetStat(creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error { +func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error { + if stat.Mask&linux.STATX_SIZE != 0 { + limit, err := CheckLimit(ctx, 0, int64(stat.Size)) + if err != nil { + return err + } + if limit < int64(stat.Size) { + return syserror.ErrExceedsFileSizeLimit + } + } if stat.Mask&linux.STATX_MODE != 0 { if !CanActAsOwner(creds, kuid) { return syserror.EPERM @@ -205,3 +218,21 @@ func CanActAsOwner(creds *auth.Credentials, kuid auth.KUID) bool { func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth.KUID, kgid auth.KGID) bool { return creds.HasCapability(cp) && creds.UserNamespace.MapFromKUID(kuid).Ok() && creds.UserNamespace.MapFromKGID(kgid).Ok() } + +// CheckLimit enforces file size rlimits. It returns error if the write +// operation must not proceed. Otherwise it returns the max length allowed to +// without violating the limit. +func CheckLimit(ctx context.Context, offset, size int64) (int64, error) { + fileSizeLimit := limits.FromContext(ctx).Get(limits.FileSize).Cur + if fileSizeLimit > math.MaxInt64 { + return size, nil + } + if offset >= int64(fileSizeLimit) { + return 0, syserror.ErrExceedsFileSizeLimit + } + remaining := int64(fileSizeLimit) - offset + if remaining < size { + return remaining, nil + } + return size, nil +} diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index bde81e1ef..2e2880171 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -174,6 +174,23 @@ type PathOperation struct { FollowFinalSymlink bool } +// AccessAt checks whether a user with creds has access to the file at +// the given path. +func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credentials, ats AccessTypes, pop *PathOperation) error { + rp := vfs.getResolvingPath(creds, pop) + for { + err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + // GetDentryAt returns a VirtualDentry representing the given path, at which a // file must exist. A reference is taken on the returned VirtualDentry. func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) { @@ -388,6 +405,11 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential // TODO(gvisor.dev/issue/1193): Move inside fsimpl to avoid another call // to FileDescription.Stat(). if opts.FileExec { + if fd.Mount().flags.NoExec { + fd.DecRef() + return nil, syserror.EACCES + } + // Only a regular file can be executed. stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE}) if err != nil { diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index a984f1712..e57d45f2a 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -22,6 +22,7 @@ go_test( size = "small", srcs = ["gonet_test.go"], library = ":gonet", + tags = ["flaky"], deps = [ "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index c6c160dfc..8dc0f7c0e 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -785,6 +785,52 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { } } +// ndpOptions checks that optsBuf only contains opts. +func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) { + t.Helper() + + it, err := optsBuf.Iter(true) + if err != nil { + t.Errorf("optsBuf.Iter(true): %s", err) + return + } + + i := 0 + for { + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + t.Fatalf("unexpected error when iterating over NDP options: %s", err) + } + if done { + break + } + + if i >= len(opts) { + t.Errorf("got unexpected option: %s", opt) + continue + } + + switch wantOpt := opts[i].(type) { + case header.NDPSourceLinkLayerAddressOption: + gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { + t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) + } + default: + t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) + } + + i++ + } + + if missing := opts[i:]; len(missing) > 0 { + t.Errorf("missing options: %s", missing) + } +} + // NDPNSOptions creates a checker that checks that the packet contains the // provided NDP options within an NDP Neighbor Solicitation message. // @@ -796,47 +842,31 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker { icmp := h.(header.ICMPv6) ns := header.NDPNeighborSolicit(icmp.NDPPayload()) - it, err := ns.Options().Iter(true) - if err != nil { - t.Errorf("opts.Iter(true): %s", err) - return - } - - i := 0 - for { - opt, done, _ := it.Next() - if done { - break - } - - if i >= len(opts) { - t.Errorf("got unexpected option: %s", opt) - continue - } - - switch wantOpt := opts[i].(type) { - case header.NDPSourceLinkLayerAddressOption: - gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { - t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) - } - default: - panic("not implemented") - } - - i++ - } - - if missing := opts[i:]; len(missing) > 0 { - t.Errorf("missing options: %s", missing) - } + ndpOptions(t, ns.Options(), opts) } } // NDPRS creates a checker that checks that the packet contains a valid NDP // Router Solicitation message (as per the raw wire format). -func NDPRS() NetworkChecker { - return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize) +// +// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPRS as far as the size of the message is concerned. The values within the +// message are up to checkers to validate. +func NDPRS(checkers ...TransportChecker) NetworkChecker { + return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...) +} + +// NDPRSOptions creates a checker that checks that the packet contains the +// provided NDP options within an NDP Router Solicitation message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPRS message as far as the size is concerned. +func NDPRSOptions(opts []header.NDPOption) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + rs := header.NDPRouterSolicit(icmp.NDPPayload()) + ndpOptions(t, rs.Options(), opts) + } } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index e5360e7c1..76839eb92 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -38,7 +38,8 @@ const ( // IPv4Fields contains the fields of an IPv4 packet. It is used to describe the // fields of a packet that needs to be encoded. type IPv4Fields struct { - // IHL is the "internet header length" field of an IPv4 packet. + // IHL is the "internet header length" field of an IPv4 packet. The value + // is in bytes. IHL uint8 // TOS is the "type of service" field of an IPv4 packet. @@ -138,7 +139,7 @@ func IPVersion(b []byte) int { } // HeaderLength returns the value of the "header length" field of the ipv4 -// header. +// header. The length returned is in bytes. func (b IPv4) HeaderLength() uint8 { return (b[versIHL] & 0xf) * 4 } diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 82cfe785c..13480687d 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -81,7 +81,8 @@ type TCPFields struct { // AckNum is the "acknowledgement number" field of a TCP packet. AckNum uint32 - // DataOffset is the "data offset" field of a TCP packet. + // DataOffset is the "data offset" field of a TCP packet. It is the length of + // the TCP header in bytes. DataOffset uint8 // Flags is the "flags" field of a TCP packet. @@ -213,7 +214,8 @@ func (b TCP) AckNumber() uint32 { return binary.BigEndian.Uint32(b[TCPAckNumOffset:]) } -// DataOffset returns the "data offset" field of the tcp header. +// DataOffset returns the "data offset" field of the tcp header. The return +// value is the length of the TCP header in bytes. func (b TCP) DataOffset() uint8 { return (b[TCPDataOffset] >> 4) * 4 } @@ -238,6 +240,11 @@ func (b TCP) Checksum() uint16 { return binary.BigEndian.Uint16(b[TCPChecksumOffset:]) } +// UrgentPointer returns the "urgent pointer" field of the tcp header. +func (b TCP) UrgentPointer() uint16 { + return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:]) +} + // SetSourcePort sets the "source port" field of the tcp header. func (b TCP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port) @@ -253,6 +260,37 @@ func (b TCP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum) } +// SetDataOffset sets the data offset field of the tcp header. headerLen should +// be the length of the TCP header in bytes. +func (b TCP) SetDataOffset(headerLen uint8) { + b[TCPDataOffset] = (headerLen / 4) << 4 +} + +// SetSequenceNumber sets the sequence number field of the tcp header. +func (b TCP) SetSequenceNumber(seqNum uint32) { + binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum) +} + +// SetAckNumber sets the ack number field of the tcp header. +func (b TCP) SetAckNumber(ackNum uint32) { + binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum) +} + +// SetFlags sets the flags field of the tcp header. +func (b TCP) SetFlags(flags uint8) { + b[TCPFlagsOffset] = flags +} + +// SetWindowSize sets the window size field of the tcp header. +func (b TCP) SetWindowSize(rcvwnd uint16) { + binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) +} + +// SetUrgentPoiner sets the window size field of the tcp header. +func (b TCP) SetUrgentPoiner(urgentPointer uint16) { + binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer) +} + // CalculateChecksum calculates the checksum of the tcp segment. // partialChecksum is the checksum of the network-layer pseudo-header // and the checksum of the segment data. diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index dbaccbb36..d30571c74 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -135,6 +135,27 @@ func EmptyFilterTable() Table { } } +// EmptyNatTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyNatTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + Underflows: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + UserChains: map[string]int{}, + } +} + // A chainVerdict is what a table decides should be done with a packet. type chainVerdict int @@ -240,9 +261,14 @@ func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, r func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] - // First check whether the packet matches the IP header filter. - // TODO(gvisor.dev/issue/170): Support other fields of the filter. - if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() { + // If pkt.NetworkHeader hasn't been set yet, it will be contained in + // pkt.Data.First(). + if pkt.NetworkHeader == nil { + pkt.NetworkHeader = pkt.Data.First() + } + + // Check whether the packet matches the IP header filter. + if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -263,3 +289,26 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru // All the matchers matched, so run the target. return rule.Target.Action(pkt) } + +func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + // Check the transport protocol. + if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { + return false + } + + // Check the destination IP. + dest := hdr.DestinationAddress() + matches := true + for i := range filter.Dst { + if dest[i]&filter.DstMask[i] != filter.Dst[i] { + matches = false + break + } + } + if matches == filter.DstInvert { + return false + } + + return true +} diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go index 81a2e39a2..e457f2349 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/iptables/targets.go @@ -17,6 +17,7 @@ package iptables import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // AcceptTarget accepts packets. @@ -63,3 +64,81 @@ type ReturnTarget struct{} func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) { return RuleReturn, 0 } + +// RedirectTarget redirects the packet by modifying the destination port/IP. +// Min and Max values for IP and Ports in the struct indicate the range of +// values which can be used to redirect. +type RedirectTarget struct { + // TODO(gvisor.dev/issue/170): Other flags need to be added after + // we support them. + // RangeProtoSpecified flag indicates single port is specified to + // redirect. + RangeProtoSpecified bool + + // Min address used to redirect. + MinIP tcpip.Address + + // Max address used to redirect. + MaxIP tcpip.Address + + // Min port used to redirect. + MinPort uint16 + + // Max port used to redirect. + MaxPort uint16 +} + +// Action implements Target.Action. +// TODO(gvisor.dev/issue/170): Parse headers without copying. The current +// implementation only works for PREROUTING and calls pkt.Clone(), neither +// of which should be the case. +func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer) (RuleVerdict, int) { + newPkt := pkt.Clone() + + // Set network header. + headerView := newPkt.Data.First() + netHeader := header.IPv4(headerView) + newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + + hlen := int(netHeader.HeaderLength()) + tlen := int(netHeader.TotalLength()) + newPkt.Data.TrimFront(hlen) + newPkt.Data.CapLength(tlen - hlen) + + // TODO(gvisor.dev/issue/170): Change destination address to + // loopback or interface address on which the packet was + // received. + + // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if + // we need to change dest address (for OUTPUT chain) or ports. + switch protocol := netHeader.TransportProtocol(); protocol { + case header.UDPProtocolNumber: + var udpHeader header.UDP + if newPkt.TransportHeader != nil { + udpHeader = header.UDP(newPkt.TransportHeader) + } else { + if len(pkt.Data.First()) < header.UDPMinimumSize { + return RuleDrop, 0 + } + udpHeader = header.UDP(newPkt.Data.First()) + } + udpHeader.SetDestinationPort(rt.MinPort) + case header.TCPProtocolNumber: + var tcpHeader header.TCP + if newPkt.TransportHeader != nil { + tcpHeader = header.TCP(newPkt.TransportHeader) + } else { + if len(pkt.Data.First()) < header.TCPMinimumSize { + return RuleDrop, 0 + } + tcpHeader = header.TCP(newPkt.TransportHeader) + } + // TODO(gvisor.dev/issue/170): Need to recompute checksum + // and implement nat connection tracking to support TCP. + tcpHeader.SetDestinationPort(rt.MinPort) + default: + return RuleDrop, 0 + } + + return RuleAccept, 0 +} diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 7d032fd23..e7fcf6bff 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -144,6 +144,18 @@ type Rule struct { type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber + + // Dst matches the destination IP address. + Dst tcpip.Address + + // DstMask masks bits of the destination IP address when comparing with + // Dst. + DstMask tcpip.Address + + // DstInvert inverts the meaning of the destination IP check, i.e. when + // true the filter will match packets that fail the destination + // comparison. + DstInvert bool } // A Matcher is the interface for matching packets. diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index b7f60178e..3b36b9673 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -407,7 +407,6 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} - vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) if gso != nil { vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) if gso.NeedsCsum { @@ -428,6 +427,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } } + vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) } @@ -467,7 +467,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac views := pkts[0].Data.Views() /* - * Each bondary in views can add one more iovec. + * Each boundary in views can add one more iovec. * * payload | | | | * ----------------------------- diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go index 97a477b61..d81858353 100644 --- a/pkg/tcpip/link/fdbased/endpoint_unsafe.go +++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go @@ -24,9 +24,10 @@ import ( const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{})) func vnetHdrToByteSlice(hdr *virtioNetHdr) (slice []byte) { - sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) - sh.Data = uintptr(unsafe.Pointer(hdr)) - sh.Len = virtioNetHdrSize - sh.Cap = virtioNetHdrSize + *(*reflect.SliceHeader)(unsafe.Pointer(&slice)) = reflect.SliceHeader{ + Data: uintptr((unsafe.Pointer(hdr))), + Len: virtioNetHdrSize, + Cap: virtioNetHdrSize, + } return } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 6ff47a742..f6e301304 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -98,7 +98,12 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { prefix = "tap" } - endpoint, err := attachOrCreateNIC(s, name, prefix) + linkCaps := stack.CapabilityNone + if isTap { + linkCaps |= stack.CapabilityResolutionRequired + } + + endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) if err != nil { return syserror.EINVAL } @@ -109,7 +114,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { return nil } -func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) { +func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) { for { // 1. Try to attach to an existing NIC. if name != "" { @@ -135,6 +140,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error nicID: id, name: name, } + endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { endpoint.name = fmt.Sprintf("%s%d", prefix, id) } diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 92f2aa13a..f42abc4bb 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -115,10 +115,12 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. if f.size > f.highLimit { - tail := f.rList.Back() - for f.size > f.lowLimit && tail != nil { + for f.size > f.lowLimit { + tail := f.rList.Back() + if tail == nil { + break + } f.release(tail) - tail = tail.Prev() } } f.mu.Unlock() diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 705cf01ee..6c029b2fb 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -18,6 +18,8 @@ go_template_instance( go_library( name = "stack", srcs = [ + "dhcpv6configurationfromndpra_string.go", + "forwarder.go", "icmp_rate_limit.go", "linkaddrcache.go", "linkaddrentry_list.go", @@ -79,6 +81,7 @@ go_test( name = "stack_test", size = "small", srcs = [ + "forwarder_test.go", "linkaddrcache_test.go", "nic_test.go", ], diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go new file mode 100644 index 000000000..8b4213eec --- /dev/null +++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type=DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[DHCPv6NoConfiguration-0] + _ = x[DHCPv6ManagedAddress-1] + _ = x[DHCPv6OtherConfigurations-2] +} + +const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations" + +var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66} + +func (i DHCPv6ConfigurationFromNDPRA) String() string { + if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) { + return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]] +} diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go new file mode 100644 index 000000000..631953935 --- /dev/null +++ b/pkg/tcpip/stack/forwarder.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // maxPendingResolutions is the maximum number of pending link-address + // resolutions. + maxPendingResolutions = 64 + maxPendingPacketsPerResolution = 256 +) + +type pendingPacket struct { + nic *NIC + route *Route + proto tcpip.NetworkProtocolNumber + pkt tcpip.PacketBuffer +} + +type forwardQueue struct { + sync.Mutex + + // The packets to send once the resolver completes. + packets map[<-chan struct{}][]*pendingPacket + + // FIFO of channels used to cancel the oldest goroutine waiting for + // link-address resolution. + cancelChans []chan struct{} +} + +func newForwardQueue() *forwardQueue { + return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} +} + +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + shouldWait := false + + f.Lock() + packets, ok := f.packets[ch] + if !ok { + shouldWait = true + } + for len(packets) == maxPendingPacketsPerResolution { + p := packets[0] + packets = packets[1:] + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + p.route.Release() + } + if l := len(packets); l >= maxPendingPacketsPerResolution { + panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) + } + f.packets[ch] = append(packets, &pendingPacket{ + nic: n, + route: r, + proto: protocol, + pkt: pkt, + }) + f.Unlock() + + if !shouldWait { + return + } + + // Wait for the link-address resolution to complete. + // Start a goroutine with a forwarding-cancel channel so that we can + // limit the maximum number of goroutines running concurrently. + cancel := f.newCancelChannel() + go func() { + cancelled := false + select { + case <-ch: + case <-cancel: + cancelled = true + } + + f.Lock() + packets := f.packets[ch] + delete(f.packets, ch) + f.Unlock() + + for _, p := range packets { + if cancelled { + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + } else if _, err := p.route.Resolve(nil); err != nil { + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + } else { + p.nic.forwardPacket(p.route, p.proto, p.pkt) + } + p.route.Release() + } + }() +} + +// newCancelChannel creates a channel that can cancel a pending forwarding +// activity. The oldest channel is closed if the number of open channels would +// exceed maxPendingResolutions. +func (f *forwardQueue) newCancelChannel() chan struct{} { + f.Lock() + defer f.Unlock() + + if len(f.cancelChans) == maxPendingResolutions { + ch := f.cancelChans[0] + f.cancelChans = f.cancelChans[1:] + close(ch) + } + if l := len(f.cancelChans); l >= maxPendingResolutions { + panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) + } + + ch := make(chan struct{}) + f.cancelChans = append(f.cancelChans, ch) + return ch +} diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go new file mode 100644 index 000000000..321b7524d --- /dev/null +++ b/pkg/tcpip/stack/forwarder_test.go @@ -0,0 +1,635 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "encoding/binary" + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +const ( + fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + fwdTestNetHeaderLen = 12 + fwdTestNetDefaultPrefixLen = 8 + + // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, + // except where another value is explicitly used. It is chosen to match + // the MTU of loopback interfaces on linux systems. + fwdTestNetDefaultMTU = 65536 +) + +// fwdTestNetworkEndpoint is a network-layer protocol endpoint. +// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only +// use the first three: destination address, source address, and transport +// protocol. They're all one byte fields to simplify parsing. +type fwdTestNetworkEndpoint struct { + nicID tcpip.NICID + id NetworkEndpointID + prefixLen int + proto *fwdTestNetworkProtocol + dispatcher TransportDispatcher + ep LinkEndpoint +} + +func (f *fwdTestNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) +} + +func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID { + return f.nicID +} + +func (f *fwdTestNetworkEndpoint) PrefixLen() int { + return f.prefixLen +} + +func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { + return 123 +} + +func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { + return &f.id +} + +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt tcpip.PacketBuffer) { + // Consume the network header. + b := pkt.Data.First() + pkt.Data.TrimFront(fwdTestNetHeaderLen) + + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) +} + +func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { + return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen +} + +func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { + return 0 +} + +func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { + return f.ep.Capabilities() +} + +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { + // Add the protocol's header to the packet and send it to the link + // endpoint. + b := pkt.Header.Prepend(fwdTestNetHeaderLen) + b[0] = r.RemoteAddress[0] + b[1] = f.id.LocalAddress[0] + b[2] = byte(params.Protocol) + + return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { + panic("not implemented") +} + +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + +func (*fwdTestNetworkEndpoint) Close() {} + +// fwdTestNetworkProtocol is a network-layer protocol that implements Address +// resolution. +type fwdTestNetworkProtocol struct { + addrCache *linkAddrCache + addrResolveDelay time.Duration + onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address) + onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) +} + +func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { + return fwdTestNetNumber +} + +func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { + return fwdTestNetHeaderLen +} + +func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { + return fwdTestNetDefaultPrefixLen +} + +func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) +} + +func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { + return &fwdTestNetworkEndpoint{ + nicID: nicID, + id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, + proto: f, + dispatcher: dispatcher, + ep: ep, + }, nil +} + +func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func (f *fwdTestNetworkProtocol) Close() {} + +func (f *fwdTestNetworkProtocol) Wait() {} + +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error { + if f.addrCache != nil && f.onLinkAddressResolved != nil { + time.AfterFunc(f.addrResolveDelay, func() { + f.onLinkAddressResolved(f.addrCache, addr) + }) + } + return nil +} + +func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if f.onResolveStaticAddress != nil { + return f.onResolveStaticAddress(addr) + } + return "", false +} + +func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return fwdTestNetNumber +} + +// fwdTestPacketInfo holds all the information about an outbound packet. +type fwdTestPacketInfo struct { + RemoteLinkAddress tcpip.LinkAddress + LocalLinkAddress tcpip.LinkAddress + Pkt tcpip.PacketBuffer +} + +type fwdTestLinkEndpoint struct { + dispatcher NetworkDispatcher + mtu uint32 + linkAddr tcpip.LinkAddress + + // C is where outbound packets are queued. + C chan fwdTestPacketInfo +} + +// InjectInbound injects an inbound packet. +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) +} + +// InjectLinkAddr injects an inbound packet with a remote link address. +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) +} + +// Attach saves the stack network-layer dispatcher for use later when packets +// are injected. +func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *fwdTestLinkEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *fwdTestLinkEndpoint) MTU() uint32 { + return e.mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { + caps := LinkEndpointCapabilities(0) + return caps | CapabilityResolutionRequired +} + +// GSOMaxSize returns the maximum GSO packet size. +func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { + return 1 << 15 +} + +// MaxHeaderLength returns the maximum size of the link layer header. Given it +// doesn't have a header, it just returns 0. +func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + p := fwdTestPacketInfo{ + RemoteLinkAddress: r.RemoteLinkAddress, + LocalLinkAddress: r.LocalLinkAddress, + Pkt: pkt, + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// WritePackets stores outbound packets into the channel. +func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + n := 0 + for _, pkt := range pkts { + e.WritePacket(r, gso, protocol, pkt) + n++ + } + + return n, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + p := fwdTestPacketInfo{ + Pkt: tcpip.PacketBuffer{Data: vv}, + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// Wait implements stack.LinkEndpoint.Wait. +func (*fwdTestLinkEndpoint) Wait() {} + +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { + // Create a stack with the network protocol and two NICs. + s := New(Options{ + NetworkProtocols: []NetworkProtocol{proto}, + }) + + proto.addrCache = s.linkAddrCache + + // Enable forwarding. + s.SetForwarding(true) + + // NIC 1 has the link address "a", and added the network address 1. + ep1 = &fwdTestLinkEndpoint{ + C: make(chan fwdTestPacketInfo, 300), + mtu: fwdTestNetDefaultMTU, + linkAddr: "a", + } + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC #1 failed:", err) + } + if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { + t.Fatal("AddAddress #1 failed:", err) + } + + // NIC 2 has the link address "b", and added the network address 2. + ep2 = &fwdTestLinkEndpoint{ + C: make(chan fwdTestPacketInfo, 300), + mtu: fwdTestNetDefaultMTU, + linkAddr: "b", + } + if err := s.CreateNIC(2, ep2); err != nil { + t.Fatal("CreateNIC #2 failed:", err) + } + if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { + t.Fatal("AddAddress #2 failed:", err) + } + + // Route all packets to NIC 2. + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) + } + + return ep1, ep2 +} + +func TestForwardingWithStaticResolver(t *testing.T) { + // Create a network protocol with a static resolver. + proto := &fwdTestNetworkProtocol{ + onResolveStaticAddress: + // The network address 3 is resolved to the link address "c". + func(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\x03" { + return "c", true + } + return "", false + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + default: + t.Fatal("packet not forwarded") + } + + // Test that the static address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithFakeResolver(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any address will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithNoResolver(t *testing.T) { + // Create a network protocol without a resolver. + proto := &fwdTestNetworkProtocol{} + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + select { + case <-ep2.C: + t.Fatal("Packet should not be forwarded") + case <-time.After(time.Second): + } +} + +func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + } + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 4 on NIC 1. This packet should + // not be forwarded. + buf := buffer.NewView(30) + buf[0] = 4 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf = buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := p.Pkt.Header.View() + if b[0] != 3 { + t.Fatalf("got b[0] = %d, want = 3", b[0]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject two inbound packets to address 3 on NIC 1. + for i := 0; i < 2; i++ { + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + } + + for i := 0; i < 2; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := p.Pkt.Header.View() + if b[0] != 3 { + t.Fatalf("got b[0] = %d, want = 3", b[0]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} + +func TestForwardingWithFakeResolverManyPackets(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + for i := 0; i < maxPendingPacketsPerResolution+5; i++ { + // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. + buf := buffer.NewView(30) + buf[0] = 3 + // Set the packet sequence number. + binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + } + + for i := 0; i < maxPendingPacketsPerResolution; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := p.Pkt.Header.View() + if b[0] != 3 { + t.Fatalf("got b[0] = %d, want = 3", b[0]) + } + // The first 5 packets should not be forwarded so the the + // sequemnce number should start with 5. + want := uint16(i + 5) + if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want { + t.Fatalf("got the packet #%d, want = #%d", n, want) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} + +func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + for i := 0; i < maxPendingResolutions+5; i++ { + // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. + // Each packet has a different destination address (3 to + // maxPendingResolutions + 7). + buf := buffer.NewView(30) + buf[0] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + } + + for i := 0; i < maxPendingResolutions; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // The first 5 packets (address 3 to 7) should not be forwarded + // because their address resolutions are interrupted. + b := p.Pkt.Header.View() + if b[0] < 8 { + t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index f651871ce..d689a006d 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -361,16 +361,16 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState + // The timer used to send the next router solicitation message. + rtrSolicitTimer *time.Timer + // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState - // The timer used to send the next router solicitation message. - // If routers are being solicited, rtrSolicitTimer MUST NOT be nil. - rtrSolicitTimer *time.Timer - - // The addresses generated by SLAAC. - autoGenAddresses map[tcpip.Address]autoGenAddressState + // The SLAAC prefixes discovered through Router Advertisements' Prefix + // Information option. + slaacPrefixes map[tcpip.Subnet]slaacPrefixState // The last learned DHCPv6 configuration from an NDP RA. dhcpv6Configuration DHCPv6ConfigurationFromNDPRA @@ -402,18 +402,16 @@ type onLinkPrefixState struct { invalidationTimer tcpip.CancellableTimer } -// autoGenAddressState holds data associated with an address generated via -// SLAAC. -type autoGenAddressState struct { - // A reference to the referencedNetworkEndpoint that this autoGenAddressState - // is holding state for. - ref *referencedNetworkEndpoint - +// slaacPrefixState holds state associated with a SLAAC prefix. +type slaacPrefixState struct { deprecationTimer tcpip.CancellableTimer invalidationTimer tcpip.CancellableTimer // Nonzero only when the address is not valid forever. validUntil time.Time + + // The prefix's permanent address endpoint. + ref *referencedNetworkEndpoint } // startDuplicateAddressDetection performs Duplicate Address Detection. @@ -899,23 +897,15 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform prefix := pi.Subnet() - // Check if we already have an auto-generated address for prefix. - for addr, addrState := range ndp.autoGenAddresses { - refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()} - if refAddrWithPrefix.Subnet() != prefix { - continue - } - - // At this point, we know we are refreshing a SLAAC generated IPv6 address - // with the prefix prefix. Do the work as outlined by RFC 4862 section - // 5.5.3.e. - ndp.refreshAutoGenAddressLifetimes(addr, pl, vl) + // Check if we already maintain SLAAC state for prefix. + if _, ok := ndp.slaacPrefixes[prefix]; ok { + // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes. + ndp.refreshSLAACPrefixLifetimes(prefix, pl, vl) return } - // We do not already have an address with the prefix prefix. Do the - // work as outlined by RFC 4862 section 5.5.3.d if n is configured - // to auto-generate global addresses by SLAAC. + // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section + // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC. if !ndp.configs.AutoGenGlobalAddresses { return } @@ -927,6 +917,8 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform // for prefix. // // pl is the new preferred lifetime. vl is the new valid lifetime. +// +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // If we do not already have an address for this prefix and the valid // lifetime is 0, no need to do anything further, as per RFC 4862 @@ -942,9 +934,59 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { return } + // If the preferred lifetime is zero, then the prefix should be considered + // deprecated. + deprecated := pl == 0 + ref := ndp.addSLAACAddr(prefix, deprecated) + if ref == nil { + // We were unable to generate a permanent address for prefix so do nothing + // further as there is no reason to maintain state for a SLAAC prefix we + // cannot generate a permanent address for. + return + } + + state := slaacPrefixState{ + deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + log.Fatalf("ndp: must have a slaacPrefixes entry for the SLAAC prefix %s", prefix) + } + + ndp.deprecateSLAACAddress(prefixState.ref) + }), + invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + ndp.invalidateSLAACPrefix(prefix, true) + }), + ref: ref, + } + + // Setup the initial timers to deprecate and invalidate prefix. + + if !deprecated && pl < header.NDPInfiniteLifetime { + state.deprecationTimer.Reset(pl) + } + + if vl < header.NDPInfiniteLifetime { + state.invalidationTimer.Reset(vl) + state.validUntil = time.Now().Add(vl) + } + + ndp.slaacPrefixes[prefix] = state +} + +// addSLAACAddr adds a SLAAC address for prefix. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referencedNetworkEndpoint { addrBytes := []byte(prefix.ID()) if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey) + addrBytes = header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), + 0, /* dadCounter */ + oIID.SecretKey, + ) } else { // Only attempt to generate an interface-specific IID if we have a valid // link address. @@ -953,137 +995,103 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // LinkEndpoint.LinkAddress) before reaching this point. linkAddr := ndp.nic.linkEP.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { - return + return nil } // Generate an address within prefix from the modified EUI-64 of ndp's NIC's // Ethernet MAC address. header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) } - addr := tcpip.Address(addrBytes) - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: validPrefixLenForAutoGen, + + generatedAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: validPrefixLenForAutoGen, + }, } // If the nic already has this address, do nothing further. - if ndp.nic.hasPermanentAddrLocked(addr) { - return + if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) { + return nil } // Inform the integrator that we have a new SLAAC address. ndpDisp := ndp.nic.stack.ndpDisp if ndpDisp == nil { - return + return nil } - if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { + + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) { // Informed by the integrator not to add the address. - return + return nil } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addrWithPrefix, - } - // If the preferred lifetime is zero, then the address should be considered - // deprecated. - deprecated := pl == 0 - ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) + ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) if err != nil { - log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err) - } - - state := autoGenAddressState{ - ref: ref, - deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { - addrState, ok := ndp.autoGenAddresses[addr] - if !ok { - log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr) - } - addrState.ref.deprecated = true - ndp.notifyAutoGenAddressDeprecated(addr) - }), - invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { - ndp.invalidateAutoGenAddress(addr) - }), - } - - // Setup the initial timers to deprecate and invalidate this newly generated - // address. - - if !deprecated && pl < header.NDPInfiniteLifetime { - state.deprecationTimer.Reset(pl) + log.Fatalf("ndp: error when adding address %+v: %s", generatedAddr, err) } - if vl < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(vl) - state.validUntil = time.Now().Add(vl) - } - - ndp.autoGenAddresses[addr] = state + return ref } -// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated -// address addr. +// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. // // pl is the new preferred lifetime. vl is the new valid lifetime. -func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) { - addrState, ok := ndp.autoGenAddresses[addr] +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) { + prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { - log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr) + log.Fatalf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix) } - defer func() { ndp.autoGenAddresses[addr] = addrState }() + defer func() { ndp.slaacPrefixes[prefix] = prefixState }() - // If the preferred lifetime is zero, then the address should be considered - // deprecated. + // If the preferred lifetime is zero, then the prefix should be deprecated. deprecated := pl == 0 - wasDeprecated := addrState.ref.deprecated - addrState.ref.deprecated = deprecated - - // Only send the deprecation event if the deprecated status for addr just - // changed from non-deprecated to deprecated. - if !wasDeprecated && deprecated { - ndp.notifyAutoGenAddressDeprecated(addr) + if deprecated { + ndp.deprecateSLAACAddress(prefixState.ref) + } else { + prefixState.ref.deprecated = false } - // If addr was preferred for some finite lifetime before, stop the deprecation - // timer so it can be reset. - addrState.deprecationTimer.StopLocked() + // If prefix was preferred for some finite lifetime before, stop the + // deprecation timer so it can be reset. + prefixState.deprecationTimer.StopLocked() - // Reset the deprecation timer if addr has a finite preferred lifetime. + // Reset the deprecation timer if prefix has a finite preferred lifetime. if !deprecated && pl < header.NDPInfiniteLifetime { - addrState.deprecationTimer.Reset(pl) + prefixState.deprecationTimer.Reset(pl) } - // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address - // + // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: // // 1) If the received Valid Lifetime is greater than 2 hours or greater than - // RemainingLifetime, set the valid lifetime of the address to the + // RemainingLifetime, set the valid lifetime of the prefix to the // advertised Valid Lifetime. // // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the // advertised Valid Lifetime. // - // 3) Otherwise, reset the valid lifetime of the address to 2 hours. + // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. // Handle the infinite valid lifetime separately as we do not keep a timer in // this case. if vl >= header.NDPInfiniteLifetime { - addrState.invalidationTimer.StopLocked() - addrState.validUntil = time.Time{} + prefixState.invalidationTimer.StopLocked() + prefixState.validUntil = time.Time{} return } var effectiveVl time.Duration var rl time.Duration - // If the address was originally set to be valid forever, assume the remaining + // If the prefix was originally set to be valid forever, assume the remaining // time to be the maximum possible value. - if addrState.validUntil == (time.Time{}) { + if prefixState.validUntil == (time.Time{}) { rl = header.NDPInfiniteLifetime } else { - rl = time.Until(addrState.validUntil) + rl = time.Until(prefixState.validUntil) } if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { @@ -1094,58 +1102,66 @@ func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl t effectiveVl = MinPrefixInformationValidLifetimeForUpdate } - addrState.invalidationTimer.StopLocked() - addrState.invalidationTimer.Reset(effectiveVl) - addrState.validUntil = time.Now().Add(effectiveVl) -} - -// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr -// has been deprecated. -func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) { - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: validPrefixLenForAutoGen, - }) - } + prefixState.invalidationTimer.StopLocked() + prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.validUntil = time.Now().Add(effectiveVl) } -// invalidateAutoGenAddress invalidates an auto-generated address. +// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP +// dispatcher that ref has been deprecated. +// +// deprecateSLAACAddress does nothing if ref is already deprecated. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { - if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) { +func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { + if ref.deprecated { return } - ndp.nic.removePermanentAddressLocked(addr) + ref.deprecated = true + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + }) + } } -// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated -// address's resources from ndp. If the stack has an NDP dispatcher, it will -// be notified that addr has been invalidated. -// -// Returns true if ndp had resources for addr to cleanup. +// invalidateSLAACPrefix invalidates a SLAAC prefix. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { - state, ok := ndp.autoGenAddresses[addr] +func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, removeAddr bool) { + state, ok := ndp.slaacPrefixes[prefix] if !ok { - return false + return } state.deprecationTimer.StopLocked() state.invalidationTimer.StopLocked() - delete(ndp.autoGenAddresses, addr) + delete(ndp.slaacPrefixes, prefix) + + addr := state.ref.ep.ID().LocalAddress + + if removeAddr { + if err := ndp.nic.removePermanentAddressLocked(addr); err != nil { + log.Fatalf("ndp: removePermanentAddressLocked(%s): %s", addr, err) + } + } if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ Address: addr, - PrefixLen: validPrefixLenForAutoGen, + PrefixLen: state.ref.ep.PrefixLen(), }) } +} - return true +// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC +// address's resources from ndp. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix) { + ndp.invalidateSLAACPrefix(addr.Subnet(), false) } // cleanupState cleans up ndp's state. @@ -1163,21 +1179,21 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupState(hostOnly bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - linkLocalAddrs := 0 - for addr := range ndp.autoGenAddresses { + linkLocalPrefixes := 0 + for prefix := range ndp.slaacPrefixes { // RFC 4862 section 5 states that routers are also expected to generate a // link-local address so we do not invalidate them if we are cleaning up // host-only state. - if hostOnly && linkLocalSubnet.Contains(addr) { - linkLocalAddrs++ + if hostOnly && prefix == linkLocalSubnet { + linkLocalPrefixes++ continue } - ndp.invalidateAutoGenAddress(addr) + ndp.invalidateSLAACPrefix(prefix, true) } - if got := len(ndp.autoGenAddresses); got != linkLocalAddrs { - log.Fatalf("ndp: still have non-linklocal auto-generated addresses after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalAddrs) + if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { + log.Fatalf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes) } for prefix := range ndp.onLinkPrefixes { @@ -1220,9 +1236,15 @@ func (ndp *ndpState) startSolicitingRouters() { } ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { - // Send an RS message with the unspecified source address. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) + // As per RFC 4861 section 4.1, the source of the RS is an address assigned + // to the sending interface, or the unspecified address if no address is + // assigned to the sending interface. + ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress) + if ref == nil { + ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + } + localAddr := ref.ep.ID().LocalAddress + r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() // Route should resolve immediately since @@ -1234,10 +1256,25 @@ func (ndp *ndpState) startSolicitingRouters() { log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()) } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source + // link-layer address option if the source address of the NDP RS is + // specified. This option MUST NOT be included if the source address is + // unspecified. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + var optsSerializer header.NDPOptionsSerializer + if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) { + optsSerializer = header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress), + } + } + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) pkt := header.ICMPv6(hdr.Prepend(payloadSize)) pkt.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(pkt.NDPPayload()) + rs.Options().Serialize(optsSerializer) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6e9306d09..4368c236c 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -639,8 +639,9 @@ func TestDADStop(t *testing.T) { const nicID = 1 tests := []struct { - name string - stopFn func(t *testing.T, s *stack.Stack) + name string + stopFn func(t *testing.T, s *stack.Stack) + skipFinalAddrCheck bool }{ // Tests to make sure that DAD stops when an address is removed. { @@ -661,6 +662,19 @@ func TestDADStop(t *testing.T) { } }, }, + + // Tests to make sure that DAD stops when the NIC is removed. + { + name: "Remove NIC", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("RemoveNIC(%d): %s", nicID, err) + } + }, + // The NIC is removed so we can't check its addresses after calling + // stopFn. + skipFinalAddrCheck: true, + }, } for _, test := range tests { @@ -710,12 +724,15 @@ func TestDADStop(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + + if !test.skipFinalAddrCheck { + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } } // Should not have sent more than 1 NS message. @@ -2983,11 +3000,12 @@ func TestCleanupNDPState(t *testing.T) { cleanupFn func(t *testing.T, s *stack.Stack) keepAutoGenLinkLocal bool maxAutoGenAddrEvents int + skipFinalAddrCheck bool }{ // A NIC should still keep its auto-generated link-local address when // becoming a router. { - name: "Forwarding Enable", + name: "Enable forwarding", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() s.SetForwarding(true) @@ -2998,7 +3016,7 @@ func TestCleanupNDPState(t *testing.T) { // A NIC should cleanup all NDP state when it is disabled. { - name: "NIC Disable", + name: "Disable NIC", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() @@ -3012,6 +3030,26 @@ func TestCleanupNDPState(t *testing.T) { keepAutoGenLinkLocal: false, maxAutoGenAddrEvents: 6, }, + + // A NIC should cleanup all NDP state when it is removed. + { + name: "Remove NIC", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + + if err := s.RemoveNIC(nicID1); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err) + } + if err := s.RemoveNIC(nicID2); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err) + } + }, + keepAutoGenLinkLocal: false, + maxAutoGenAddrEvents: 6, + // The NICs are removed so we can't check their addresses after calling + // stopFn. + skipFinalAddrCheck: true, + }, } for _, test := range tests { @@ -3230,35 +3268,37 @@ func TestCleanupNDPState(t *testing.T) { t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) } - // Make sure the auto-generated addresses got removed. - nicinfo = s.NICInfo() - nic1Addrs = nicinfo[nicID1].ProtocolAddresses - nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + if !test.skipFinalAddrCheck { + // Make sure the auto-generated addresses got removed. + nicinfo = s.NICInfo() + nic1Addrs = nicinfo[nicID1].ProtocolAddresses + nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } } - } - if containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + if containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + } + if containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) } - } - if containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) } // Should not get any more events (invalidation timers should have been @@ -3384,6 +3424,10 @@ func TestRouterSolicitation(t *testing.T) { tests := []struct { name string linkHeaderLen uint16 + linkAddr tcpip.LinkAddress + nicAddr tcpip.Address + expectedSrcAddr tcpip.Address + expectedNDPOpts []header.NDPOption maxRtrSolicit uint8 rtrSolicitInt time.Duration effectiveRtrSolicitInt time.Duration @@ -3392,6 +3436,7 @@ func TestRouterSolicitation(t *testing.T) { }{ { name: "Single RS with delay", + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3401,6 +3446,8 @@ func TestRouterSolicitation(t *testing.T) { { name: "Two RS with delay", linkHeaderLen: 1, + nicAddr: llAddr1, + expectedSrcAddr: llAddr1, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3408,8 +3455,14 @@ func TestRouterSolicitation(t *testing.T) { effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, }, { - name: "Single RS without delay", - linkHeaderLen: 2, + name: "Single RS without delay", + linkHeaderLen: 2, + linkAddr: linkAddr1, + nicAddr: llAddr1, + expectedSrcAddr: llAddr1, + expectedNDPOpts: []header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3419,6 +3472,8 @@ func TestRouterSolicitation(t *testing.T) { { name: "Two RS without delay and invalid zero interval", linkHeaderLen: 3, + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 2, rtrSolicitInt: 0, effectiveRtrSolicitInt: 4 * time.Second, @@ -3427,6 +3482,8 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Three RS without delay", + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 3, rtrSolicitInt: 500 * time.Millisecond, effectiveRtrSolicitInt: 500 * time.Millisecond, @@ -3435,6 +3492,8 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS with invalid negative delay", + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3457,7 +3516,7 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1), + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), headerLength: test.linkHeaderLen, } e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -3481,10 +3540,10 @@ func TestRouterSolicitation(t *testing.T) { checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), + checker.SrcAddr(test.expectedSrcAddr), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), - checker.NDPRS(), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { @@ -3510,13 +3569,19 @@ func TestRouterSolicitation(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - // Make sure each RS got sent at the right - // times. + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } + + // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit if remaining > 0 { waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) remaining-- } + for ; remaining > 0; remaining-- { waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout) waitForPkt(defaultAsyncEventTimeout) @@ -3550,17 +3615,19 @@ func TestStopStartSolicitingRouters(t *testing.T) { tests := []struct { name string startFn func(t *testing.T, s *stack.Stack) - stopFn func(t *testing.T, s *stack.Stack) + // first is used to tell stopFn that it is being called for the first time + // after router solicitations were last enabled. + stopFn func(t *testing.T, s *stack.Stack, first bool) }{ // Tests that when forwarding is enabled or disabled, router solicitations // are stopped or started, respectively. { - name: "Forwarding enabled and disabled", + name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() s.SetForwarding(false) }, - stopFn: func(t *testing.T, s *stack.Stack) { + stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() s.SetForwarding(true) }, @@ -3569,7 +3636,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Tests that when a NIC is enabled or disabled, router solicitations // are started or stopped, respectively. { - name: "NIC disabled and enabled", + name: "Enable and disable NIC", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() @@ -3577,7 +3644,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } }, - stopFn: func(t *testing.T, s *stack.Stack) { + stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() if err := s.DisableNIC(nicID); err != nil { @@ -3585,6 +3652,25 @@ func TestStopStartSolicitingRouters(t *testing.T) { } }, }, + + // Tests that when a NIC is removed, router solicitations are stopped. We + // cannot start router solications on a removed NIC. + { + name: "Remove NIC", + stopFn: func(t *testing.T, s *stack.Stack, first bool) { + t.Helper() + + // Only try to remove the NIC the first time stopFn is called since it's + // impossible to remove an already removed NIC. + if !first { + return + } + + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) + } + }, + }, } for _, test := range tests { @@ -3623,7 +3709,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { } // Stop soliciting routers. - test.stopFn(t, s) + test.stopFn(t, s, true /* first */) ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { @@ -3637,13 +3723,18 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. - test.stopFn(t, s) + test.stopFn(t, s, false /* first */) ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") } + // If test.startFn is nil, there is no way to restart router solications. + if test.startFn == nil { + return + } + // Start soliciting routers. test.startFn(t, s) waitForPkt(delay + defaultAsyncEventTimeout) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3e6196aee..9dcb1d52c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "log" "reflect" "sort" @@ -25,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" ) var ipv4BroadcastAddr = tcpip.ProtocolAddress{ @@ -54,7 +56,7 @@ type NIC struct { primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint endpoints map[NetworkEndpointID]*referencedNetworkEndpoint addressRanges []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 + mcastJoins map[NetworkEndpointID]uint32 // packetEPs is protected by mu, but the contained PacketEndpoint // values are not. packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint @@ -121,15 +123,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC } nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) - nic.mu.mcastJoins = make(map[NetworkEndpointID]int32) + nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) nic.mu.ndp = ndpState{ - nic: nic, - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), + nic: nic, + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), } // Register supported packet endpoint protocols. @@ -165,8 +167,17 @@ func (n *NIC) disable() *tcpip.Error { } n.mu.Lock() - defer n.mu.Unlock() + err := n.disableLocked() + n.mu.Unlock() + return err +} +// disableLocked disables n. +// +// It undoes the work done by enable. +// +// n MUST be locked. +func (n *NIC) disableLocked() *tcpip.Error { if !n.mu.enabled { return nil } @@ -189,7 +200,7 @@ func (n *NIC) disable() *tcpip.Error { } // The NIC may have already left the multicast group. - if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { return err } } @@ -305,24 +316,33 @@ func (n *NIC) remove() *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - // Detach from link endpoint, so no packet comes in. - n.linkEP.Attach(nil) + n.disableLocked() + + // TODO(b/151378115): come up with a better way to pick an error than the + // first one. + var err *tcpip.Error + + // Forcefully leave multicast groups. + for nid := range n.mu.mcastJoins { + if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil { + err = tempErr + } + } // Remove permanent and permanentTentative addresses, so no packet goes out. - var errs []*tcpip.Error for nid, ref := range n.mu.endpoints { switch ref.getKind() { case permanentTentative, permanent: - if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil { - errs = append(errs, err) + if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil { + err = tempErr } } } - if len(errs) > 0 { - return errs[0] - } - return nil + // Detach from link endpoint, so no packet comes in. + n.linkEP.Attach(nil) + + return err } // becomeIPv6Router transitions n into an IPv6 router. @@ -969,6 +989,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { for i, ref := range refs { if ref == r { n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + refs[len(refs)-1] = nil break } } @@ -996,8 +1017,7 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) if isIPv6Unicast { - // If we are removing a tentative IPv6 unicast address, stop - // DAD. + // If we are removing a tentative IPv6 unicast address, stop DAD. if kind == permanentTentative { n.mu.ndp.stopDuplicateAddressDetection(addr) } @@ -1005,7 +1025,10 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. if r.configType == slaac { - n.mu.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: r.ep.PrefixLen(), + }) } } @@ -1019,9 +1042,12 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { // If we are removing an IPv6 unicast address, leave the solicited-node // multicast address. + // + // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node + // multicast group may be left by user action. if isIPv6Unicast { snmc := header.SolicitedNodeAddr(addr) - if err := n.leaveGroupLocked(snmc); err != nil { + if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { return err } } @@ -1081,26 +1107,31 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.leaveGroupLocked(addr) + return n.leaveGroupLocked(addr, false /* force */) } // leaveGroupLocked decrements the count for the given multicast address, and // when it reaches zero removes the endpoint for this address. n MUST be locked // before leaveGroupLocked is called. -func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { +// +// If force is true, then the count for the multicast addres is ignored and the +// endpoint will be removed immediately. +func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error { id := NetworkEndpointID{addr} - joins := n.mu.mcastJoins[id] - switch joins { - case 0: + joins, ok := n.mu.mcastJoins[id] + if !ok { // There are no joins with this address on this NIC. return tcpip.ErrBadLocalAddress - case 1: - // This is the last one, clean up. - if err := n.removePermanentAddressLocked(addr); err != nil { - return err - } } - n.mu.mcastJoins[id] = joins - 1 + + joins-- + if force || joins == 0 { + // There are no outstanding joins or we are forced to leave, clean up. + delete(n.mu.mcastJoins, id) + return n.removePermanentAddressLocked(addr) + } + + n.mu.mcastJoins[id] = joins return nil } @@ -1116,6 +1147,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr + ref.ep.HandlePacket(&r, pkt) ref.decRef() } @@ -1186,6 +1218,16 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() return } + + // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. + if protocol == header.IPv4ProtocolNumber { + ipt := n.stack.IPTables() + if ok := ipt.Check(iptables.Prerouting, pkt); !ok { + // iptables is telling us to drop the packet. + return + } + } + if ref := n.getRef(protocol, dst); ref != nil { handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) return @@ -1201,10 +1243,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() return } - defer r.Release() - - r.LocalLinkAddress = n.linkEP.LinkAddress() - r.RemoteLinkAddress = remote // Found a NIC. n := r.ref.nic @@ -1213,24 +1251,33 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() n.mu.RUnlock() if ok { + r.LocalLinkAddress = n.linkEP.LinkAddress() + r.RemoteLinkAddress = remote r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. ref.ep.HandlePacket(&r, pkt) ref.decRef() - } else { - // n doesn't have a destination endpoint. - // Send the packet out of n. - pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) - pkt.Data.RemoveFirst() - - // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - } else { - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) + r.Release() + return + } + + // n doesn't have a destination endpoint. + // Send the packet out of n. + // TODO(b/128629022): move this logic to route.WritePacket. + if ch, err := r.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) + // forwarder will release route. + return } + n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() + r.Release() + return } + + // The link-address resolution finished immediately. + n.forwardPacket(&r, protocol, pkt) + r.Release() return } @@ -1240,6 +1287,35 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } } +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + + firstData := pkt.Data.First() + pkt.Data.RemoveFirst() + + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { + pkt.Header = buffer.NewPrependableFromView(firstData) + } else { + firstDataLen := len(firstData) + + // pkt.Header should have enough capacity to hold n.linkEP's headers. + pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) + + // TODO(b/151227689): avoid copying the packet when forwarding + if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) + } + } + + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return + } + + n.stats.Tx.Packets.Increment() + n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) +} + // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index f9fd8f18f..fa28b46b1 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -401,6 +401,9 @@ type LinkEndpoint interface { // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. + // + // Attach will be called with a nil dispatcher if the receiver's associated + // NIC is being removed. Attach(dispatcher NetworkDispatcher) // IsAttached returns whether a NetworkDispatcher is attached to the diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 13354d884..6f423874a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -462,6 +462,10 @@ type Stack struct { // opaqueIIDOpts hold the options for generating opaque interface identifiers // (IIDs) as outlined by RFC 7217. opaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // forwarder holds the packets that wait for their link-address resolutions + // to complete, and forwards them when each resolution is done. + forwarder *forwardQueue } // UniqueID is an abstract generator of unique identifiers. @@ -641,6 +645,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, opaqueIIDOpts: opts.OpaqueIIDOpts, + forwarder: newForwardQueue(), } // Add specified network protocols. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e15db40fb..9836b340f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -255,7 +255,7 @@ type linkEPWithMockedAttach struct { // Attach implements stack.LinkEndpoint.Attach. func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { l.LinkEndpoint.Attach(d) - l.attached = true + l.attached = d != nil } func (l *linkEPWithMockedAttach) isAttached() bool { @@ -566,7 +566,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) } if !e.isAttached() { - t.Fatalf("link endpoint not attached to a network disatcher") + t.Fatal("link endpoint not attached to a network dispatcher") } }) } @@ -631,196 +631,240 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { checkNIC(false) } -func TestRoutesWithDisabledNIC(t *testing.T) { - const unspecifiedNIC = 0 - const nicID1 = 1 - const nicID2 = 2 - +func TestRemoveUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) - ep1 := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) } +} - addr1 := tcpip.Address("\x01") - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) - } +func TestRemoveNIC(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - ep2 := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + e := linkEPWithMockedAttach{ + LinkEndpoint: loopback.New(), + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addr2 := tcpip.Address("\x02") - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + // NIC should be present in NICInfo and attached to a NetworkDispatcher. + allNICInfo := s.NICInfo() + if _, ok := allNICInfo[nicID]; !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } + if !e.isAttached() { + t.Fatal("link endpoint not attached to a network dispatcher") } - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, - {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, - }) + // Removing a NIC should remove it from NICInfo and e should be detached from + // the NetworkDispatcher. + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) + } + if nicInfo, ok := s.NICInfo()[nicID]; ok { + t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) } + if e.isAttached() { + t.Error("link endpoint for removed NIC still attached to a network dispatcher") + } +} - // Test routes to odd address. - testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) - testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) - testRoute(t, s, nicID1, addr1, "\x05", addr1) +func TestRouteWithDownNIC(t *testing.T) { + tests := []struct { + name string + downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + }{ + { + name: "Disabled NIC", + downFn: (*stack.Stack).DisableNIC, + upFn: (*stack.Stack).EnableNIC, + }, + + // Once a NIC is removed, it cannot be brought up. + { + name: "Removed NIC", + downFn: (*stack.Stack).RemoveNIC, + }, + } - // Test routes to even address. - testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) - testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) - testRoute(t, s, nicID2, addr2, "\x06", addr2) - - // Disabling NIC1 should result in no routes to odd addresses. Routes to even - // addresses should continue to be available as NIC2 is still enabled. - if err := s.DisableNIC(nicID1); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) - } - nic1Dst := tcpip.Address("\x05") - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - nic2Dst := tcpip.Address("\x06") - testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) - testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) - testRoute(t, s, nicID2, addr2, nic2Dst, addr2) - - // Disabling NIC2 should result in no routes to even addresses. No route - // should be available to any address as routes to odd addresses were made - // unavailable by disabling NIC1 above. - if err := s.DisableNIC(nicID2); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) - } - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) - - // Enabling NIC1 should make routes to odd addresses available again. Routes - // to even addresses should continue to be unavailable as NIC2 is still - // disabled. - if err := s.EnableNIC(nicID1); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) - } - testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) - testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) - testRoute(t, s, nicID1, addr1, nic1Dst, addr1) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) -} - -func TestRouteWritePacketWithDisabledNIC(t *testing.T) { const unspecifiedNIC = 0 const nicID1 = 1 const nicID2 = 2 + const addr1 = tcpip.Address("\x01") + const addr2 = tcpip.Address("\x02") + const nic1Dst = tcpip.Address("\x05") + const nic2Dst = tcpip.Address("\x06") + + setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } - ep1 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } - addr1 := tcpip.Address("\x01") - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) - } + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } - ep2 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } - addr2 := tcpip.Address("\x02") - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + return s, ep1, ep2 } - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) + // Tests that routes through a down NIC are not used when looking up a route + // for a destination. + t.Run("Find", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, _, _ := setup(t) + + // Test routes to odd address. + testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) + testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) + testRoute(t, s, nicID1, addr1, "\x05", addr1) + + // Test routes to even address. + testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) + testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) + testRoute(t, s, nicID2, addr2, "\x06", addr2) + + // Bringing NIC1 down should result in no routes to odd addresses. Routes to + // even addresses should continue to be available as NIC2 is still up. + if err := test.downFn(s, nicID1); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID1, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) + testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) + testRoute(t, s, nicID2, addr2, nic2Dst, addr2) + + // Bringing NIC2 down should result in no routes to even addresses. No + // route should be available to any address as routes to odd addresses + // were made unavailable by bringing NIC1 down above. + if err := test.downFn(s, nicID2); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID2, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + + if upFn := test.upFn; upFn != nil { + // Bringing NIC1 up should make routes to odd addresses available + // again. Routes to even addresses should continue to be unavailable + // as NIC2 is still down. + if err := upFn(s, nicID1); err != nil { + t.Fatalf("test.upFn(_, %d): %s", nicID1, err) + } + testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) + testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) + testRoute(t, s, nicID1, addr1, nic1Dst, addr1) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + } + }) } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, - {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, - }) - } + }) - nic1Dst := tcpip.Address("\x05") - r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) - } - defer r1.Release() + // Tests that writing a packet using a Route through a down NIC fails. + t.Run("WritePacket", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, ep1, ep2 := setup(t) - nic2Dst := tcpip.Address("\x06") - r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) - } - defer r2.Release() + r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) + } + defer r1.Release() - // If we failed to get routes r1 or r2, we cannot proceed with the test. - if t.Failed() { - t.FailNow() - } + r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) + } + defer r2.Release() - buf := buffer.View([]byte{1}) - testSend(t, r1, ep1, buf) - testSend(t, r2, ep2, buf) + // If we failed to get routes r1 or r2, we cannot proceed with the test. + if t.Failed() { + t.FailNow() + } - // Writes with Routes that use the disabled NIC1 should fail. - if err := s.DisableNIC(nicID1); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) - } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testSend(t, r2, ep2, buf) + buf := buffer.View([]byte{1}) + testSend(t, r1, ep1, buf) + testSend(t, r2, ep2, buf) - // Writes with Routes that use the disabled NIC2 should fail. - if err := s.DisableNIC(nicID2); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) - } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + // Writes with Routes that use NIC1 after being brought down should fail. + if err := test.downFn(s, nicID1); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID1, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testSend(t, r2, ep2, buf) - // Writes with Routes that use the re-enabled NIC1 should succeed. - // TODO(b/147015577): Should we instead completely invalidate all Routes that - // were bound to a disabled NIC at some point? - if err := s.EnableNIC(nicID1); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) - } - testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + // Writes with Routes that use NIC2 after being brought down should fail. + if err := test.downFn(s, nicID2); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID2, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + + if upFn := test.upFn; upFn != nil { + // Writes with Routes that use NIC1 after being brought up should + // succeed. + // + // TODO(b/147015577): Should we instead completely invalidate all + // Routes that were bound to a NIC that was brought down at some + // point? + if err := upFn(s, nicID1); err != nil { + t.Fatalf("test.upFn(_, %d): %s", nicID1, err) + } + testSend(t, r1, ep1, buf) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + } + }) + } + }) } func TestRoutes(t *testing.T) { @@ -2240,56 +2284,84 @@ func TestNICStats(t *testing.T) { } func TestNICForwarding(t *testing.T) { - // Create a stack with the fake network protocol, two NICs, each with - // an address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - s.SetForwarding(true) + const nicID1 = 1 + const nicID2 = 2 + const dstAddr = tcpip.Address("\x03") - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + tests := []struct { + name string + headerLen uint16 + }{ + { + name: "Zero header length", + }, + { + name: "Non-zero header length", + headerLen: 16, + }, } - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + s.SetForwarding(true) - // Route all packets to address 3 to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) - } + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) + } - // Send a packet to address 3. - buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + ep2 := channelLinkWithHeaderLength{ + Endpoint: channel.New(10, defaultMTU, ""), + headerLength: test.headerLen, + } + if err := s.CreateNIC(nicID2, &ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) + } - if _, ok := ep2.Read(); !ok { - t.Fatal("Packet not forwarded") - } + // Route all packets to dstAddr to NIC 2. + { + subnet, err := tcpip.NewSubnet(dstAddr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) + } - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } + // Send a packet to dstAddr. + buf := buffer.NewView(30) + buf[0] = dstAddr[0] + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) - if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not forwarded") + } + + // Test that the link's MaxHeaderLength is honoured. + if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { + t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + } + + // Test that forwarding increments Tx stats correctly. + if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) + } + + if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + }) } } @@ -3010,6 +3082,50 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { } } +// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6 +// address after leaving its solicited node multicast address does not result in +// an error. +func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) + } + + // The NIC should have joined addr1's solicited node multicast address. + snmc := header.SolicitedNodeAddr(addr1) + in, err := s.IsInGroup(nicID, snmc) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) + } + if !in { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc) + } + + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err) + } + in, err = s.IsInGroup(nicID, snmc) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) + } + if in { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc) + } + + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) + } +} + func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { const nicID = 1 diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 778c0a4d6..d4c0359e8 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,9 +15,9 @@ package stack import ( + "container/heap" "fmt" "math/rand" - "sort" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -141,16 +141,17 @@ func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto t epsByNic.mu.Lock() defer epsByNic.mu.Unlock() - if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { - // There was already a bind. - return multiPortEp.singleRegisterEndpoint(t, reusePort) + multiPortEp, ok := epsByNic.endpoints[bindToDevice] + if !ok { + multiPortEp = &multiPortEndpoint{ + demux: d, + netProto: netProto, + transProto: transProto, + reuse: reusePort, + } + epsByNic.endpoints[bindToDevice] = multiPortEp } - // This is a new binding. - multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto} - multiPortEp.endpointsMap = make(map[TransportEndpoint]int) - multiPortEp.reuse = reusePort - epsByNic.endpoints[bindToDevice] = multiPortEp return multiPortEp.singleRegisterEndpoint(t, reusePort) } @@ -222,6 +223,35 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } +type transportEndpointHeap []TransportEndpoint + +var _ heap.Interface = (*transportEndpointHeap)(nil) + +func (h *transportEndpointHeap) Len() int { + return len(*h) +} + +func (h *transportEndpointHeap) Less(i, j int) bool { + return (*h)[i].UniqueID() < (*h)[j].UniqueID() +} + +func (h *transportEndpointHeap) Swap(i, j int) { + (*h)[i], (*h)[j] = (*h)[j], (*h)[i] +} + +func (h *transportEndpointHeap) Push(x interface{}) { + *h = append(*h, x.(TransportEndpoint)) +} + +func (h *transportEndpointHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + old[n-1] = nil + *h = old[:n-1] + return x +} + // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. @@ -237,15 +267,14 @@ type multiPortEndpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - endpointsArr []TransportEndpoint - endpointsMap map[TransportEndpoint]int + endpoints transportEndpointHeap // reuse indicates if more than one endpoint is allowed. reuse bool } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + eps := append([]TransportEndpoint(nil), ep.endpoints...) ep.mu.RUnlock() return eps } @@ -262,8 +291,8 @@ func reciprocalScale(val, n uint32) uint32 { // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { - if len(mpep.endpointsArr) == 1 { - return mpep.endpointsArr[0] + if len(mpep.endpoints) == 1 { + return mpep.endpoints[0] } payload := []byte{ @@ -279,29 +308,26 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) - return mpep.endpointsArr[idx] + idx := reciprocalScale(hash, uint32(len(mpep.endpoints))) + return mpep.endpoints[idx] } func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] - for i, endpoint := range ep.endpointsArr { - // HandlePacket takes ownership of pkt, so each endpoint needs - // its own copy except for the final one. - if i == len(ep.endpointsArr)-1 { - if mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt) - break - } - endpoint.HandlePacket(r, id, pkt) - break - } + // HandlePacket takes ownership of pkt, so each endpoint needs + // its own copy except for the final one. + for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { if mustQueue { queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) - continue + } else { + endpoint.HandlePacket(r, id, pkt.Clone()) } - endpoint.HandlePacket(r, id, pkt.Clone()) + } + if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt) + } else { + endpoint.HandlePacket(r, id, pkt) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -312,26 +338,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo ep.mu.Lock() defer ep.mu.Unlock() - if len(ep.endpointsArr) > 0 { + if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. if !ep.reuse || !reusePort { return tcpip.ErrPortInUse } } - // A new endpoint is added into endpointsArr and its index there is saved in - // endpointsMap. This will allow us to remove endpoint from the array fast. - ep.endpointsMap[t] = len(ep.endpointsArr) - ep.endpointsArr = append(ep.endpointsArr, t) + heap.Push(&ep.endpoints, t) - // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints - // can be restored in the same order. - sort.Slice(ep.endpointsArr, func(i, j int) bool { - return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID() - }) - for i, e := range ep.endpointsArr { - ep.endpointsMap[e] = i - } return nil } @@ -340,21 +355,13 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { ep.mu.Lock() defer ep.mu.Unlock() - idx, ok := ep.endpointsMap[t] - if !ok { - return false - } - delete(ep.endpointsMap, t) - l := len(ep.endpointsArr) - if l > 1 { - // The last endpoint in endpointsArr is moved instead of the deleted one. - lastEp := ep.endpointsArr[l-1] - ep.endpointsArr[idx] = lastEp - ep.endpointsMap[lastEp] = idx - ep.endpointsArr = ep.endpointsArr[0 : l-1] - return false + for i, endpoint := range ep.endpoints { + if endpoint == t { + heap.Remove(&ep.endpoints, i) + break + } } - return true + return len(ep.endpoints) == 0 } func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { @@ -371,17 +378,14 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - if epsByNic, ok := eps.endpoints[id]; ok { - // There was already a binding. - return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) - } - - // This is a new binding. - epsByNic := &endpointsByNic{ - endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: rand.Uint32(), + epsByNic, ok := eps.endpoints[id] + if !ok { + epsByNic = &endpointsByNic{ + endpoints: make(map[tcpip.NICID]*multiPortEndpoint), + seed: rand.Uint32(), + } + eps.endpoints[id] = epsByNic } - eps.endpoints[id] = epsByNic return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } @@ -396,14 +400,6 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN } } -var loopbackSubnet = func() tcpip.Subnet { - sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00") - if err != nil { - panic(err) - } - return sn -}() - // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. @@ -413,61 +409,45 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - eps.mu.RLock() - - // Determine which transport endpoint or endpoints to deliver this packet to. // If the packet is a UDP broadcast or multicast, then find all matching - // transport endpoints. If the packet is a TCP packet with a non-unicast - // source or destination address, then do nothing further and instruct - // the caller to do the same. - var destEps []*endpointsByNic - switch protocol { - case header.UDPProtocolNumber: - if isMulticastOrBroadcast(id.LocalAddress) { - destEps = d.findAllEndpointsLocked(eps, id) - break - } - - if ep := d.findEndpointLocked(eps, id); ep != nil { - destEps = append(destEps, ep) + // transport endpoints. + if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + eps.mu.RLock() + destEPs := d.findAllEndpointsLocked(eps, id) + eps.mu.RUnlock() + // Fail if we didn't find at least one matching transport endpoint. + if len(destEPs) == 0 { + r.Stats().UDP.UnknownPortErrors.Increment() + return false } - - case header.TCPProtocolNumber: - if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) { - // TCP can only be used to communicate between a single - // source and a single destination; the addresses must - // be unicast. - eps.mu.RUnlock() - r.Stats().TCP.InvalidSegmentsReceived.Increment() - return true + // handlePacket takes ownership of pkt, so each endpoint needs its own + // copy except for the final one. + for _, ep := range destEPs[:len(destEPs)-1] { + ep.handlePacket(r, id, pkt.Clone()) } + destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + return true + } - fallthrough - - default: - if ep := d.findEndpointLocked(eps, id); ep != nil { - destEps = append(destEps, ep) - } + // If the packet is a TCP packet with a non-unicast source or destination + // address, then do nothing further and instruct the caller to do the same. + if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + // TCP can only be used to communicate between a single source and a + // single destination; the addresses must be unicast. + r.Stats().TCP.InvalidSegmentsReceived.Increment() + return true } + eps.mu.RLock() + ep := d.findEndpointLocked(eps, id) eps.mu.RUnlock() - - // Fail if we didn't find at least one matching transport endpoint. - if len(destEps) == 0 { - // UDP packet could not be delivered to an unknown destination port. + if ep == nil { if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() } return false } - - // HandlePacket takes ownership of pkt, so each endpoint needs its own - // copy except for the final one. - for _, ep := range destEps[:len(destEps)-1] { - ep.handlePacket(r, id, pkt.Clone()) - } - destEps[len(destEps)-1].handlePacket(r, id, pkt) - + ep.handlePacket(r, id, pkt) return true } @@ -519,11 +499,17 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return true } -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { - var matchedEPs []*endpointsByNic +// iterEndpointsLocked yields all endpointsByNic in eps that match id, in +// descending order of match quality. If a call to yield returns false, +// iterEndpointsLocked stops iteration and returns immediately. +// +// Preconditions: eps.mu must be locked. +func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } // Try to find a match with the id minus the local address. @@ -531,7 +517,9 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } // Try to find a match with the id minus the remote part. @@ -539,14 +527,26 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr nid.RemoteAddress = "" nid.RemotePort = 0 if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } // Try to find a match with only the local port. nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } +} + +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { + var matchedEPs []*endpointsByNic + d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { + matchedEPs = append(matchedEPs, ep) + return true + }) return matchedEPs } @@ -584,10 +584,12 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN // findEndpointLocked returns the endpoint that most closely matches the given // id. func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { - if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 { - return matchedEPs[0] - } - return nil + var matchedEP *endpointsByNic + d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { + matchedEP = ep + return false + }) + return matchedEP } // registerRawEndpoint registers the given endpoint with the dispatcher such @@ -601,8 +603,8 @@ func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNum } eps.mu.Lock() - defer eps.mu.Unlock() eps.rawEndpoints = append(eps.rawEndpoints, ep) + eps.mu.Unlock() return nil } @@ -616,13 +618,16 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN } eps.mu.Lock() - defer eps.mu.Unlock() for i, rawEP := range eps.rawEndpoints { if rawEP == ep { - eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) - return + lastIdx := len(eps.rawEndpoints) - 1 + eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx] + eps.rawEndpoints[lastIdx] = nil + eps.rawEndpoints = eps.rawEndpoints[:lastIdx] + break } } + eps.mu.Unlock() } func isMulticastOrBroadcast(addr tcpip.Address) bool { diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 5e9237de9..0e3e239c5 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -167,8 +167,18 @@ func TestTransportDemuxerRegister(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + tEP, ok := ep.(stack.TransportEndpoint) + if !ok { + t.Fatalf("%T does not implement stack.TransportEndpoint", ep) + } + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want { t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) } }) diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 272e8f570..a32f9eacf 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -32,6 +32,7 @@ go_library( srcs = [ "accept.go", "connect.go", + "connect_unsafe.go", "cubic.go", "cubic_state.go", "dispatcher.go", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index ae4f3f3a9..be86af502 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -624,17 +624,17 @@ func parseSynSegmentOptions(s *segment) header.TCPSynOptions { var optionPool = sync.Pool{ New: func() interface{} { - return make([]byte, maxOptionSize) + return &[maxOptionSize]byte{} }, } func getOptions() []byte { - return optionPool.Get().([]byte) + return (*optionPool.Get().(*[maxOptionSize]byte))[:] } func putOptions(options []byte) { // Reslice to full capacity. - optionPool.Put(options[0:cap(options)]) + optionPool.Put(optionsToArray(options)) } func makeSynOptions(opts header.TCPSynOptions) []byte { @@ -1639,6 +1639,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { const timeWaitDone = 3 s := sleep.Sleeper{} + defer s.Done() s.AddWaker(&e.newSegmentWaker, newSegment) s.AddWaker(&e.notificationWaker, notification) diff --git a/pkg/tcpip/transport/tcp/connect_unsafe.go b/pkg/tcpip/transport/tcp/connect_unsafe.go new file mode 100644 index 000000000..cfc304616 --- /dev/null +++ b/pkg/tcpip/transport/tcp/connect_unsafe.go @@ -0,0 +1,30 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "reflect" + "unsafe" +) + +// optionsToArray converts a slice of capacity >-= maxOptionSize to an array. +// +// optionsToArray panics if the capacity of options is smaller than +// maxOptionSize. +func optionsToArray(options []byte) *[maxOptionSize]byte { + // Reslice to full capacity. + options = options[0:maxOptionSize] + return (*[maxOptionSize]byte)(unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&options)).Data)) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 40cc664c0..5187a5e25 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -825,6 +825,7 @@ func (e *endpoint) Abort() { func (e *endpoint) Close() { e.mu.Lock() closed := e.closed + e.closed = true e.mu.Unlock() if closed { return @@ -833,13 +834,7 @@ func (e *endpoint) Close() { // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) - e.closeNoShutdown() -} -// closeNoShutdown closes the endpoint without doing a full shutdown. This is -// used when a connection needs to be aborted with a RST and we want to skip -// a full 4 way TCP shutdown. -func (e *endpoint) closeNoShutdown() { e.mu.Lock() // For listening sockets, we always release ports inline so that they @@ -858,11 +853,8 @@ func (e *endpoint) closeNoShutdown() { e.boundPortFlags = ports.Flags{} } - // Mark endpoint as closed. - e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. - tcpip.AddDanglingEndpoint(e) switch e.EndpointState() { // Sockets in StateSynRecv state(passive connections) are closed when // the handshake fails or if the listening socket is closed while @@ -876,6 +868,9 @@ func (e *endpoint) closeNoShutdown() { // do nothing. default: e.workerCleanup = true + tcpip.AddDanglingEndpoint(e) + // Worker will remove the dangling endpoint when the endpoint + // goroutine terminates. e.notifyProtocolGoroutine(notifyClose) } @@ -2117,10 +2112,13 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // Close for write. if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { e.sndBufMu.Lock() - if e.sndClosed { // Already closed. e.sndBufMu.Unlock() + if e.EndpointState() == StateTimeWait { + e.mu.Unlock() + return tcpip.ErrNotConnected + } break } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 1c10da5ca..5d0bc4f72 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -56,9 +56,9 @@ type segment struct { options []byte `state:".([]byte)"` hasNewSACKInfo bool rcvdTime time.Time `state:".(unixTime)"` - // xmitTime is the last transmit time of this segment. A zero value - // indicates that the segment has yet to be transmitted. - xmitTime time.Time `state:".(unixTime)"` + // xmitTime is the last transmit time of this segment. + xmitTime time.Time `state:".(unixTime)"` + xmitCount uint32 } func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment { diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go index e28f213ba..8d3ddce4b 100644 --- a/pkg/tcpip/transport/tcp/segment_heap.go +++ b/pkg/tcpip/transport/tcp/segment_heap.go @@ -14,21 +14,25 @@ package tcp +import "container/heap" + type segmentHeap []*segment +var _ heap.Interface = (*segmentHeap)(nil) + // Len returns the length of h. -func (h segmentHeap) Len() int { - return len(h) +func (h *segmentHeap) Len() int { + return len(*h) } // Less determines whether the i-th element of h is less than the j-th element. -func (h segmentHeap) Less(i, j int) bool { - return h[i].sequenceNumber.LessThan(h[j].sequenceNumber) +func (h *segmentHeap) Less(i, j int) bool { + return (*h)[i].sequenceNumber.LessThan((*h)[j].sequenceNumber) } // Swap swaps the i-th and j-th elements of h. -func (h segmentHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] +func (h *segmentHeap) Swap(i, j int) { + (*h)[i], (*h)[j] = (*h)[j], (*h)[i] } // Push adds x as the last element of h. diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index b74b61e7d..657c3146e 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1229,7 +1229,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { // sendSegment sends the specified segment. func (s *sender) sendSegment(seg *segment) *tcpip.Error { - if !seg.xmitTime.IsZero() { + if seg.xmitCount > 0 { s.ep.stack.Stats().TCP.Retransmits.Increment() s.ep.stats.SendErrors.Retransmits.Increment() if s.sndCwnd < s.sndSsthresh { @@ -1237,6 +1237,7 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error { } } seg.xmitTime = time.Now() + seg.xmitCount++ return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) } |