diff options
Diffstat (limited to 'pkg')
534 files changed, 35034 insertions, 55870 deletions
diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD deleted file mode 100644 index 32c601a03..000000000 --- a/pkg/abi/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "abi", - srcs = [ - "abi.go", - "abi_linux.go", - "flag.go", - ], - importpath = "gvisor.dev/gvisor/pkg/abi", - visibility = ["//:sandbox"], -) diff --git a/pkg/abi/abi_state_autogen.go b/pkg/abi/abi_state_autogen.go new file mode 100755 index 000000000..4f94570e5 --- /dev/null +++ b/pkg/abi/abi_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package abi + diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD deleted file mode 100644 index 6960add0c..000000000 --- a/pkg/abi/linux/BUILD +++ /dev/null @@ -1,64 +0,0 @@ -# Package linux contains the constants and types needed to inferface with a -# Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix -# when the host OS may not be Linux. - -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "linux", - srcs = [ - "aio.go", - "ashmem.go", - "audit.go", - "binder.go", - "bpf.go", - "capability.go", - "dev.go", - "elf.go", - "errors.go", - "eventfd.go", - "exec.go", - "fcntl.go", - "file.go", - "fs.go", - "futex.go", - "inotify.go", - "ioctl.go", - "ip.go", - "ipc.go", - "limits.go", - "linux.go", - "mm.go", - "netdevice.go", - "netfilter.go", - "netlink.go", - "netlink_route.go", - "poll.go", - "prctl.go", - "ptrace.go", - "rusage.go", - "sched.go", - "seccomp.go", - "sem.go", - "shm.go", - "signal.go", - "socket.go", - "splice.go", - "tcp.go", - "time.go", - "timer.go", - "tty.go", - "uio.go", - "utsname.go", - "wait.go", - ], - importpath = "gvisor.dev/gvisor/pkg/abi/linux", - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi", - "//pkg/binary", - "//pkg/bits", - ], -) diff --git a/pkg/abi/linux/linux_state_autogen.go b/pkg/abi/linux/linux_state_autogen.go new file mode 100755 index 000000000..048ee5a25 --- /dev/null +++ b/pkg/abi/linux/linux_state_autogen.go @@ -0,0 +1,68 @@ +// automatically generated by stateify. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *BPFInstruction) beforeSave() {} +func (x *BPFInstruction) save(m state.Map) { + x.beforeSave() + m.Save("OpCode", &x.OpCode) + m.Save("JumpIfTrue", &x.JumpIfTrue) + m.Save("JumpIfFalse", &x.JumpIfFalse) + m.Save("K", &x.K) +} + +func (x *BPFInstruction) afterLoad() {} +func (x *BPFInstruction) load(m state.Map) { + m.Load("OpCode", &x.OpCode) + m.Load("JumpIfTrue", &x.JumpIfTrue) + m.Load("JumpIfFalse", &x.JumpIfFalse) + m.Load("K", &x.K) +} + +func (x *KernelTermios) beforeSave() {} +func (x *KernelTermios) save(m state.Map) { + x.beforeSave() + m.Save("InputFlags", &x.InputFlags) + m.Save("OutputFlags", &x.OutputFlags) + m.Save("ControlFlags", &x.ControlFlags) + m.Save("LocalFlags", &x.LocalFlags) + m.Save("LineDiscipline", &x.LineDiscipline) + m.Save("ControlCharacters", &x.ControlCharacters) + m.Save("InputSpeed", &x.InputSpeed) + m.Save("OutputSpeed", &x.OutputSpeed) +} + +func (x *KernelTermios) afterLoad() {} +func (x *KernelTermios) load(m state.Map) { + m.Load("InputFlags", &x.InputFlags) + m.Load("OutputFlags", &x.OutputFlags) + m.Load("ControlFlags", &x.ControlFlags) + m.Load("LocalFlags", &x.LocalFlags) + m.Load("LineDiscipline", &x.LineDiscipline) + m.Load("ControlCharacters", &x.ControlCharacters) + m.Load("InputSpeed", &x.InputSpeed) + m.Load("OutputSpeed", &x.OutputSpeed) +} + +func (x *WindowSize) beforeSave() {} +func (x *WindowSize) save(m state.Map) { + x.beforeSave() + m.Save("Rows", &x.Rows) + m.Save("Cols", &x.Cols) +} + +func (x *WindowSize) afterLoad() {} +func (x *WindowSize) load(m state.Map) { + m.Load("Rows", &x.Rows) + m.Load("Cols", &x.Cols) +} + +func init() { + state.Register("linux.BPFInstruction", (*BPFInstruction)(nil), state.Fns{Save: (*BPFInstruction).save, Load: (*BPFInstruction).load}) + state.Register("linux.KernelTermios", (*KernelTermios)(nil), state.Fns{Save: (*KernelTermios).save, Load: (*KernelTermios).load}) + state.Register("linux.WindowSize", (*WindowSize)(nil), state.Fns{Save: (*WindowSize).save, Load: (*WindowSize).load}) +} diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD deleted file mode 100644 index 7831471ca..000000000 --- a/pkg/amutex/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "amutex", - srcs = ["amutex.go"], - importpath = "gvisor.dev/gvisor/pkg/amutex", - visibility = ["//:sandbox"], -) - -go_test( - name = "amutex_test", - size = "small", - srcs = ["amutex_test.go"], - embed = [":amutex"], - tags = ["flaky"], -) diff --git a/pkg/amutex/amutex_state_autogen.go b/pkg/amutex/amutex_state_autogen.go new file mode 100755 index 000000000..5651ae4e9 --- /dev/null +++ b/pkg/amutex/amutex_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package amutex + diff --git a/pkg/amutex/amutex_test.go b/pkg/amutex/amutex_test.go deleted file mode 100644 index 211bdda4b..000000000 --- a/pkg/amutex/amutex_test.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package amutex - -import ( - "sync" - "testing" - "time" -) - -type sleeper struct { - ch chan struct{} -} - -func (s *sleeper) SleepStart() <-chan struct{} { - return s.ch -} - -func (*sleeper) SleepFinish(bool) { -} - -func (s *sleeper) Interrupted() bool { - return len(s.ch) != 0 -} - -func TestMutualExclusion(t *testing.T) { - var m AbortableMutex - m.Init() - - // Test mutual exclusion by running "gr" goroutines concurrently, and - // have each one increment a counter "iters" times within the critical - // section established by the mutex. - // - // If at the end of the counter is not gr * iters, then we know that - // goroutines ran concurrently within the critical section. - // - // If one of the goroutines doesn't complete, it's likely a bug that - // causes to to wait forever. - const gr = 1000 - const iters = 100000 - v := 0 - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(1) - go func() { - for j := 0; j < iters; j++ { - m.Lock(nil) - v++ - m.Unlock() - } - wg.Done() - }() - } - - wg.Wait() - - if v != gr*iters { - t.Fatalf("Bad count: got %v, want %v", v, gr*iters) - } -} - -func TestAbortWait(t *testing.T) { - var s sleeper - var m AbortableMutex - m.Init() - - // Lock the mutex. - m.Lock(&s) - - // Lock again, but this time cancel after 500ms. - s.ch = make(chan struct{}, 1) - go func() { - time.Sleep(500 * time.Millisecond) - s.ch <- struct{}{} - }() - if v := m.Lock(&s); v { - t.Fatalf("Lock succeeded when it should have failed") - } - - // Lock again, but cancel right away. - s.ch <- struct{}{} - if v := m.Lock(&s); v { - t.Fatalf("Lock succeeded when it should have failed") - } -} diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD deleted file mode 100644 index 47ab65346..000000000 --- a/pkg/atomicbitops/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "atomicbitops", - srcs = [ - "atomic_bitops.go", - "atomic_bitops_amd64.s", - "atomic_bitops_common.go", - ], - importpath = "gvisor.dev/gvisor/pkg/atomicbitops", - visibility = ["//:sandbox"], -) - -go_test( - name = "atomicbitops_test", - size = "small", - srcs = ["atomic_bitops_test.go"], - embed = [":atomicbitops"], -) diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomic_bitops_test.go deleted file mode 100644 index 965e9be79..000000000 --- a/pkg/atomicbitops/atomic_bitops_test.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package atomicbitops - -import ( - "runtime" - "sync" - "testing" -) - -const iterations = 100 - -func detectRaces32(val, target uint32, fn func(*uint32, uint32)) bool { - runtime.GOMAXPROCS(100) - for n := 0; n < iterations; n++ { - x := val - var wg sync.WaitGroup - for i := uint32(0); i < 32; i++ { - wg.Add(1) - go func(a *uint32, i uint32) { - defer wg.Done() - fn(a, uint32(1<<i)) - }(&x, i) - } - wg.Wait() - if x != target { - return true - } - } - return false -} - -func detectRaces64(val, target uint64, fn func(*uint64, uint64)) bool { - runtime.GOMAXPROCS(100) - for n := 0; n < iterations; n++ { - x := val - var wg sync.WaitGroup - for i := uint64(0); i < 64; i++ { - wg.Add(1) - go func(a *uint64, i uint64) { - defer wg.Done() - fn(a, uint64(1<<i)) - }(&x, i) - } - wg.Wait() - if x != target { - return true - } - } - return false -} - -func TestOrUint32(t *testing.T) { - if detectRaces32(0x0, 0xffffffff, OrUint32) { - t.Error("Data race detected!") - } -} - -func TestAndUint32(t *testing.T) { - if detectRaces32(0xf0f0f0f0, 0x00000000, AndUint32) { - t.Error("Data race detected!") - } -} - -func TestXorUint32(t *testing.T) { - if detectRaces32(0xf0f0f0f0, 0x0f0f0f0f, XorUint32) { - t.Error("Data race detected!") - } -} - -func TestOrUint64(t *testing.T) { - if detectRaces64(0x0, 0xffffffffffffffff, OrUint64) { - t.Error("Data race detected!") - } -} - -func TestAndUint64(t *testing.T) { - if detectRaces64(0xf0f0f0f0f0f0f0f0, 0x0, AndUint64) { - t.Error("Data race detected!") - } -} - -func TestXorUint64(t *testing.T) { - if detectRaces64(0xf0f0f0f0f0f0f0f0, 0x0f0f0f0f0f0f0f0f, XorUint64) { - t.Error("Data race detected!") - } -} - -func TestCompareAndSwapUint32(t *testing.T) { - tests := []struct { - name string - prev uint32 - old uint32 - new uint32 - next uint32 - }{ - { - name: "Successful compare-and-swap with prev == new", - prev: 10, - old: 10, - new: 10, - next: 10, - }, - { - name: "Successful compare-and-swap with prev != new", - prev: 20, - old: 20, - new: 22, - next: 22, - }, - { - name: "Failed compare-and-swap with prev == new", - prev: 31, - old: 30, - new: 31, - next: 31, - }, - { - name: "Failed compare-and-swap with prev != new", - prev: 41, - old: 40, - new: 42, - next: 41, - }, - } - for _, test := range tests { - val := test.prev - prev := CompareAndSwapUint32(&val, test.old, test.new) - if got, want := prev, test.prev; got != want { - t.Errorf("%s: incorrect returned previous value: got %d, expected %d", test.name, got, want) - } - if got, want := val, test.next; got != want { - t.Errorf("%s: incorrect value stored in val: got %d, expected %d", test.name, got, want) - } - } -} - -func TestCompareAndSwapUint64(t *testing.T) { - tests := []struct { - name string - prev uint64 - old uint64 - new uint64 - next uint64 - }{ - { - name: "Successful compare-and-swap with prev == new", - prev: 0x100000000, - old: 0x100000000, - new: 0x100000000, - next: 0x100000000, - }, - { - name: "Successful compare-and-swap with prev != new", - prev: 0x200000000, - old: 0x200000000, - new: 0x200000002, - next: 0x200000002, - }, - { - name: "Failed compare-and-swap with prev == new", - prev: 0x300000001, - old: 0x300000000, - new: 0x300000001, - next: 0x300000001, - }, - { - name: "Failed compare-and-swap with prev != new", - prev: 0x400000001, - old: 0x400000000, - new: 0x400000002, - next: 0x400000001, - }, - } - for _, test := range tests { - val := test.prev - prev := CompareAndSwapUint64(&val, test.old, test.new) - if got, want := prev, test.prev; got != want { - t.Errorf("%s: incorrect returned previous value: got %d, expected %d", test.name, got, want) - } - if got, want := val, test.next; got != want { - t.Errorf("%s: incorrect value stored in val: got %d, expected %d", test.name, got, want) - } - } -} - -func TestIncUnlessZeroInt32(t *testing.T) { - for _, test := range []struct { - initial int32 - final int32 - ret bool - }{ - { - initial: 0, - final: 0, - ret: false, - }, - { - initial: 1, - final: 2, - ret: true, - }, - { - initial: 2, - final: 3, - ret: true, - }, - } { - val := test.initial - if got, want := IncUnlessZeroInt32(&val), test.ret; got != want { - t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want) - } - if got, want := val, test.final; got != want { - t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want) - } - } -} - -func TestDecUnlessOneInt32(t *testing.T) { - for _, test := range []struct { - initial int32 - final int32 - ret bool - }{ - { - initial: 0, - final: -1, - ret: true, - }, - { - initial: 1, - final: 1, - ret: false, - }, - { - initial: 2, - final: 1, - ret: true, - }, - } { - val := test.initial - if got, want := DecUnlessOneInt32(&val), test.ret; got != want { - t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want) - } - if got, want := val, test.final; got != want { - t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want) - } - } -} diff --git a/pkg/atomicbitops/atomicbitops_state_autogen.go b/pkg/atomicbitops/atomicbitops_state_autogen.go new file mode 100755 index 000000000..a74ea7d50 --- /dev/null +++ b/pkg/atomicbitops/atomicbitops_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package atomicbitops + diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD deleted file mode 100644 index 09d6c2c1f..000000000 --- a/pkg/binary/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "binary", - srcs = ["binary.go"], - importpath = "gvisor.dev/gvisor/pkg/binary", - visibility = ["//:sandbox"], -) - -go_test( - name = "binary_test", - size = "small", - srcs = ["binary_test.go"], - embed = [":binary"], -) diff --git a/pkg/binary/binary_state_autogen.go b/pkg/binary/binary_state_autogen.go new file mode 100755 index 000000000..e29aeb344 --- /dev/null +++ b/pkg/binary/binary_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package binary + diff --git a/pkg/binary/binary_test.go b/pkg/binary/binary_test.go deleted file mode 100644 index 4d609a438..000000000 --- a/pkg/binary/binary_test.go +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package binary - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - "reflect" - "strings" - "testing" -) - -func newInt32(i int32) *int32 { - return &i -} - -func TestSize(t *testing.T) { - if got, want := Size(uint32(10)), uintptr(4); got != want { - t.Errorf("Got = %d, want = %d", got, want) - } -} - -func TestPanic(t *testing.T) { - tests := []struct { - name string - f func([]byte, binary.ByteOrder, interface{}) - data interface{} - want string - }{ - {"Unmarshal int", Unmarshal, 5, "invalid type: int"}, - {"Unmarshal []int", Unmarshal, []int{5}, "invalid type: int"}, - {"Marshal int", func(_ []byte, bo binary.ByteOrder, d interface{}) { Marshal(nil, bo, d) }, 5, "invalid type: int"}, - {"Marshal int[]", func(_ []byte, bo binary.ByteOrder, d interface{}) { Marshal(nil, bo, d) }, []int{5}, "invalid type: int"}, - {"Unmarshal short buffer", Unmarshal, newInt32(5), "runtime error: index out of range"}, - {"Unmarshal long buffer", func(_ []byte, bo binary.ByteOrder, d interface{}) { Unmarshal(make([]byte, 50), bo, d) }, newInt32(5), "buffer too long by 46 bytes"}, - {"marshal int", func(_ []byte, bo binary.ByteOrder, d interface{}) { marshal(nil, bo, reflect.ValueOf(d)) }, 5, "invalid type: int"}, - {"Size int", func(_ []byte, _ binary.ByteOrder, d interface{}) { Size(d) }, 5, "invalid type: int"}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - defer func() { - r := recover() - if got := fmt.Sprint(r); !strings.HasPrefix(got, test.want) { - t.Errorf("Got recover() = %q, want prefix = %q", got, test.want) - } - }() - - test.f(nil, LittleEndian, test.data) - }) - } -} - -type inner struct { - Field int32 -} - -type outer struct { - Int8 int8 - Int16 int16 - Int32 int32 - Int64 int64 - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - - Slice []int32 - Array [5]int32 - Struct inner -} - -func TestMarshalUnmarshal(t *testing.T) { - want := outer{ - 1, 2, 3, 4, 5, 6, 7, 8, - []int32{9, 10, 11}, - [5]int32{12, 13, 14, 15, 16}, - inner{17}, - } - buf := Marshal(nil, LittleEndian, want) - got := outer{Slice: []int32{0, 0, 0}} - Unmarshal(buf, LittleEndian, &got) - if !reflect.DeepEqual(&got, &want) { - t.Errorf("Got = %#v, want = %#v", got, want) - } -} - -type outerBenchmark struct { - Int8 int8 - Int16 int16 - Int32 int32 - Int64 int64 - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - - Array [5]int32 - Struct inner -} - -func BenchmarkMarshalUnmarshal(b *testing.B) { - b.ReportAllocs() - - in := outerBenchmark{ - 1, 2, 3, 4, 5, 6, 7, 8, - [5]int32{9, 10, 11, 12, 13}, - inner{14}, - } - buf := make([]byte, Size(&in)) - out := outerBenchmark{} - - for i := 0; i < b.N; i++ { - buf := Marshal(buf[:0], LittleEndian, &in) - Unmarshal(buf, LittleEndian, &out) - } -} - -func BenchmarkReadWrite(b *testing.B) { - b.ReportAllocs() - - in := outerBenchmark{ - 1, 2, 3, 4, 5, 6, 7, 8, - [5]int32{9, 10, 11, 12, 13}, - inner{14}, - } - buf := bytes.NewBuffer(make([]byte, binary.Size(&in))) - out := outerBenchmark{} - - for i := 0; i < b.N; i++ { - buf.Reset() - if err := binary.Write(buf, LittleEndian, &in); err != nil { - b.Error("Write:", err) - } - if err := binary.Read(buf, LittleEndian, &out); err != nil { - b.Error("Read:", err) - } - } -} - -type outerPadding struct { - _ int8 - _ int16 - _ int32 - _ int64 - _ uint8 - _ uint16 - _ uint32 - _ uint64 - - _ []int32 - _ [5]int32 - _ inner -} - -func TestMarshalUnmarshalPadding(t *testing.T) { - var want outerPadding - buf := Marshal(nil, LittleEndian, want) - var got outerPadding - Unmarshal(buf, LittleEndian, &got) - if !reflect.DeepEqual(&got, &want) { - t.Errorf("Got = %#v, want = %#v", got, want) - } -} - -// Numbers with bits in every byte that distinguishable in big and little endian. -const ( - want16 = 64<<8 | 128 - want32 = 16<<24 | 32<<16 | want16 - want64 = 1<<56 | 2<<48 | 4<<40 | 8<<32 | want32 -) - -func TestReadWriteUint16(t *testing.T) { - const want = uint16(want16) - var buf bytes.Buffer - if err := WriteUint16(&buf, LittleEndian, want); err != nil { - t.Error("WriteUint16:", err) - } - got, err := ReadUint16(&buf, LittleEndian) - if err != nil { - t.Error("ReadUint16:", err) - } - if got != want { - t.Errorf("got = %d, want = %d", got, want) - } -} - -func TestReadWriteUint32(t *testing.T) { - const want = uint32(want32) - var buf bytes.Buffer - if err := WriteUint32(&buf, LittleEndian, want); err != nil { - t.Error("WriteUint32:", err) - } - got, err := ReadUint32(&buf, LittleEndian) - if err != nil { - t.Error("ReadUint32:", err) - } - if got != want { - t.Errorf("got = %d, want = %d", got, want) - } -} - -func TestReadWriteUint64(t *testing.T) { - const want = uint64(want64) - var buf bytes.Buffer - if err := WriteUint64(&buf, LittleEndian, want); err != nil { - t.Error("WriteUint64:", err) - } - got, err := ReadUint64(&buf, LittleEndian) - if err != nil { - t.Error("ReadUint64:", err) - } - if got != want { - t.Errorf("got = %d, want = %d", got, want) - } -} - -type readWriter struct { - err error -} - -func (rw *readWriter) Write([]byte) (int, error) { - return 0, rw.err -} - -func (rw *readWriter) Read([]byte) (int, error) { - return 0, rw.err -} - -func TestReadWriteError(t *testing.T) { - tests := []struct { - name string - f func(rw io.ReadWriter) error - }{ - {"WriteUint16", func(rw io.ReadWriter) error { return WriteUint16(rw, LittleEndian, 0) }}, - {"ReadUint16", func(rw io.ReadWriter) error { _, err := ReadUint16(rw, LittleEndian); return err }}, - {"WriteUint32", func(rw io.ReadWriter) error { return WriteUint32(rw, LittleEndian, 0) }}, - {"ReadUint32", func(rw io.ReadWriter) error { _, err := ReadUint32(rw, LittleEndian); return err }}, - {"WriteUint64", func(rw io.ReadWriter) error { return WriteUint64(rw, LittleEndian, 0) }}, - {"ReadUint64", func(rw io.ReadWriter) error { _, err := ReadUint64(rw, LittleEndian); return err }}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - want := errors.New("want") - if got := test.f(&readWriter{want}); got != want { - t.Errorf("got = %v, want = %v", got, want) - } - }) - } -} diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD deleted file mode 100644 index 0c2dde4f8..000000000 --- a/pkg/bits/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -go_library( - name = "bits", - srcs = [ - "bits.go", - "bits32.go", - "bits64.go", - "uint64_arch_amd64.go", - "uint64_arch_amd64_asm.s", - "uint64_arch_generic.go", - ], - importpath = "gvisor.dev/gvisor/pkg/bits", - visibility = ["//:sandbox"], -) - -go_template( - name = "bits_template", - srcs = ["bits_template.go"], - types = [ - "T", - ], -) - -go_template_instance( - name = "bits64", - out = "bits64.go", - package = "bits", - suffix = "64", - template = ":bits_template", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "bits32", - out = "bits32.go", - package = "bits", - suffix = "32", - template = ":bits_template", - types = { - "T": "uint32", - }, -) - -go_test( - name = "bits_test", - size = "small", - srcs = ["uint64_test.go"], - embed = [":bits"], -) diff --git a/pkg/bits/bits32.go b/pkg/bits/bits32.go new file mode 100755 index 000000000..4e9e45dce --- /dev/null +++ b/pkg/bits/bits32.go @@ -0,0 +1,25 @@ +package bits + +// IsOn returns true if *all* bits set in 'bits' are set in 'mask'. +func IsOn32(mask, bits uint32) bool { + return mask&bits == bits +} + +// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'. +func IsAnyOn32(mask, bits uint32) bool { + return mask&bits != 0 +} + +// Mask returns a T with all of the given bits set. +func Mask32(is ...int) uint32 { + ret := uint32(0) + for _, i := range is { + ret |= MaskOf32(i) + } + return ret +} + +// MaskOf is like Mask, but sets only a single bit (more efficiently). +func MaskOf32(i int) uint32 { + return uint32(1) << uint32(i) +} diff --git a/pkg/bits/bits64.go b/pkg/bits/bits64.go new file mode 100755 index 000000000..f49158792 --- /dev/null +++ b/pkg/bits/bits64.go @@ -0,0 +1,25 @@ +package bits + +// IsOn returns true if *all* bits set in 'bits' are set in 'mask'. +func IsOn64(mask, bits uint64) bool { + return mask&bits == bits +} + +// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'. +func IsAnyOn64(mask, bits uint64) bool { + return mask&bits != 0 +} + +// Mask returns a T with all of the given bits set. +func Mask64(is ...int) uint64 { + ret := uint64(0) + for _, i := range is { + ret |= MaskOf64(i) + } + return ret +} + +// MaskOf is like Mask, but sets only a single bit (more efficiently). +func MaskOf64(i int) uint64 { + return uint64(1) << uint64(i) +} diff --git a/pkg/bits/bits_state_autogen.go b/pkg/bits/bits_state_autogen.go new file mode 100755 index 000000000..2abb1291b --- /dev/null +++ b/pkg/bits/bits_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package bits + diff --git a/pkg/bits/bits_template.go b/pkg/bits/bits_template.go deleted file mode 100644 index 93a435b80..000000000 --- a/pkg/bits/bits_template.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bits - -// Non-atomic bit operations on a template type T. - -// T is a required type parameter that must be an integral type. -type T uint64 - -// IsOn returns true if *all* bits set in 'bits' are set in 'mask'. -func IsOn(mask, bits T) bool { - return mask&bits == bits -} - -// IsAnyOn returns true if *any* bit set in 'bits' is set in 'mask'. -func IsAnyOn(mask, bits T) bool { - return mask&bits != 0 -} - -// Mask returns a T with all of the given bits set. -func Mask(is ...int) T { - ret := T(0) - for _, i := range is { - ret |= MaskOf(i) - } - return ret -} - -// MaskOf is like Mask, but sets only a single bit (more efficiently). -func MaskOf(i int) T { - return T(1) << T(i) -} diff --git a/pkg/bits/uint64_test.go b/pkg/bits/uint64_test.go deleted file mode 100644 index 1b018d808..000000000 --- a/pkg/bits/uint64_test.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bits - -import ( - "reflect" - "testing" -) - -func TestTrailingZeros64(t *testing.T) { - for i := 0; i <= 64; i++ { - n := uint64(1) << uint(i) - if got, want := TrailingZeros64(n), i; got != want { - t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) << uint(i) - if got, want := TrailingZeros64(n), i; got != want { - t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) >> uint(i) - if got, want := TrailingZeros64(n), 0; got != want { - t.Errorf("TrailingZeros64(%#x): got %d, wanted %d", n, got, want) - } - } -} - -func TestMostSignificantOne64(t *testing.T) { - for i := 0; i <= 64; i++ { - n := uint64(1) << uint(i) - if got, want := MostSignificantOne64(n), i; got != want { - t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) >> uint(i) - if got, want := MostSignificantOne64(n), 63-i; got != want { - t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want) - } - } - - for i := 0; i < 64; i++ { - n := ^uint64(0) << uint(i) - if got, want := MostSignificantOne64(n), 63; got != want { - t.Errorf("MostSignificantOne64(%#x): got %d, wanted %d", n, got, want) - } - } -} - -func TestForEachSetBit64(t *testing.T) { - for _, want := range [][]int{ - {}, - {0}, - {1}, - {63}, - {0, 1}, - {1, 3, 5}, - {0, 63}, - } { - n := Mask64(want...) - // "Slice values are deeply equal when ... they are both nil or both - // non-nil ..." - got := make([]int, 0) - ForEachSetBit64(n, func(i int) { - got = append(got, i) - }) - if !reflect.DeepEqual(got, want) { - t.Errorf("ForEachSetBit64(%#x): iterated bits %v, wanted %v", n, got, want) - } - } -} - -func TestIsOn(t *testing.T) { - type spec struct { - mask uint64 - bits uint64 - any bool - all bool - } - for _, s := range []spec{ - {Mask64(0), Mask64(0), true, true}, - {Mask64(63), Mask64(63), true, true}, - {Mask64(0), Mask64(1), false, false}, - {Mask64(0), Mask64(0, 1), true, false}, - - {Mask64(1, 63), Mask64(1), true, true}, - {Mask64(1, 63), Mask64(1, 63), true, true}, - {Mask64(1, 63), Mask64(0, 1, 63), true, false}, - {Mask64(1, 63), Mask64(0, 62), false, false}, - } { - if ok := IsAnyOn64(s.mask, s.bits); ok != s.any { - t.Errorf("IsAnyOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.any) - } - if ok := IsOn64(s.mask, s.bits); ok != s.all { - t.Errorf("IsOn(%#x, %#x) = %v, wanted: %v", s.mask, s.bits, ok, s.all) - } - } -} diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD deleted file mode 100644 index b692aa3b1..000000000 --- a/pkg/bpf/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "bpf", - srcs = [ - "bpf.go", - "decoder.go", - "input_bytes.go", - "interpreter.go", - "program_builder.go", - ], - importpath = "gvisor.dev/gvisor/pkg/bpf", - visibility = ["//visibility:public"], - deps = ["//pkg/abi/linux"], -) - -go_test( - name = "bpf_test", - size = "small", - srcs = [ - "decoder_test.go", - "interpreter_test.go", - "program_builder_test.go", - ], - embed = [":bpf"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - ], -) diff --git a/pkg/bpf/bpf_state_autogen.go b/pkg/bpf/bpf_state_autogen.go new file mode 100755 index 000000000..d10db8e49 --- /dev/null +++ b/pkg/bpf/bpf_state_autogen.go @@ -0,0 +1,22 @@ +// automatically generated by stateify. + +package bpf + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Program) beforeSave() {} +func (x *Program) save(m state.Map) { + x.beforeSave() + m.Save("instructions", &x.instructions) +} + +func (x *Program) afterLoad() {} +func (x *Program) load(m state.Map) { + m.Load("instructions", &x.instructions) +} + +func init() { + state.Register("bpf.Program", (*Program)(nil), state.Fns{Save: (*Program).save, Load: (*Program).load}) +} diff --git a/pkg/bpf/decoder_test.go b/pkg/bpf/decoder_test.go deleted file mode 100644 index 6a023f0c0..000000000 --- a/pkg/bpf/decoder_test.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bpf - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" -) - -func TestDecode(t *testing.T) { - for _, test := range []struct { - filter linux.BPFInstruction - expected string - fail bool - }{ - {filter: Stmt(Ld+Imm, 10), expected: "A <- 10"}, - {filter: Stmt(Ld+Abs+W, 10), expected: "A <- P[10:4]"}, - {filter: Stmt(Ld+Ind+H, 10), expected: "A <- P[X+10:2]"}, - {filter: Stmt(Ld+Ind+B, 10), expected: "A <- P[X+10:1]"}, - {filter: Stmt(Ld+Mem, 10), expected: "A <- M[10]"}, - {filter: Stmt(Ld+Len, 0), expected: "A <- len"}, - {filter: Stmt(Ldx+Imm, 10), expected: "X <- 10"}, - {filter: Stmt(Ldx+Mem, 10), expected: "X <- M[10]"}, - {filter: Stmt(Ldx+Len, 0), expected: "X <- len"}, - {filter: Stmt(Ldx+Msh, 10), expected: "X <- 4*(P[10:1]&0xf)"}, - {filter: Stmt(St, 10), expected: "M[10] <- A"}, - {filter: Stmt(Stx, 10), expected: "M[10] <- X"}, - {filter: Stmt(Alu+Add+K, 10), expected: "A <- A + 10"}, - {filter: Stmt(Alu+Sub+K, 10), expected: "A <- A - 10"}, - {filter: Stmt(Alu+Mul+K, 10), expected: "A <- A * 10"}, - {filter: Stmt(Alu+Div+K, 10), expected: "A <- A / 10"}, - {filter: Stmt(Alu+Or+K, 10), expected: "A <- A | 10"}, - {filter: Stmt(Alu+And+K, 10), expected: "A <- A & 10"}, - {filter: Stmt(Alu+Lsh+K, 10), expected: "A <- A << 10"}, - {filter: Stmt(Alu+Rsh+K, 10), expected: "A <- A >> 10"}, - {filter: Stmt(Alu+Mod+K, 10), expected: "A <- A % 10"}, - {filter: Stmt(Alu+Xor+K, 10), expected: "A <- A ^ 10"}, - {filter: Stmt(Alu+Add+X, 0), expected: "A <- A + X"}, - {filter: Stmt(Alu+Sub+X, 0), expected: "A <- A - X"}, - {filter: Stmt(Alu+Mul+X, 0), expected: "A <- A * X"}, - {filter: Stmt(Alu+Div+X, 0), expected: "A <- A / X"}, - {filter: Stmt(Alu+Or+X, 0), expected: "A <- A | X"}, - {filter: Stmt(Alu+And+X, 0), expected: "A <- A & X"}, - {filter: Stmt(Alu+Lsh+X, 0), expected: "A <- A << X"}, - {filter: Stmt(Alu+Rsh+X, 0), expected: "A <- A >> X"}, - {filter: Stmt(Alu+Mod+X, 0), expected: "A <- A % X"}, - {filter: Stmt(Alu+Xor+X, 0), expected: "A <- A ^ X"}, - {filter: Stmt(Alu+Neg, 0), expected: "A <- -A"}, - {filter: Stmt(Jmp+Ja, 10), expected: "pc += 10"}, - {filter: Jump(Jmp+Jeq+K, 10, 2, 5), expected: "pc += (A == 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jgt+K, 10, 2, 5), expected: "pc += (A > 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jge+K, 10, 2, 5), expected: "pc += (A >= 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jset+K, 10, 2, 5), expected: "pc += (A & 10) ? 2 : 5"}, - {filter: Jump(Jmp+Jeq+X, 0, 2, 5), expected: "pc += (A == X) ? 2 : 5"}, - {filter: Jump(Jmp+Jgt+X, 0, 2, 5), expected: "pc += (A > X) ? 2 : 5"}, - {filter: Jump(Jmp+Jge+X, 0, 2, 5), expected: "pc += (A >= X) ? 2 : 5"}, - {filter: Jump(Jmp+Jset+X, 0, 2, 5), expected: "pc += (A & X) ? 2 : 5"}, - {filter: Stmt(Ret+K, 10), expected: "ret 10"}, - {filter: Stmt(Ret+A, 0), expected: "ret A"}, - {filter: Stmt(Misc+Tax, 0), expected: "X <- A"}, - {filter: Stmt(Misc+Txa, 0), expected: "A <- X"}, - {filter: Stmt(Ld+Ind+Msh, 0), fail: true}, - } { - got, err := Decode(test.filter) - if test.fail { - if err == nil { - t.Errorf("Decode(%v) failed, expected: 'error', got: %q", test.filter, got) - continue - } - } else { - if err != nil { - t.Errorf("Decode(%v) failed for test %q, error: %q", test.filter, test.expected, err) - continue - } - if got != test.expected { - t.Errorf("Decode(%v) failed, expected: %q, got: %q", test.filter, test.expected, got) - continue - } - } - } -} - -func TestDecodeProgram(t *testing.T) { - for _, test := range []struct { - name string - program []linux.BPFInstruction - expected string - fail bool - }{ - {name: "basic with jump indexes", - program: []linux.BPFInstruction{ - Stmt(Ld+Abs+W, 10), - Stmt(Ldx+Mem, 10), - Stmt(St, 10), - Stmt(Stx, 10), - Stmt(Alu+Add+K, 10), - Stmt(Jmp+Ja, 10), - Jump(Jmp+Jeq+K, 10, 2, 5), - Jump(Jmp+Jset+X, 0, 0, 5), - Stmt(Misc+Tax, 0), - }, - expected: "0: A <- P[10:4]\n" + - "1: X <- M[10]\n" + - "2: M[10] <- A\n" + - "3: M[10] <- X\n" + - "4: A <- A + 10\n" + - "5: pc += 10 [16]\n" + - "6: pc += (A == 10) ? 2 [9] : 5 [12]\n" + - "7: pc += (A & X) ? 0 [8] : 5 [13]\n" + - "8: X <- A\n", - }, - {name: "invalid instruction", - program: []linux.BPFInstruction{Stmt(Ld+Abs+W, 10), Stmt(Ld+Len+Mem, 0)}, - fail: true}, - } { - got, err := DecodeProgram(test.program) - if test.fail { - if err == nil { - t.Errorf("%s: Decode(...) failed, expected: 'error', got: %q", test.name, got) - continue - } - } else { - if err != nil { - t.Errorf("%s: Decode failed: %v", test.name, err) - continue - } - if got != test.expected { - t.Errorf("%s: Decode(...) failed, expected: %q, got: %q", test.name, test.expected, got) - continue - } - } - } -} diff --git a/pkg/bpf/interpreter_test.go b/pkg/bpf/interpreter_test.go deleted file mode 100644 index 547921d0a..000000000 --- a/pkg/bpf/interpreter_test.go +++ /dev/null @@ -1,797 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bpf - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" -) - -func TestCompilationErrors(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // insns is the BPF instructions to be compiled. - insns []linux.BPFInstruction - - // expectedErr is the expected compilation error. - expectedErr error - }{ - { - desc: "Instructions must not be nil", - expectedErr: Error{InvalidInstructionCount, 0}, - }, - { - desc: "Instructions must not be empty", - insns: []linux.BPFInstruction{}, - expectedErr: Error{InvalidInstructionCount, 0}, - }, - { - desc: "A program must end with a return", - insns: make([]linux.BPFInstruction, MaxInstructions), - expectedErr: Error{InvalidEndOfProgram, MaxInstructions - 1}, - }, - { - desc: "A program must have MaxInstructions or fewer instructions", - insns: append(make([]linux.BPFInstruction, MaxInstructions), Stmt(Ret|K, 0)), - expectedErr: Error{InvalidInstructionCount, MaxInstructions + 1}, - }, - { - desc: "A load from an invalid M register is a compilation error", - insns: []linux.BPFInstruction{ - Stmt(Ld|Mem|W, ScratchMemRegisters), // A = M[16] - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidRegister, 0}, - }, - { - desc: "A store to an invalid M register is a compilation error", - insns: []linux.BPFInstruction{ - Stmt(St, ScratchMemRegisters), // M[16] = A - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidRegister, 0}, - }, - { - desc: "Division by literal zero is a compilation error", - insns: []linux.BPFInstruction{ - Stmt(Alu|Div|K, 0), // A /= 0 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{DivisionByZero, 0}, - }, - { - desc: "An unconditional jump outside of the program is a compilation error", - insns: []linux.BPFInstruction{ - Jump(Jmp|Ja, 1, 0, 0), // jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidJumpTarget, 0}, - }, - { - desc: "A conditional jump outside of the program in the true case is a compilation error", - insns: []linux.BPFInstruction{ - Jump(Jmp|Jeq|K, 0, 1, 0), // if (A == K) jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidJumpTarget, 0}, - }, - { - desc: "A conditional jump outside of the program in the false case is a compilation error", - insns: []linux.BPFInstruction{ - Jump(Jmp|Jeq|K, 0, 0, 1), // if (A != K) jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidJumpTarget, 0}, - }, - } { - _, err := Compile(test.insns) - if err != test.expectedErr { - t.Errorf("%s: expected error %q, got error %q", test.desc, test.expectedErr, err) - } - } -} - -func TestExecErrors(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // insns is the BPF instructions to be executed. - insns []linux.BPFInstruction - - // expectedErr is the expected execution error. - expectedErr error - }{ - { - desc: "An out-of-bounds load of input data is an execution error", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|B, 0), // A = input[0] - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{InvalidLoad, 0}, - }, - { - desc: "Division by zero at runtime is an execution error", - insns: []linux.BPFInstruction{ - Stmt(Alu|Div|X, 0), // A /= X - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{DivisionByZero, 0}, - }, - { - desc: "Modulo zero at runtime is an execution error", - insns: []linux.BPFInstruction{ - Stmt(Alu|Mod|X, 0), // A %= X - Stmt(Ret|K, 0), // return 0 - }, - expectedErr: Error{DivisionByZero, 0}, - }, - } { - p, err := Compile(test.insns) - if err != nil { - t.Errorf("%s: unexpected compilation error: %v", test.desc, err) - continue - } - ret, err := Exec(p, InputBytes{nil, binary.BigEndian}) - if err != test.expectedErr { - t.Errorf("%s: expected execution error %q, got (%d, %v)", test.desc, test.expectedErr, ret, err) - } - } -} - -func TestValidInstructions(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // insns is the BPF instructions to be compiled. - insns []linux.BPFInstruction - - // input is the input data. Note that input will be read as big-endian. - input []byte - - // expectedRet is the expected return value of the BPF program. - expectedRet uint32 - }{ - { - desc: "Return of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ret|K, 42), // return 42 - }, - expectedRet: 42, - }, - { - desc: "Load of immediate into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load of immediate into X and copying of X into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 42), // X = 42 - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Copying of A into X and back", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(Misc|Txa, 0), // X = A - Stmt(Ld|Imm|W, 0), // A = 0 - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load of 32-bit input by absolute offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|W, 1), // A = input[1..4] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22, 0x33, 0x44}, - expectedRet: 0x11223344, - }, - { - desc: "Load of 16-bit input by absolute offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|H, 1), // A = input[1..2] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22}, - expectedRet: 0x1122, - }, - { - desc: "Load of 8-bit input by absolute offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Abs|B, 1), // A = input[1] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11}, - expectedRet: 0x11, - }, - { - desc: "Load of 32-bit input by relative offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 1), // X = 1 - Stmt(Ld|Ind|W, 1), // A = input[X+1..X+4] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, - expectedRet: 0x22334455, - }, - { - desc: "Load of 16-bit input by relative offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 1), // X = 1 - Stmt(Ld|Ind|H, 1), // A = input[X+1..X+2] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22, 0x33}, - expectedRet: 0x2233, - }, - { - desc: "Load of 8-bit input by relative offset into A", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 1), // X = 1 - Stmt(Ld|Ind|B, 1), // A = input[X+1] - Stmt(Ret|A, 0), // return A - }, - input: []byte{0x00, 0x11, 0x22}, - expectedRet: 0x22, - }, - { - desc: "Load/store between A and scratch memory", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(St, 2), // M[2] = A - Stmt(Ld|Imm|W, 0), // A = 0 - Stmt(Ld|Mem|W, 2), // A = M[2] - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load/store between X and scratch memory", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Imm|W, 42), // X = 42 - Stmt(Stx, 3), // M[3] = X - Stmt(Ldx|Imm|W, 0), // X = 0 - Stmt(Ldx|Mem|W, 3), // X = M[3] - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 42, - }, - { - desc: "Load of input length into A", - insns: []linux.BPFInstruction{ - Stmt(Ld|Len|W, 0), // A = len(input) - Stmt(Ret|A, 0), // return A - }, - input: []byte{1, 2, 3}, - expectedRet: 3, - }, - { - desc: "Load of input length into X", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Len|W, 0), // X = len(input) - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - input: []byte{1, 2, 3}, - expectedRet: 3, - }, - { - desc: "Load of MSH (?) into X", - insns: []linux.BPFInstruction{ - Stmt(Ldx|Msh|B, 0), // X = 4*(input[0]&0xf) - Stmt(Misc|Tax, 0), // A = X - Stmt(Ret|A, 0), // return A - }, - input: []byte{0xf1}, - expectedRet: 4, - }, - { - desc: "Addition of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Alu|Add|K, 20), // A += 20 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 30, - }, - { - desc: "Addition of X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 20), // X = 20 - Stmt(Alu|Add|X, 0), // A += X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 30, - }, - { - desc: "Subtraction of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 30), // A = 30 - Stmt(Alu|Sub|K, 20), // A -= 20 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 10, - }, - { - desc: "Subtraction of X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 30), // A = 30 - Stmt(Ldx|Imm|W, 20), // X = 20 - Stmt(Alu|Sub|X, 0), // A -= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 10, - }, - { - desc: "Multiplication of immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 2), // A = 2 - Stmt(Alu|Mul|K, 3), // A *= 3 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 6, - }, - { - desc: "Multiplication of X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 2), // A = 2 - Stmt(Ldx|Imm|W, 3), // X = 3 - Stmt(Alu|Mul|X, 0), // A *= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 6, - }, - { - desc: "Division by immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 6), // A = 6 - Stmt(Alu|Div|K, 3), // A /= 3 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 2, - }, - { - desc: "Division by X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 6), // A = 6 - Stmt(Ldx|Imm|W, 3), // X = 3 - Stmt(Alu|Div|X, 0), // A /= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 2, - }, - { - desc: "Modulo immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 17), // A = 17 - Stmt(Alu|Mod|K, 7), // A %= 7 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 3, - }, - { - desc: "Modulo X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 17), // A = 17 - Stmt(Ldx|Imm|W, 7), // X = 7 - Stmt(Alu|Mod|X, 0), // A %= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 3, - }, - { - desc: "Arithmetic negation", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 1), // A = 1 - Stmt(Alu|Neg, 0), // A = -A - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xffffffff, - }, - { - desc: "Bitwise OR with immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Alu|Or|K, 0xff0055aa), // A |= 0xff0055aa - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff00ffff, - }, - { - desc: "Bitwise OR with X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa - Stmt(Alu|Or|X, 0), // A |= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff00ffff, - }, - { - desc: "Bitwise AND with immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Alu|And|K, 0xff0055aa), // A &= 0xff0055aa - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff000000, - }, - { - desc: "Bitwise AND with X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa - Stmt(Alu|And|X, 0), // A &= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0xff000000, - }, - { - desc: "Bitwise XOR with immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Alu|Xor|K, 0xff0055aa), // A ^= 0xff0055aa - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0x0000ffff, - }, - { - desc: "Bitwise XOR with X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff00aa55), // A = 0xff00aa55 - Stmt(Ldx|Imm|W, 0xff0055aa), // X = 0xff0055aa - Stmt(Alu|Xor|X, 0), // A ^= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 0x0000ffff, - }, - { - desc: "Left shift by immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 1), // A = 1 - Stmt(Alu|Lsh|K, 5), // A <<= 5 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 32, - }, - { - desc: "Left shift by X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 1), // A = 1 - Stmt(Ldx|Imm|W, 5), // X = 5 - Stmt(Alu|Lsh|X, 0), // A <<= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 32, - }, - { - desc: "Right shift by immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xffffffff), // A = 0xffffffff - Stmt(Alu|Rsh|K, 31), // A >>= 31 - Stmt(Ret|A, 0), // return A - }, - expectedRet: 1, - }, - { - desc: "Right shift by X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xffffffff), // A = 0xffffffff - Stmt(Ldx|Imm|W, 31), // X = 31 - Stmt(Alu|Rsh|X, 0), // A >>= X - Stmt(Ret|A, 0), // return A - }, - expectedRet: 1, - }, - { - desc: "Unconditional jump", - insns: []linux.BPFInstruction{ - Jump(Jmp|Ja, 1, 0, 0), // jmp nextpc+1 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - }, - expectedRet: 1, - }, - { - desc: "Jump when A == immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Jump(Jmp|Jeq|K, 42, 1, 2), // if (A == 42) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A != immediate", - insns: []linux.BPFInstruction{ - Jump(Jmp|Jeq|K, 42, 1, 2), // if (A == 42) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A == X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Stmt(Ldx|Imm|W, 42), // X = 42 - Jump(Jmp|Jeq|X, 0, 1, 2), // if (A == X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A != X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 42), // A = 42 - Jump(Jmp|Jeq|X, 0, 1, 2), // if (A == X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A > immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jgt|K, 9, 1, 2), // if (A > 9) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A <= immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jgt|K, 10, 1, 2), // if (A > 10) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A > X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 9), // X = 9 - Jump(Jmp|Jgt|X, 0, 1, 2), // if (A > X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A <= X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 10), // X = 10 - Jump(Jmp|Jgt|X, 0, 1, 2), // if (A > X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A >= immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jge|K, 10, 1, 2), // if (A >= 10) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A < immediate", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Jump(Jmp|Jge|K, 11, 1, 2), // if (A >= 11) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A >= X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 10), // X = 10 - Jump(Jmp|Jge|X, 0, 1, 2), // if (A >= X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A < X", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 10), // A = 10 - Stmt(Ldx|Imm|W, 11), // X = 11 - Jump(Jmp|Jge|X, 0, 1, 2), // if (A >= X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A & immediate != 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff), // A = 0xff - Jump(Jmp|Jset|K, 0x101, 1, 2), // if (A & 0x101) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A & immediate == 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xfe), // A = 0xfe - Jump(Jmp|Jset|K, 0x101, 1, 2), // if (A & 0x101) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - { - desc: "Jump when A & X != 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xff), // A = 0xff - Stmt(Ldx|Imm|W, 0x101), // X = 0x101 - Jump(Jmp|Jset|X, 0, 1, 2), // if (A & X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 1, - }, - { - desc: "Jump when A & X == 0", - insns: []linux.BPFInstruction{ - Stmt(Ld|Imm|W, 0xfe), // A = 0xfe - Stmt(Ldx|Imm|W, 0x101), // X = 0x101 - Jump(Jmp|Jset|X, 0, 1, 2), // if (A & X) jmp nextpc+1 else jmp nextpc+2 - Stmt(Ret|K, 0), // return 0 - Stmt(Ret|K, 1), // return 1 - Stmt(Ret|K, 2), // return 2 - }, - expectedRet: 2, - }, - } { - p, err := Compile(test.insns) - if err != nil { - t.Errorf("%s: unexpected compilation error: %v", test.desc, err) - continue - } - ret, err := Exec(p, InputBytes{test.input, binary.BigEndian}) - if err != nil { - t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err) - continue - } - if ret != test.expectedRet { - t.Errorf("%s: expected return value of %d, got value %d", test.desc, test.expectedRet, ret) - } - } -} - -func TestSimpleFilter(t *testing.T) { - // Seccomp filter example given in Linux's - // Documentation/networking/filter.txt, translated to bytecode using the - // Linux kernel tree's tools/net/bpf_asm. - filter := []linux.BPFInstruction{ - {0x20, 0, 0, 0x00000004}, // ld [4] /* offsetof(struct seccomp_data, arch) */ - {0x15, 0, 11, 0xc000003e}, // jne #0xc000003e, bad /* AUDIT_ARCH_X86_64 */ - {0x20, 0, 0, 0000000000}, // ld [0] /* offsetof(struct seccomp_data, nr) */ - {0x15, 10, 0, 0x0000000f}, // jeq #15, good /* __NR_rt_sigreturn */ - {0x15, 9, 0, 0x000000e7}, // jeq #231, good /* __NR_exit_group */ - {0x15, 8, 0, 0x0000003c}, // jeq #60, good /* __NR_exit */ - {0x15, 7, 0, 0000000000}, // jeq #0, good /* __NR_read */ - {0x15, 6, 0, 0x00000001}, // jeq #1, good /* __NR_write */ - {0x15, 5, 0, 0x00000005}, // jeq #5, good /* __NR_fstat */ - {0x15, 4, 0, 0x00000009}, // jeq #9, good /* __NR_mmap */ - {0x15, 3, 0, 0x0000000e}, // jeq #14, good /* __NR_rt_sigprocmask */ - {0x15, 2, 0, 0x0000000d}, // jeq #13, good /* __NR_rt_sigaction */ - {0x15, 1, 0, 0x00000023}, // jeq #35, good /* __NR_nanosleep */ - {0x06, 0, 0, 0000000000}, // bad: ret #0 /* SECCOMP_RET_KILL */ - {0x06, 0, 0, 0x7fff0000}, // good: ret #0x7fff0000 /* SECCOMP_RET_ALLOW */ - } - p, err := Compile(filter) - if err != nil { - t.Fatalf("Unexpected compilation error: %v", err) - } - - for _, test := range []struct { - // desc is the test's description. - desc string - - // seccompData is the input data. - seccompData - - // expectedRet is the expected return value of the BPF program. - expectedRet uint32 - }{ - { - desc: "Invalid arch is rejected", - seccompData: seccompData{nr: 1 /* x86 exit */, arch: 0x40000003 /* AUDIT_ARCH_I386 */}, - expectedRet: 0, - }, - { - desc: "Disallowed syscall is rejected", - seccompData: seccompData{nr: 105 /* __NR_setuid */, arch: 0xc000003e}, - expectedRet: 0, - }, - { - desc: "Whitelisted syscall is allowed", - seccompData: seccompData{nr: 231 /* __NR_exit_group */, arch: 0xc000003e}, - expectedRet: 0x7fff0000, - }, - } { - ret, err := Exec(p, test.seccompData.asInput()) - if err != nil { - t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err) - continue - } - if ret != test.expectedRet { - t.Errorf("%s: expected return value of %d, got value %d", test.desc, test.expectedRet, ret) - } - } -} - -// seccompData is equivalent to struct seccomp_data. -type seccompData struct { - nr uint32 - arch uint32 - instructionPointer uint64 - args [6]uint64 -} - -// asInput converts a seccompData to a bpf.Input. -func (d *seccompData) asInput() Input { - return InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} -} diff --git a/pkg/bpf/program_builder_test.go b/pkg/bpf/program_builder_test.go deleted file mode 100644 index 92ca5f4c3..000000000 --- a/pkg/bpf/program_builder_test.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bpf - -import ( - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" -) - -func validate(p *ProgramBuilder, expected []linux.BPFInstruction) error { - instructions, err := p.Instructions() - if err != nil { - return fmt.Errorf("Instructions() failed: %v", err) - } - got, err := DecodeProgram(instructions) - if err != nil { - return fmt.Errorf("DecodeProgram('instructions') failed: %v", err) - } - expectedDecoded, err := DecodeProgram(expected) - if err != nil { - return fmt.Errorf("DecodeProgram('expected') failed: %v", err) - } - if got != expectedDecoded { - return fmt.Errorf("DecodeProgram() failed, expected: %q, got: %q", expectedDecoded, got) - } - return nil -} - -func TestProgramBuilderSimple(t *testing.T) { - p := NewProgramBuilder() - p.AddStmt(Ld+Abs+W, 10) - p.AddJump(Jmp+Ja, 10, 0, 0) - - expected := []linux.BPFInstruction{ - Stmt(Ld+Abs+W, 10), - Jump(Jmp+Ja, 10, 0, 0), - } - - if err := validate(p, expected); err != nil { - t.Errorf("Validate() failed: %v", err) - } -} - -func TestProgramBuilderLabels(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 11, "label_1", 0) - p.AddJumpFalseLabel(Jmp+Jeq+K, 12, 0, "label_2") - p.AddJumpLabels(Jmp+Jeq+K, 13, "label_3", "label_4") - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 1) - if err := p.AddLabel("label_3"); err != nil { - t.Errorf("AddLabel(label_3) failed: %v", err) - } - p.AddJumpLabels(Jmp+Jeq+K, 14, "label_4", "label_5") - if err := p.AddLabel("label_2"); err != nil { - t.Errorf("AddLabel(label_2) failed: %v", err) - } - p.AddJumpLabels(Jmp+Jeq+K, 15, "label_4", "label_6") - if err := p.AddLabel("label_4"); err != nil { - t.Errorf("AddLabel(label_4) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 4) - if err := p.AddLabel("label_5"); err != nil { - t.Errorf("AddLabel(label_5) failed: %v", err) - } - if err := p.AddLabel("label_6"); err != nil { - t.Errorf("AddLabel(label_6) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 5) - - expected := []linux.BPFInstruction{ - Jump(Jmp+Jeq+K, 11, 2, 0), - Jump(Jmp+Jeq+K, 12, 0, 3), - Jump(Jmp+Jeq+K, 13, 1, 3), - Stmt(Ld+Abs+W, 1), - Jump(Jmp+Jeq+K, 14, 1, 2), - Jump(Jmp+Jeq+K, 15, 0, 1), - Stmt(Ld+Abs+W, 4), - Stmt(Ld+Abs+W, 5), - } - if err := validate(p, expected); err != nil { - t.Errorf("Validate() failed: %v", err) - } - // Calling validate()=>p.Instructions() again to make sure - // Instructions can be called multiple times without ruining - // the program. - if err := validate(p, expected); err != nil { - t.Errorf("Validate() failed: %v", err) - } -} - -func TestProgramBuilderMissingErrorTarget(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if _, err := p.Instructions(); err == nil { - t.Errorf("Instructions() should have failed") - } -} - -func TestProgramBuilderLabelWithNoInstruction(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - if _, err := p.Instructions(); err == nil { - t.Errorf("Instructions() should have failed") - } -} - -func TestProgramBuilderUnusedLabel(t *testing.T) { - p := NewProgramBuilder() - if err := p.AddLabel("unused"); err == nil { - t.Errorf("AddLabel(unused) should have failed") - } -} - -func TestProgramBuilderLabelAddedTwice(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 0) - if err := p.AddLabel("label_1"); err == nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } -} - -func TestProgramBuilderJumpBackwards(t *testing.T) { - p := NewProgramBuilder() - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if err := p.AddLabel("label_1"); err != nil { - t.Errorf("AddLabel(label_1) failed: %v", err) - } - p.AddStmt(Ld+Abs+W, 0) - p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "label_1", 0) - if _, err := p.Instructions(); err == nil { - t.Errorf("Instructions() should have failed") - } -} diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD deleted file mode 100644 index cdec96df1..000000000 --- a/pkg/compressio/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "compressio", - srcs = ["compressio.go"], - importpath = "gvisor.dev/gvisor/pkg/compressio", - visibility = ["//:sandbox"], - deps = ["//pkg/binary"], -) - -go_test( - name = "compressio_test", - size = "medium", - srcs = ["compressio_test.go"], - embed = [":compressio"], -) diff --git a/pkg/compressio/compressio_state_autogen.go b/pkg/compressio/compressio_state_autogen.go new file mode 100755 index 000000000..cac5ea41c --- /dev/null +++ b/pkg/compressio/compressio_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package compressio + diff --git a/pkg/compressio/compressio_test.go b/pkg/compressio/compressio_test.go deleted file mode 100644 index 86dc47e44..000000000 --- a/pkg/compressio/compressio_test.go +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package compressio - -import ( - "bytes" - "compress/flate" - "encoding/base64" - "fmt" - "io" - "math/rand" - "runtime" - "testing" - "time" -) - -type harness interface { - Errorf(format string, v ...interface{}) - Fatalf(format string, v ...interface{}) - Logf(format string, v ...interface{}) -} - -func initTest(t harness, size int) []byte { - // Set number of processes to number of CPUs. - runtime.GOMAXPROCS(runtime.NumCPU()) - - // Construct synthetic data. We do this by encoding random data with - // base64. This gives a high level of entropy, but still quite a bit of - // structure, to give reasonable compression ratios (~75%). - var buf bytes.Buffer - bufW := base64.NewEncoder(base64.RawStdEncoding, &buf) - bufR := rand.New(rand.NewSource(0)) - if _, err := io.CopyN(bufW, bufR, int64(size)); err != nil { - t.Fatalf("unable to seed random data: %v", err) - } - return buf.Bytes() -} - -type testOpts struct { - Name string - Data []byte - NewWriter func(*bytes.Buffer) (io.Writer, error) - NewReader func(*bytes.Buffer) (io.Reader, error) - PreCompress func() - PostCompress func() - PreDecompress func() - PostDecompress func() - CompressIters int - DecompressIters int - CorruptData bool -} - -func doTest(t harness, opts testOpts) { - // Compress. - var compressed bytes.Buffer - compressionStartTime := time.Now() - if opts.PreCompress != nil { - opts.PreCompress() - } - if opts.CompressIters <= 0 { - opts.CompressIters = 1 - } - for i := 0; i < opts.CompressIters; i++ { - compressed.Reset() - w, err := opts.NewWriter(&compressed) - if err != nil { - t.Errorf("%s: NewWriter got err %v, expected nil", opts.Name, err) - } - if _, err := io.Copy(w, bytes.NewBuffer(opts.Data)); err != nil { - t.Errorf("%s: compress got err %v, expected nil", opts.Name, err) - return - } - closer, ok := w.(io.Closer) - if ok { - if err := closer.Close(); err != nil { - t.Errorf("%s: got err %v, expected nil", opts.Name, err) - return - } - } - } - if opts.PostCompress != nil { - opts.PostCompress() - } - compressionTime := time.Since(compressionStartTime) - compressionRatio := float32(compressed.Len()) / float32(len(opts.Data)) - - // Decompress. - var decompressed bytes.Buffer - decompressionStartTime := time.Now() - if opts.PreDecompress != nil { - opts.PreDecompress() - } - if opts.DecompressIters <= 0 { - opts.DecompressIters = 1 - } - if opts.CorruptData { - b := compressed.Bytes() - b[rand.Intn(len(b))]++ - } - for i := 0; i < opts.DecompressIters; i++ { - decompressed.Reset() - r, err := opts.NewReader(bytes.NewBuffer(compressed.Bytes())) - if err != nil { - if opts.CorruptData { - continue - } - t.Errorf("%s: NewReader got err %v, expected nil", opts.Name, err) - return - } - if _, err := io.Copy(&decompressed, r); (err != nil) != opts.CorruptData { - t.Errorf("%s: decompress got err %v unexpectly", opts.Name, err) - return - } - } - if opts.PostDecompress != nil { - opts.PostDecompress() - } - decompressionTime := time.Since(decompressionStartTime) - - if opts.CorruptData { - return - } - - // Verify. - if decompressed.Len() != len(opts.Data) { - t.Errorf("%s: got %d bytes, expected %d", opts.Name, decompressed.Len(), len(opts.Data)) - } - if !bytes.Equal(opts.Data, decompressed.Bytes()) { - t.Errorf("%s: got mismatch, expected match", opts.Name) - if len(opts.Data) < 32 { // Don't flood the logs. - t.Errorf("got %v, expected %v", decompressed.Bytes(), opts.Data) - } - } - - t.Logf("%s: compression time %v, ratio %2.2f, decompression time %v", - opts.Name, compressionTime, compressionRatio, decompressionTime) -} - -var hashKey = []byte("01234567890123456789012345678901") - -func TestCompress(t *testing.T) { - rand.Seed(time.Now().Unix()) - - var ( - data = initTest(t, 10*1024*1024) - data0 = data[:0] - data1 = data[:1] - data2 = data[:11] - data3 = data[:16] - data4 = data[:] - ) - - for _, data := range [][]byte{data0, data1, data2, data3, data4} { - for _, blockSize := range []uint32{1, 4, 1024, 4 * 1024, 16 * 1024} { - // Skip annoying tests; they just take too long. - if blockSize <= 16 && len(data) > 16 { - continue - } - - for _, key := range [][]byte{nil, hashKey} { - for _, corruptData := range []bool{false, true} { - if key == nil && corruptData { - // No need to test corrupt data - // case when not doing hashing. - continue - } - // Do the compress test. - doTest(t, testOpts{ - Name: fmt.Sprintf("len(data)=%d, blockSize=%d, key=%s, corruptData=%v", len(data), blockSize, string(key), corruptData), - Data: data, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, key, blockSize, flate.BestSpeed) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b, key) - }, - CorruptData: corruptData, - }) - } - } - } - - // Do the vanilla test. - doTest(t, testOpts{ - Name: fmt.Sprintf("len(data)=%d, vanilla flate", len(data)), - Data: data, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return flate.NewWriter(b, flate.BestSpeed) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return flate.NewReader(b), nil - }, - }) - } -} - -const ( - benchDataSize = 600 * 1024 * 1024 -) - -func benchmark(b *testing.B, compress bool, hash bool, blockSize uint32) { - b.StopTimer() - b.SetBytes(benchDataSize) - data := initTest(b, benchDataSize) - compIters := b.N - decompIters := b.N - if compress { - decompIters = 0 - } else { - compIters = 0 - } - key := hashKey - if !hash { - key = nil - } - doTest(b, testOpts{ - Name: fmt.Sprintf("compress=%t, hash=%t, len(data)=%d, blockSize=%d", compress, hash, len(data), blockSize), - Data: data, - PreCompress: b.StartTimer, - PostCompress: b.StopTimer, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, key, blockSize, flate.BestSpeed) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b, key) - }, - CompressIters: compIters, - DecompressIters: decompIters, - }) -} - -func BenchmarkCompressNoHash64K(b *testing.B) { - benchmark(b, true, false, 64*1024) -} - -func BenchmarkCompressHash64K(b *testing.B) { - benchmark(b, true, true, 64*1024) -} - -func BenchmarkDecompressNoHash64K(b *testing.B) { - benchmark(b, false, false, 64*1024) -} - -func BenchmarkDecompressHash64K(b *testing.B) { - benchmark(b, false, true, 64*1024) -} - -func BenchmarkCompressNoHash1M(b *testing.B) { - benchmark(b, true, false, 1024*1024) -} - -func BenchmarkCompressHash1M(b *testing.B) { - benchmark(b, true, true, 1024*1024) -} - -func BenchmarkDecompressNoHash1M(b *testing.B) { - benchmark(b, false, false, 1024*1024) -} - -func BenchmarkDecompressHash1M(b *testing.B) { - benchmark(b, false, true, 1024*1024) -} - -func BenchmarkCompressNoHash16M(b *testing.B) { - benchmark(b, true, false, 16*1024*1024) -} - -func BenchmarkCompressHash16M(b *testing.B) { - benchmark(b, true, true, 16*1024*1024) -} - -func BenchmarkDecompressNoHash16M(b *testing.B) { - benchmark(b, false, false, 16*1024*1024) -} - -func BenchmarkDecompressHash16M(b *testing.B) { - benchmark(b, false, true, 16*1024*1024) -} diff --git a/pkg/control/client/BUILD b/pkg/control/client/BUILD deleted file mode 100644 index 066d7b1a1..000000000 --- a/pkg/control/client/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "client", - srcs = [ - "client.go", - ], - importpath = "gvisor.dev/gvisor/pkg/control/client", - visibility = ["//:sandbox"], - deps = [ - "//pkg/unet", - "//pkg/urpc", - ], -) diff --git a/pkg/control/client/client_state_autogen.go b/pkg/control/client/client_state_autogen.go new file mode 100755 index 000000000..69ea753a9 --- /dev/null +++ b/pkg/control/client/client_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package client + diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD deleted file mode 100644 index 21adf3adf..000000000 --- a/pkg/control/server/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "server", - srcs = ["server.go"], - importpath = "gvisor.dev/gvisor/pkg/control/server", - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/unet", - "//pkg/urpc", - ], -) diff --git a/pkg/control/server/server_state_autogen.go b/pkg/control/server/server_state_autogen.go new file mode 100755 index 000000000..f2b4725d3 --- /dev/null +++ b/pkg/control/server/server_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package server + diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD deleted file mode 100644 index 830e19e07..000000000 --- a/pkg/cpuid/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "cpuid", - srcs = [ - "cpu_amd64.s", - "cpuid.go", - ], - importpath = "gvisor.dev/gvisor/pkg/cpuid", - visibility = ["//:sandbox"], - deps = ["//pkg/log"], -) - -go_test( - name = "cpuid_test", - size = "small", - srcs = ["cpuid_test.go"], - embed = [":cpuid"], -) - -go_test( - name = "cpuid_parse_test", - size = "small", - srcs = [ - "cpuid_parse_test.go", - ], - embed = [":cpuid"], - tags = ["manual"], -) diff --git a/pkg/cpuid/cpuid_parse_test.go b/pkg/cpuid/cpuid_parse_test.go deleted file mode 100644 index dd9969db4..000000000 --- a/pkg/cpuid/cpuid_parse_test.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cpuid - -import ( - "fmt" - "io/ioutil" - "regexp" - "strconv" - "strings" - "syscall" - "testing" -) - -func kernelVersion() (int, int, error) { - var u syscall.Utsname - if err := syscall.Uname(&u); err != nil { - return 0, 0, err - } - - var r string - for _, b := range u.Release { - if b == 0 { - break - } - r += string(b) - } - - s := strings.Split(r, ".") - if len(s) < 2 { - return 0, 0, fmt.Errorf("kernel release missing major and minor component: %s", r) - } - - major, err := strconv.Atoi(s[0]) - if err != nil { - return 0, 0, fmt.Errorf("error parsing major version %q in %q: %v", s[0], r, err) - } - - minor, err := strconv.Atoi(s[1]) - if err != nil { - return 0, 0, fmt.Errorf("error parsing minor version %q in %q: %v", s[1], r, err) - } - - return major, minor, nil -} - -// TestHostFeatureFlags tests that all features detected by HostFeatureSet are -// on the host. -// -// It does *not* verify that all features reported by the host are detected by -// HostFeatureSet. -// -// i.e., test that HostFeatureSet is a subset of the host features. -func TestHostFeatureFlags(t *testing.T) { - cpuinfoBytes, _ := ioutil.ReadFile("/proc/cpuinfo") - cpuinfo := string(cpuinfoBytes) - t.Logf("Host cpu info:\n%s", cpuinfo) - - major, minor, err := kernelVersion() - if err != nil { - t.Fatalf("Unable to parse kernel version: %v", err) - } - - re := regexp.MustCompile(`(?m)^flags\s+: (.*)$`) - m := re.FindStringSubmatch(cpuinfo) - if len(m) != 2 { - t.Fatalf("Unable to extract flags from %q", cpuinfo) - } - - cpuinfoFlags := make(map[string]struct{}) - for _, f := range strings.Split(m[1], " ") { - cpuinfoFlags[f] = struct{}{} - } - - fs := HostFeatureSet() - - // All features have a string and appear in host cpuinfo. - for f := range fs.Set { - name := f.flagString(false) - if name == "" { - t.Errorf("Non-parsable feature: %v", f) - } - - // Special cases not consistently visible. We don't mind if - // they are exposed in earlier versions. - switch { - // Block 0. - case f == X86FeatureSDBG && (major < 4 || major == 4 && minor < 3): - // SDBG only exposed in - // b1c599b8ff80ea79b9f8277a3f9f36a7b0cfedce (4.3). - continue - // Block 2. - case f == X86FeatureRDT && (major < 4 || major == 4 && minor < 10): - // RDT only exposed in - // 4ab1586488cb56ed8728e54c4157cc38646874d9 (4.10). - continue - // Block 3. - case f == X86FeatureAVX512VBMI && (major < 4 || major == 4 && minor < 10): - // AVX512VBMI only exposed in - // a8d9df5a509a232a959e4ef2e281f7ecd77810d6 (4.10). - continue - case f == X86FeatureUMIP && (major < 4 || major == 4 && minor < 15): - // UMIP only exposed in - // 3522c2a6a4f341058b8291326a945e2a2d2aaf55 (4.15). - continue - case f == X86FeaturePKU && (major < 4 || major == 4 && minor < 9): - // PKU only exposed in - // dfb4a70f20c5b3880da56ee4c9484bdb4e8f1e65 (4.9). - continue - // Block 4. - case f == X86FeatureXSAVES && (major < 4 || major == 4 && minor < 8): - // XSAVES only exposed in - // b8be15d588060a03569ac85dc4a0247460988f5b (4.8). - continue - // Block 5. - case f == X86FeaturePERFCTR_LLC && (major < 4 || major == 4 && minor < 14): - // PERFCTR_LLC renamed in - // 910448bbed066ab1082b510eef1ae61bb792d854 (4.14). - continue - } - - hidden := f.flagString(true) == "" - _, ok := cpuinfoFlags[name] - if hidden && ok { - t.Errorf("Unexpectedly hidden flag: %v", f) - } else if !hidden && !ok { - t.Errorf("Non-native flag: %v", f) - } - } -} diff --git a/pkg/cpuid/cpuid_state_autogen.go b/pkg/cpuid/cpuid_state_autogen.go new file mode 100755 index 000000000..9f487b0c4 --- /dev/null +++ b/pkg/cpuid/cpuid_state_autogen.go @@ -0,0 +1,36 @@ +// automatically generated by stateify. + +package cpuid + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *FeatureSet) beforeSave() {} +func (x *FeatureSet) save(m state.Map) { + x.beforeSave() + m.Save("Set", &x.Set) + m.Save("VendorID", &x.VendorID) + m.Save("ExtendedFamily", &x.ExtendedFamily) + m.Save("ExtendedModel", &x.ExtendedModel) + m.Save("ProcessorType", &x.ProcessorType) + m.Save("Family", &x.Family) + m.Save("Model", &x.Model) + m.Save("SteppingID", &x.SteppingID) +} + +func (x *FeatureSet) afterLoad() {} +func (x *FeatureSet) load(m state.Map) { + m.Load("Set", &x.Set) + m.Load("VendorID", &x.VendorID) + m.Load("ExtendedFamily", &x.ExtendedFamily) + m.Load("ExtendedModel", &x.ExtendedModel) + m.Load("ProcessorType", &x.ProcessorType) + m.Load("Family", &x.Family) + m.Load("Model", &x.Model) + m.Load("SteppingID", &x.SteppingID) +} + +func init() { + state.Register("cpuid.FeatureSet", (*FeatureSet)(nil), state.Fns{Save: (*FeatureSet).save, Load: (*FeatureSet).load}) +} diff --git a/pkg/cpuid/cpuid_test.go b/pkg/cpuid/cpuid_test.go deleted file mode 100644 index 6ae14d2da..000000000 --- a/pkg/cpuid/cpuid_test.go +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cpuid - -import ( - "testing" -) - -// These are the default values of various FeatureSet fields. -const ( - defaultVendorID = "GenuineIntel" - - // These processor signature defaults are derived from the values - // listed in Intel Application Note 485 for i7/Xeon processors. - defaultExtFamily uint8 = 0 - defaultExtModel uint8 = 1 - defaultType uint8 = 0 - defaultFamily uint8 = 0x06 - defaultModel uint8 = 0x0a - defaultSteppingID uint8 = 0 -) - -// newEmptyFeatureSet creates a new FeatureSet with a sensible default model and no features. -func newEmptyFeatureSet() *FeatureSet { - return &FeatureSet{ - Set: make(map[Feature]bool), - VendorID: defaultVendorID, - ExtendedFamily: defaultExtFamily, - ExtendedModel: defaultExtModel, - ProcessorType: defaultType, - Family: defaultFamily, - Model: defaultModel, - SteppingID: defaultSteppingID, - } -} - -var justFPU = &FeatureSet{ - Set: map[Feature]bool{ - X86FeatureFPU: true, - }} - -var justFPUandPAE = &FeatureSet{ - Set: map[Feature]bool{ - X86FeatureFPU: true, - X86FeaturePAE: true, - }} - -func TestIsSubset(t *testing.T) { - if !justFPU.IsSubset(justFPUandPAE) { - t.Errorf("Got %v is not subset of %v, want IsSubset being true", justFPU, justFPUandPAE) - } - - if justFPUandPAE.IsSubset(justFPU) { - t.Errorf("Got %v is a subset of %v, want IsSubset being false", justFPU, justFPUandPAE) - } -} - -func TestTakeFeatureIntersection(t *testing.T) { - testFeatures := HostFeatureSet() - testFeatures.TakeFeatureIntersection(justFPU) - if !testFeatures.IsSubset(justFPU) { - t.Errorf("Got more features than expected after intersecting host features with justFPU: %v, want %v", testFeatures.Set, justFPU.Set) - } - if !testFeatures.HasFeature(X86FeatureFPU) { - t.Errorf("Got no features in testFeatures after intersecting, want %v", X86FeatureFPU) - } -} - -// TODO(b/73346484): Run this test on a very old platform, and make sure more -// bits are enabled than just FPU and PAE. This test currently may not detect -// if HostFeatureSet gives back junk bits. -func TestHostFeatureSet(t *testing.T) { - hostFeatures := HostFeatureSet() - if !justFPUandPAE.IsSubset(hostFeatures) { - t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures) - } -} - -func TestHasFeature(t *testing.T) { - if !justFPU.HasFeature(X86FeatureFPU) { - t.Errorf("HasFeature failed, %v should contain %v", justFPU, X86FeatureFPU) - } - - if justFPU.HasFeature(X86FeatureAVX) { - t.Errorf("HasFeature failed, %v should not contain %v", justFPU, X86FeatureAVX) - } -} - -// Note: these tests are aware of and abuse internal details of FeatureSets. -// Users of FeatureSets should not depend on this. -func TestAdd(t *testing.T) { - // Test a basic insertion into the FeatureSet. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureCLFSH) - if len(testFeatures.Set) != 1 { - t.Errorf("Got length %v want 1", len(testFeatures.Set)) - } - - if !testFeatures.HasFeature(X86FeatureCLFSH) { - t.Errorf("Add failed, got %v want set with %v", testFeatures, X86FeatureCLFSH) - } - - // Test that duplicates are ignored. - testFeatures.Add(X86FeatureCLFSH) - if len(testFeatures.Set) != 1 { - t.Errorf("Got length %v, want 1", len(testFeatures.Set)) - } -} - -func TestRemove(t *testing.T) { - // Try removing the last feature. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureFPU) - testFeatures.Add(X86FeaturePAE) - testFeatures.Remove(X86FeaturePAE) - if !testFeatures.HasFeature(X86FeatureFPU) || len(testFeatures.Set) != 1 || testFeatures.HasFeature(X86FeaturePAE) { - t.Errorf("Remove failed, got %v want %v", testFeatures, justFPU) - } - - // Try removing a feature not in the set. - testFeatures.Remove(X86FeatureRDRAND) - if !testFeatures.HasFeature(X86FeatureFPU) || len(testFeatures.Set) != 1 { - t.Errorf("Remove failed, got %v want %v", testFeatures, justFPU) - } -} - -func TestFeatureFromString(t *testing.T) { - f, ok := FeatureFromString("avx") - if f != X86FeatureAVX || !ok { - t.Errorf("got %v want avx", f) - } - - f, ok = FeatureFromString("bad") - if ok { - t.Errorf("got %v want nothing", f) - } -} - -// This tests function 0 (eax=0), which returns the vendor ID and highest cpuid -// function reported to be available. -func TestEmulateIDVendorAndLength(t *testing.T) { - testFeatures := newEmptyFeatureSet() - - ax, bx, cx, dx := testFeatures.EmulateID(0, 0) - wantEax := uint32(0xd) // Highest supported cpuid function. - - // These magical constants are the characters of "GenuineIntel". - // See Intel AN485 for a reference on why they are laid out like this. - wantEbx := uint32(0x756e6547) - wantEcx := uint32(0x6c65746e) - wantEdx := uint32(0x49656e69) - if wantEax != ax { - t.Errorf("highest function failed, got %x want %x", ax, wantEax) - } - - if wantEbx != bx || wantEcx != cx || wantEdx != dx { - t.Errorf("vendor string emulation failed, bx:cx:dx, got %x:%x:%x want %x:%x:%x", bx, cx, dx, wantEbx, wantEcx, wantEdx) - } -} - -func TestEmulateIDBasicFeatures(t *testing.T) { - // Make a minimal test feature set. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureCLFSH) - testFeatures.Add(X86FeatureAVX) - - ax, bx, cx, dx := testFeatures.EmulateID(1, 0) - ECXAVXBit := uint32(1 << uint(X86FeatureAVX)) - EDXCLFlushBit := uint32(1 << uint(X86FeatureCLFSH-32)) // We adjust by 32 since it's in block 1. - - if EDXCLFlushBit&dx == 0 || dx&^EDXCLFlushBit != 0 { - t.Errorf("EmulateID failed, got feature bits %x want %x", dx, testFeatures.blockMask(1)) - } - - if ECXAVXBit&cx == 0 || cx&^ECXAVXBit != 0 { - t.Errorf("EmulateID failed, got feature bits %x want %x", cx, testFeatures.blockMask(0)) - } - - // Default signature bits, based on values for i7/Xeon. - // See Intel AN485 for information on stepping/model bits. - defaultSignature := uint32(0x000106a0) - if defaultSignature != ax { - t.Errorf("EmulateID stepping emulation failed, got %x want %x", ax, defaultSignature) - } - - clflushSizeInfo := uint32(8 << 8) - if clflushSizeInfo != bx { - t.Errorf("EmulateID bx emulation failed, got %x want %x", bx, clflushSizeInfo) - } -} - -func TestEmulateIDExtendedFeatures(t *testing.T) { - // Make a minimal test feature set, one bit in each extended feature word. - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureSMEP) - testFeatures.Add(X86FeatureAVX512VBMI) - - ax, bx, cx, dx := testFeatures.EmulateID(7, 0) - EBXSMEPBit := uint32(1 << uint(X86FeatureSMEP-2*32)) // Adjust by 2*32 since SMEP is a block 2 feature. - ECXAVXBit := uint32(1 << uint(X86FeatureAVX512VBMI-3*32)) // We adjust by 3*32 since it's a block 3 feature. - - // Test that the desired bit is set and no other bits are set. - if EBXSMEPBit&bx == 0 || bx&^EBXSMEPBit != 0 { - t.Errorf("extended feature emulation failed, got feature bits %x want %x", bx, testFeatures.blockMask(2)) - } - - if ECXAVXBit&cx == 0 || cx&^ECXAVXBit != 0 { - t.Errorf("extended feature emulation failed, got feature bits %x want %x", cx, testFeatures.blockMask(3)) - } - - if ax != 0 || dx != 0 { - t.Errorf("extended feature emulation failed, ax:dx, got %x:%x want 0:0", ax, dx) - } - - // Check that no subleaves other than 0 do anything. - ax, bx, cx, dx = testFeatures.EmulateID(7, 1) - if ax != 0 || bx != 0 || cx != 0 || dx != 0 { - t.Errorf("extended feature emulation failed, got %x:%x:%x:%x want 0:0", ax, bx, cx, dx) - } - -} - -// Checks that the expected extended features are available via cpuid functions -// 0x80000000 and up. -func TestEmulateIDExtended(t *testing.T) { - testFeatures := newEmptyFeatureSet() - testFeatures.Add(X86FeatureSYSCALL) - EDXSYSCALLBit := uint32(1 << uint(X86FeatureSYSCALL-6*32)) // Adjust by 6*32 since SYSCALL is a block 6 feature. - - ax, bx, cx, dx := testFeatures.EmulateID(0x80000000, 0) - if ax != 0x80000001 || bx != 0 || cx != 0 || dx != 0 { - t.Errorf("EmulateID extended emulation failed, ax:bx:cx:dx, got %x:%x:%x:%x want 0x80000001:0:0:0", ax, bx, cx, dx) - } - - _, _, _, dx = testFeatures.EmulateID(0x80000001, 0) - if EDXSYSCALLBit&dx == 0 || dx&^EDXSYSCALLBit != 0 { - t.Errorf("extended feature emulation failed, got feature bits %x want %x", dx, testFeatures.blockMask(6)) - } -} diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD deleted file mode 100644 index 4c336ea84..000000000 --- a/pkg/eventchannel/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "eventchannel", - srcs = [ - "event.go", - ], - importpath = "gvisor.dev/gvisor/pkg/eventchannel", - visibility = ["//:sandbox"], - deps = [ - ":eventchannel_go_proto", - "//pkg/log", - "//pkg/unet", - "@com_github_golang_protobuf//proto:go_default_library", - "@com_github_golang_protobuf//ptypes:go_default_library_gen", - ], -) - -proto_library( - name = "eventchannel_proto", - srcs = ["event.proto"], -) - -go_proto_library( - name = "eventchannel_go_proto", - importpath = "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto", - proto = ":eventchannel_proto", - visibility = ["//:sandbox"], -) diff --git a/pkg/eventchannel/event.proto b/pkg/eventchannel/event.proto deleted file mode 100644 index 34468f072..000000000 --- a/pkg/eventchannel/event.proto +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -// A debug event encapsulates any other event protobuf in text format. This is -// useful because clients reading events emitted this way do not need to link -// the event protobufs to display them in a human-readable format. -message DebugEvent { - // Name of the inner message. - string name = 1; - // Text representation of the inner message content. - string text = 2; -} diff --git a/pkg/eventchannel/eventchannel_go_proto/event.pb.go b/pkg/eventchannel/eventchannel_go_proto/event.pb.go new file mode 100755 index 000000000..bb71ed3e6 --- /dev/null +++ b/pkg/eventchannel/eventchannel_go_proto/event.pb.go @@ -0,0 +1,85 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/eventchannel/event.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type DebugEvent struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Text string `protobuf:"bytes,2,opt,name=text,proto3" json:"text,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *DebugEvent) Reset() { *m = DebugEvent{} } +func (m *DebugEvent) String() string { return proto.CompactTextString(m) } +func (*DebugEvent) ProtoMessage() {} +func (*DebugEvent) Descriptor() ([]byte, []int) { + return fileDescriptor_fcfbd51abd9de962, []int{0} +} + +func (m *DebugEvent) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_DebugEvent.Unmarshal(m, b) +} +func (m *DebugEvent) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_DebugEvent.Marshal(b, m, deterministic) +} +func (m *DebugEvent) XXX_Merge(src proto.Message) { + xxx_messageInfo_DebugEvent.Merge(m, src) +} +func (m *DebugEvent) XXX_Size() int { + return xxx_messageInfo_DebugEvent.Size(m) +} +func (m *DebugEvent) XXX_DiscardUnknown() { + xxx_messageInfo_DebugEvent.DiscardUnknown(m) +} + +var xxx_messageInfo_DebugEvent proto.InternalMessageInfo + +func (m *DebugEvent) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *DebugEvent) GetText() string { + if m != nil { + return m.Text + } + return "" +} + +func init() { + proto.RegisterType((*DebugEvent)(nil), "gvisor.DebugEvent") +} + +func init() { proto.RegisterFile("pkg/eventchannel/event.proto", fileDescriptor_fcfbd51abd9de962) } + +var fileDescriptor_fcfbd51abd9de962 = []byte{ + // 103 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x29, 0xc8, 0x4e, 0xd7, + 0x4f, 0x2d, 0x4b, 0xcd, 0x2b, 0x49, 0xce, 0x48, 0xcc, 0xcb, 0x4b, 0xcd, 0x81, 0x70, 0xf4, 0x0a, + 0x8a, 0xf2, 0x4b, 0xf2, 0x85, 0xd8, 0xd2, 0xcb, 0x32, 0x8b, 0xf3, 0x8b, 0x94, 0x4c, 0xb8, 0xb8, + 0x5c, 0x52, 0x93, 0x4a, 0xd3, 0x5d, 0x41, 0x72, 0x42, 0x42, 0x5c, 0x2c, 0x79, 0x89, 0xb9, 0xa9, + 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x60, 0x36, 0x48, 0xac, 0x24, 0xb5, 0xa2, 0x44, 0x82, + 0x09, 0x22, 0x06, 0x62, 0x27, 0xb1, 0x81, 0x0d, 0x31, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x17, + 0xee, 0x7f, 0xef, 0x64, 0x00, 0x00, 0x00, +} diff --git a/pkg/eventchannel/eventchannel_state_autogen.go b/pkg/eventchannel/eventchannel_state_autogen.go new file mode 100755 index 000000000..cfd3a5e43 --- /dev/null +++ b/pkg/eventchannel/eventchannel_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package eventchannel + diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD deleted file mode 100644 index 785c685a0..000000000 --- a/pkg/fd/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "fd", - srcs = ["fd.go"], - importpath = "gvisor.dev/gvisor/pkg/fd", - visibility = ["//visibility:public"], -) - -go_test( - name = "fd_test", - size = "small", - srcs = ["fd_test.go"], - embed = [":fd"], -) diff --git a/pkg/fd/fd_state_autogen.go b/pkg/fd/fd_state_autogen.go new file mode 100755 index 000000000..0320140b0 --- /dev/null +++ b/pkg/fd/fd_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package fd + diff --git a/pkg/fd/fd_test.go b/pkg/fd/fd_test.go deleted file mode 100644 index 5fb0ad47d..000000000 --- a/pkg/fd/fd_test.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fd - -import ( - "math" - "os" - "syscall" - "testing" -) - -func TestSetNegOne(t *testing.T) { - type entry struct { - name string - file *FD - fn func() error - } - var tests []entry - - fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("syscall.Socket:", err) - } - f1 := New(fd) - tests = append(tests, entry{ - "Release", - f1, - func() error { - return syscall.Close(f1.Release()) - }, - }) - - fd, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("syscall.Socket:", err) - } - f2 := New(fd) - tests = append(tests, entry{ - "Close", - f2, - f2.Close, - }) - - for _, test := range tests { - if err := test.fn(); err != nil { - t.Errorf("%s: %v", test.name, err) - continue - } - if fd := test.file.FD(); fd != -1 { - t.Errorf("%s: got FD() = %d, want = -1", test.name, fd) - } - } -} - -func TestStartsNegOne(t *testing.T) { - type entry struct { - name string - file *FD - } - - tests := []entry{ - {"-1", New(-1)}, - {"-2", New(-2)}, - {"MinInt32", New(math.MinInt32)}, - {"MinInt64", New(math.MinInt64)}, - } - - for _, test := range tests { - if fd := test.file.FD(); fd != -1 { - t.Errorf("%s: got FD() = %d, want = -1", test.name, fd) - } - } -} - -func TestFileDotFile(t *testing.T) { - fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("syscall.Socket:", err) - } - - f := New(fd) - of, err := f.File() - if err != nil { - t.Fatalf("File got err %v want nil", err) - } - - if ofd, nfd := int(of.Fd()), f.FD(); ofd == nfd || ofd == -1 { - // Try not to double close the FD. - f.Release() - - t.Fatalf("got %#v.File().Fd() = %d, want new FD", f, ofd) - } - - f.Close() - of.Close() -} - -func TestFileDotFileError(t *testing.T) { - f := &FD{ReadWriter{-2}} - - if of, err := f.File(); err == nil { - t.Errorf("File %v got nil err want non-nil", of) - of.Close() - } -} - -func TestNewFromFile(t *testing.T) { - f, err := NewFromFile(os.Stdin) - if err != nil { - t.Fatalf("NewFromFile got err %v want nil", err) - } - if nfd, ofd := f.FD(), int(os.Stdin.Fd()); nfd == -1 || nfd == ofd { - t.Errorf("got FD() = %d, want = new FD (old FD was %d)", nfd, ofd) - } - f.Close() -} - -func TestNewFromFileError(t *testing.T) { - f, err := NewFromFile(nil) - if err == nil { - t.Errorf("NewFromFile got %v with nil err want non-nil", f) - f.Close() - } -} diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD deleted file mode 100644 index d0552c06e..000000000 --- a/pkg/fdnotifier/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "fdnotifier", - srcs = [ - "fdnotifier.go", - "poll_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/fdnotifier", - visibility = ["//:sandbox"], - deps = ["//pkg/waiter"], -) diff --git a/pkg/fdnotifier/fdnotifier_state_autogen.go b/pkg/fdnotifier/fdnotifier_state_autogen.go new file mode 100755 index 000000000..6f6076b7b --- /dev/null +++ b/pkg/fdnotifier/fdnotifier_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package fdnotifier + diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD deleted file mode 100644 index e6a8dbd02..000000000 --- a/pkg/gate/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "gate", - srcs = [ - "gate.go", - ], - importpath = "gvisor.dev/gvisor/pkg/gate", - visibility = ["//visibility:public"], -) - -go_test( - name = "gate_test", - srcs = [ - "gate_test.go", - ], - deps = [ - ":gate", - ], -) diff --git a/pkg/gate/gate_state_autogen.go b/pkg/gate/gate_state_autogen.go new file mode 100755 index 000000000..a81fca776 --- /dev/null +++ b/pkg/gate/gate_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package gate + diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go deleted file mode 100644 index 5dbd8d712..000000000 --- a/pkg/gate/gate_test.go +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gate_test - -import ( - "sync" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/gate" -) - -func TestBasicEnter(t *testing.T) { - var g gate.Gate - - if !g.Enter() { - t.Fatalf("Failed to enter when it should be allowed") - } - - g.Leave() - - g.Close() - - if g.Enter() { - t.Fatalf("Allowed to enter when it should fail") - } -} - -func enterFunc(t *testing.T, g *gate.Gate, enter, leave, reenter chan struct{}, done1, done2, done3 *sync.WaitGroup) { - // Wait until instructed to enter. - <-enter - if !g.Enter() { - t.Errorf("Failed to enter when it should be allowed") - } - - done1.Done() - - // Wait until instructed to leave. - <-leave - g.Leave() - - done2.Done() - - // Wait until instructed to reenter. - <-reenter - if g.Enter() { - t.Errorf("Allowed to enter when it should fail") - } - done3.Done() -} - -func TestConcurrentEnter(t *testing.T) { - var g gate.Gate - var done1, done2, done3 sync.WaitGroup - - // Create 1000 worker goroutines. - enter := make(chan struct{}) - leave := make(chan struct{}) - reenter := make(chan struct{}) - done1.Add(1000) - done2.Add(1000) - done3.Add(1000) - for i := 0; i < 1000; i++ { - go enterFunc(t, &g, enter, leave, reenter, &done1, &done2, &done3) - } - - // Tell them all to enter, then leave. - close(enter) - done1.Wait() - - close(leave) - done2.Wait() - - // Close the gate, then have the workers try to enter again. - g.Close() - close(reenter) - done3.Wait() -} - -func closeFunc(g *gate.Gate, done chan struct{}) { - g.Close() - close(done) -} - -func TestCloseWaits(t *testing.T) { - var g gate.Gate - - // Enter 10 times. - for i := 0; i < 10; i++ { - if !g.Enter() { - t.Fatalf("Failed to enter when it should be allowed") - } - } - - // Launch closer. Check that it doesn't complete. - done := make(chan struct{}) - go closeFunc(&g, done) - - for i := 0; i < 10; i++ { - select { - case <-done: - t.Fatalf("Close function completed too soon") - case <-time.After(100 * time.Millisecond): - } - - g.Leave() - } - - // Now the closer must complete. - <-done -} - -func TestMultipleSerialCloses(t *testing.T) { - var g gate.Gate - - // Enter 10 times. - for i := 0; i < 10; i++ { - if !g.Enter() { - t.Fatalf("Failed to enter when it should be allowed") - } - } - - // Launch closer. Check that it doesn't complete. - done := make(chan struct{}) - go closeFunc(&g, done) - - for i := 0; i < 10; i++ { - select { - case <-done: - t.Fatalf("Close function completed too soon") - case <-time.After(100 * time.Millisecond): - } - - g.Leave() - } - - // Now the closer must complete. - <-done - - // Close again should not block. - done = make(chan struct{}) - go closeFunc(&g, done) - - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatalf("Second Close is blocking") - } -} - -func worker(g *gate.Gate, done *sync.WaitGroup) { - for { - if !g.Enter() { - break - } - g.Leave() - } - done.Done() -} - -func TestConcurrentAll(t *testing.T) { - var g gate.Gate - var done sync.WaitGroup - - // Launch 1000 goroutines to concurrently enter/leave. - done.Add(1000) - for i := 0; i < 1000; i++ { - go worker(&g, &done) - } - - // Wait for the goroutines to do some work, then close the gate. - time.Sleep(2 * time.Second) - g.Close() - - // Wait for all of them to complete. - done.Wait() -} diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD deleted file mode 100644 index 8f3defa25..000000000 --- a/pkg/ilist/BUILD +++ /dev/null @@ -1,57 +0,0 @@ -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ilist", - srcs = [ - "interface_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/ilist", - visibility = ["//visibility:public"], -) - -go_template_instance( - name = "interface_list", - out = "interface_list.go", - package = "ilist", - template = ":generic_list", - types = {}, -) - -# This list is used for benchmarking. -go_template_instance( - name = "test_list", - out = "test_list.go", - package = "ilist", - prefix = "direct", - template = ":generic_list", - types = { - "Element": "*direct", - "Linker": "*direct", - }, -) - -go_test( - name = "list_test", - size = "small", - srcs = [ - "list_test.go", - "test_list.go", - ], - embed = [":ilist"], -) - -go_template( - name = "generic_list", - srcs = [ - "list.go", - ], - opt_types = [ - "Element", - "ElementMapper", - "Linker", - ], - visibility = ["//visibility:public"], -) diff --git a/pkg/ilist/ilist_state_autogen.go b/pkg/ilist/ilist_state_autogen.go new file mode 100755 index 000000000..03d7a9564 --- /dev/null +++ b/pkg/ilist/ilist_state_autogen.go @@ -0,0 +1,38 @@ +// automatically generated by stateify. + +package ilist + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *List) beforeSave() {} +func (x *List) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *List) afterLoad() {} +func (x *List) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *Entry) beforeSave() {} +func (x *Entry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *Entry) afterLoad() {} +func (x *Entry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("ilist.List", (*List)(nil), state.Fns{Save: (*List).save, Load: (*List).load}) + state.Register("ilist.Entry", (*Entry)(nil), state.Fns{Save: (*Entry).save, Load: (*Entry).load}) +} diff --git a/pkg/ilist/list.go b/pkg/ilist/interface_list.go index 019caadca..940c2d3f6 100644..100755 --- a/pkg/ilist/list.go +++ b/pkg/ilist/interface_list.go @@ -1,18 +1,3 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package ilist provides the implementation of intrusive linked lists. package ilist // Linker is the interface that objects must implement if they want to be added diff --git a/pkg/ilist/list_test.go b/pkg/ilist/list_test.go deleted file mode 100644 index 3f9abfb56..000000000 --- a/pkg/ilist/list_test.go +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ilist - -import ( - "testing" -) - -type testEntry struct { - Entry - value int -} - -type direct struct { - directEntry - value int -} - -func verifyEquality(t *testing.T, entries []testEntry, l *List) { - t.Helper() - - i := 0 - for it := l.Front(); it != nil; it = it.Next() { - e := it.(*testEntry) - if e != &entries[i] { - t.Errorf("Wrong entry at index %d", i) - return - } - i++ - } - - if i != len(entries) { - t.Errorf("Wrong number of entries; want = %d, got = %d", len(entries), i) - return - } - - i = 0 - for it := l.Back(); it != nil; it = it.Prev() { - e := it.(*testEntry) - if e != &entries[len(entries)-1-i] { - t.Errorf("Wrong entry at index %d", i) - return - } - i++ - } - - if i != len(entries) { - t.Errorf("Wrong number of entries; want = %d, got = %d", len(entries), i) - return - } -} - -func TestZeroEmpty(t *testing.T) { - var l List - if l.Front() != nil { - t.Error("Front is non-nil") - } - if l.Back() != nil { - t.Error("Back is non-nil") - } -} - -func TestPushBack(t *testing.T) { - var l List - - // Test single entry insertion. - var entry testEntry - l.PushBack(&entry) - - e := l.Front().(*testEntry) - if e != &entry { - t.Error("Wrong entry returned") - } - - // Test inserting 100 entries. - l.Reset() - var entries [100]testEntry - for i := range entries { - l.PushBack(&entries[i]) - } - - verifyEquality(t, entries[:], &l) -} - -func TestPushFront(t *testing.T) { - var l List - - // Test single entry insertion. - var entry testEntry - l.PushFront(&entry) - - e := l.Front().(*testEntry) - if e != &entry { - t.Error("Wrong entry returned") - } - - // Test inserting 100 entries. - l.Reset() - var entries [100]testEntry - for i := range entries { - l.PushFront(&entries[len(entries)-1-i]) - } - - verifyEquality(t, entries[:], &l) -} - -func TestRemove(t *testing.T) { - // Remove entry from single-element list. - var l List - var entry testEntry - l.PushBack(&entry) - l.Remove(&entry) - if l.Front() != nil { - t.Error("List is empty") - } - - var entries [100]testEntry - - // Remove single element from lists of lengths 2 to 101. - for n := 1; n <= len(entries); n++ { - for extra := 0; extra <= n; extra++ { - l.Reset() - for i := 0; i < n; i++ { - if extra == i { - l.PushBack(&entry) - } - l.PushBack(&entries[i]) - } - if extra == n { - l.PushBack(&entry) - } - - l.Remove(&entry) - verifyEquality(t, entries[:n], &l) - } - } -} - -func TestReset(t *testing.T) { - var l List - - // Resetting list of one element. - l.PushBack(&testEntry{}) - if l.Front() == nil { - t.Error("List is empty") - } - - l.Reset() - if l.Front() != nil { - t.Error("List is not empty") - } - - // Resetting list of 10 elements. - for i := 0; i < 10; i++ { - l.PushBack(&testEntry{}) - } - - if l.Front() == nil { - t.Error("List is empty") - } - - l.Reset() - if l.Front() != nil { - t.Error("List is not empty") - } - - // Resetting empty list. - l.Reset() - if l.Front() != nil { - t.Error("List is not empty") - } -} - -func BenchmarkIterateForward(b *testing.B) { - var l List - for i := 0; i < 1000000; i++ { - l.PushBack(&testEntry{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Front(); e != nil; e = e.Next() { - tmp += e.(*testEntry).value - } - } -} - -func BenchmarkIterateBackward(b *testing.B) { - var l List - for i := 0; i < 1000000; i++ { - l.PushBack(&testEntry{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Back(); e != nil; e = e.Prev() { - tmp += e.(*testEntry).value - } - } -} - -func BenchmarkDirectIterateForward(b *testing.B) { - var l directList - for i := 0; i < 1000000; i++ { - l.PushBack(&direct{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Front(); e != nil; e = e.Next() { - tmp += e.value - } - } -} - -func BenchmarkDirectIterateBackward(b *testing.B) { - var l directList - for i := 0; i < 1000000; i++ { - l.PushBack(&direct{value: i}) - } - - for i := b.N; i > 0; i-- { - tmp := 0 - for e := l.Back(); e != nil; e = e.Prev() { - tmp += e.value - } - } -} diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD deleted file mode 100644 index c8e923a74..000000000 --- a/pkg/linewriter/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "linewriter", - srcs = ["linewriter.go"], - importpath = "gvisor.dev/gvisor/pkg/linewriter", - visibility = ["//visibility:public"], -) - -go_test( - name = "linewriter_test", - srcs = ["linewriter_test.go"], - embed = [":linewriter"], -) diff --git a/pkg/linewriter/linewriter_state_autogen.go b/pkg/linewriter/linewriter_state_autogen.go new file mode 100755 index 000000000..194088d76 --- /dev/null +++ b/pkg/linewriter/linewriter_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package linewriter + diff --git a/pkg/linewriter/linewriter_test.go b/pkg/linewriter/linewriter_test.go deleted file mode 100644 index 96dc7e6e0..000000000 --- a/pkg/linewriter/linewriter_test.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package linewriter - -import ( - "bytes" - "testing" -) - -func TestWriter(t *testing.T) { - testCases := []struct { - input []string - want []string - }{ - { - input: []string{"1\n", "2\n"}, - want: []string{"1", "2"}, - }, - { - input: []string{"1\n", "\n", "2\n"}, - want: []string{"1", "", "2"}, - }, - { - input: []string{"1\n2\n", "3\n"}, - want: []string{"1", "2", "3"}, - }, - { - input: []string{"1", "2\n"}, - want: []string{"12"}, - }, - { - // Data with no newline yet is omitted. - input: []string{"1\n", "2\n", "3"}, - want: []string{"1", "2"}, - }, - } - - for _, c := range testCases { - var lines [][]byte - - w := NewWriter(func(p []byte) { - // We must not retain p, so we must make a copy. - b := make([]byte, len(p)) - copy(b, p) - - lines = append(lines, b) - }) - - for _, in := range c.input { - n, err := w.Write([]byte(in)) - if err != nil { - t.Errorf("Write(%q) err got %v want nil (case %+v)", in, err, c) - } - if n != len(in) { - t.Errorf("Write(%q) b got %d want %d (case %+v)", in, n, len(in), c) - } - } - - if len(lines) != len(c.want) { - t.Errorf("len(lines) got %d want %d (case %+v)", len(lines), len(c.want), c) - } - - for i := range lines { - if !bytes.Equal(lines[i], []byte(c.want[i])) { - t.Errorf("item %d got %q want %q (case %+v)", i, lines[i], c.want[i], c) - } - } - } -} diff --git a/pkg/log/BUILD b/pkg/log/BUILD deleted file mode 100644 index 12615240c..000000000 --- a/pkg/log/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "log", - srcs = [ - "glog.go", - "glog_unsafe.go", - "json.go", - "json_k8s.go", - "log.go", - ], - importpath = "gvisor.dev/gvisor/pkg/log", - visibility = [ - "//visibility:public", - ], - deps = ["//pkg/linewriter"], -) - -go_test( - name = "log_test", - size = "small", - srcs = [ - "json_test.go", - "log_test.go", - ], - embed = [":log"], -) diff --git a/pkg/log/json_test.go b/pkg/log/json_test.go deleted file mode 100644 index f25224fe1..000000000 --- a/pkg/log/json_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package log - -import ( - "encoding/json" - "testing" -) - -// Tests that Level can marshal/unmarshal properly. -func TestLevelMarshal(t *testing.T) { - lvs := []Level{Warning, Info, Debug} - for _, lv := range lvs { - bs, err := lv.MarshalJSON() - if err != nil { - t.Errorf("error marshaling %v: %v", lv, err) - } - var lv2 Level - if err := lv2.UnmarshalJSON(bs); err != nil { - t.Errorf("error unmarshaling %v: %v", bs, err) - } - if lv != lv2 { - t.Errorf("marshal/unmarshal level got %v wanted %v", lv2, lv) - } - } -} - -// Test that integers can be properly unmarshaled. -func TestUnmarshalFromInt(t *testing.T) { - tcs := []struct { - i int - want Level - }{ - {0, Warning}, - {1, Info}, - {2, Debug}, - } - - for _, tc := range tcs { - j, err := json.Marshal(tc.i) - if err != nil { - t.Errorf("error marshaling %v: %v", tc.i, err) - } - var lv Level - if err := lv.UnmarshalJSON(j); err != nil { - t.Errorf("error unmarshaling %v: %v", j, err) - } - if lv != tc.want { - t.Errorf("marshal/unmarshal %v got %v want %v", tc.i, lv, tc.want) - } - } -} diff --git a/pkg/log/log_state_autogen.go b/pkg/log/log_state_autogen.go new file mode 100755 index 000000000..010b760a5 --- /dev/null +++ b/pkg/log/log_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package log + diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go deleted file mode 100644 index 0634e7c1f..000000000 --- a/pkg/log/log_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package log - -import ( - "fmt" - "testing" -) - -type testWriter struct { - lines []string - fail bool -} - -func (w *testWriter) Write(bytes []byte) (int, error) { - if w.fail { - return 0, fmt.Errorf("simulated failure") - } - w.lines = append(w.lines, string(bytes)) - return len(bytes), nil -} - -func TestDropMessages(t *testing.T) { - tw := &testWriter{} - w := Writer{Next: tw} - if _, err := w.Write([]byte("line 1\n")); err != nil { - t.Fatalf("Write failed, err: %v", err) - } - - tw.fail = true - if _, err := w.Write([]byte("error\n")); err == nil { - t.Fatalf("Write should have failed") - } - if _, err := w.Write([]byte("error\n")); err == nil { - t.Fatalf("Write should have failed") - } - - fmt.Printf("writer: %+v\n", w) - - tw.fail = false - if _, err := w.Write([]byte("line 2\n")); err != nil { - t.Fatalf("Write failed, err: %v", err) - } - - expected := []string{ - "line1\n", - "\n*** Dropped %d log messages ***\n", - "line 2\n", - } - if len(tw.lines) != len(expected) { - t.Fatalf("Writer should have logged %d lines, got: %v, expected: %v", len(expected), tw.lines, expected) - } - for i, l := range tw.lines { - if l == expected[i] { - t.Fatalf("line %d doesn't match, got: %v, expected: %v", i, l, expected[i]) - } - } -} diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD deleted file mode 100644 index 7b50e2b28..000000000 --- a/pkg/memutil/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "memutil", - srcs = ["memutil_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/memutil", - visibility = ["//visibility:public"], - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/memutil/memutil_state_autogen.go b/pkg/memutil/memutil_state_autogen.go new file mode 100755 index 000000000..52f337963 --- /dev/null +++ b/pkg/memutil/memutil_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package memutil + diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD deleted file mode 100644 index 3b8a691f4..000000000 --- a/pkg/metric/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "metric", - srcs = ["metric.go"], - importpath = "gvisor.dev/gvisor/pkg/metric", - visibility = ["//:sandbox"], - deps = [ - ":metric_go_proto", - "//pkg/eventchannel", - "//pkg/log", - ], -) - -proto_library( - name = "metric_proto", - srcs = ["metric.proto"], - visibility = ["//:sandbox"], -) - -go_proto_library( - name = "metric_go_proto", - importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto", - proto = ":metric_proto", - visibility = ["//:sandbox"], -) - -go_test( - name = "metric_test", - srcs = ["metric_test.go"], - embed = [":metric"], - deps = [ - ":metric_go_proto", - "//pkg/eventchannel", - "@com_github_golang_protobuf//proto:go_default_library", - ], -) diff --git a/pkg/metric/metric.proto b/pkg/metric/metric.proto deleted file mode 100644 index a2c2bd1ba..000000000 --- a/pkg/metric/metric.proto +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -// MetricMetadata contains all of the metadata describing a single metric. -message MetricMetadata { - // name is the unique name of the metric, usually in a "directory" format - // (e.g., /foo/count). - string name = 1; - - // description is a human-readable description of the metric. - string description = 2; - - // cumulative indicates that this metric is never decremented. - bool cumulative = 3; - - // sync indicates that values from the final metric event should be - // synchronized to the backing monitoring system at exit. - // - // If sync is false, values are only sent to the monitoring system - // periodically. There is no guarantee that values will ever be received by - // the monitoring system. - bool sync = 4; - - enum Type { UINT64 = 0; } - - // type is the type of the metric value. - Type type = 5; -} - -// MetricRegistration contains the metadata for all metrics that will be in -// future MetricUpdates. -message MetricRegistration { - repeated MetricMetadata metrics = 1; -} - -// MetricValue the value of a metric at a single point in time. -message MetricValue { - // name is the unique name of the metric, as in MetricMetadata. - string name = 1; - - // value is the value of the metric at a single point in time. The field set - // depends on the type of the metric. - oneof value { - uint64 uint64_value = 2; - } -} - -// MetricUpdate contains new values for multiple distinct metrics. -// -// Metrics whose values have not changed are not included. -message MetricUpdate { - repeated MetricValue metrics = 1; -} diff --git a/pkg/metric/metric_go_proto/metric.pb.go b/pkg/metric/metric_go_proto/metric.pb.go new file mode 100755 index 000000000..553236535 --- /dev/null +++ b/pkg/metric/metric_go_proto/metric.pb.go @@ -0,0 +1,297 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/metric/metric.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type MetricMetadata_Type int32 + +const ( + MetricMetadata_UINT64 MetricMetadata_Type = 0 +) + +var MetricMetadata_Type_name = map[int32]string{ + 0: "UINT64", +} + +var MetricMetadata_Type_value = map[string]int32{ + "UINT64": 0, +} + +func (x MetricMetadata_Type) String() string { + return proto.EnumName(MetricMetadata_Type_name, int32(x)) +} + +func (MetricMetadata_Type) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_87b8778a4ff2ab5c, []int{0, 0} +} + +type MetricMetadata struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Description string `protobuf:"bytes,2,opt,name=description,proto3" json:"description,omitempty"` + Cumulative bool `protobuf:"varint,3,opt,name=cumulative,proto3" json:"cumulative,omitempty"` + Sync bool `protobuf:"varint,4,opt,name=sync,proto3" json:"sync,omitempty"` + Type MetricMetadata_Type `protobuf:"varint,5,opt,name=type,proto3,enum=gvisor.MetricMetadata_Type" json:"type,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *MetricMetadata) Reset() { *m = MetricMetadata{} } +func (m *MetricMetadata) String() string { return proto.CompactTextString(m) } +func (*MetricMetadata) ProtoMessage() {} +func (*MetricMetadata) Descriptor() ([]byte, []int) { + return fileDescriptor_87b8778a4ff2ab5c, []int{0} +} + +func (m *MetricMetadata) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_MetricMetadata.Unmarshal(m, b) +} +func (m *MetricMetadata) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_MetricMetadata.Marshal(b, m, deterministic) +} +func (m *MetricMetadata) XXX_Merge(src proto.Message) { + xxx_messageInfo_MetricMetadata.Merge(m, src) +} +func (m *MetricMetadata) XXX_Size() int { + return xxx_messageInfo_MetricMetadata.Size(m) +} +func (m *MetricMetadata) XXX_DiscardUnknown() { + xxx_messageInfo_MetricMetadata.DiscardUnknown(m) +} + +var xxx_messageInfo_MetricMetadata proto.InternalMessageInfo + +func (m *MetricMetadata) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *MetricMetadata) GetDescription() string { + if m != nil { + return m.Description + } + return "" +} + +func (m *MetricMetadata) GetCumulative() bool { + if m != nil { + return m.Cumulative + } + return false +} + +func (m *MetricMetadata) GetSync() bool { + if m != nil { + return m.Sync + } + return false +} + +func (m *MetricMetadata) GetType() MetricMetadata_Type { + if m != nil { + return m.Type + } + return MetricMetadata_UINT64 +} + +type MetricRegistration struct { + Metrics []*MetricMetadata `protobuf:"bytes,1,rep,name=metrics,proto3" json:"metrics,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *MetricRegistration) Reset() { *m = MetricRegistration{} } +func (m *MetricRegistration) String() string { return proto.CompactTextString(m) } +func (*MetricRegistration) ProtoMessage() {} +func (*MetricRegistration) Descriptor() ([]byte, []int) { + return fileDescriptor_87b8778a4ff2ab5c, []int{1} +} + +func (m *MetricRegistration) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_MetricRegistration.Unmarshal(m, b) +} +func (m *MetricRegistration) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_MetricRegistration.Marshal(b, m, deterministic) +} +func (m *MetricRegistration) XXX_Merge(src proto.Message) { + xxx_messageInfo_MetricRegistration.Merge(m, src) +} +func (m *MetricRegistration) XXX_Size() int { + return xxx_messageInfo_MetricRegistration.Size(m) +} +func (m *MetricRegistration) XXX_DiscardUnknown() { + xxx_messageInfo_MetricRegistration.DiscardUnknown(m) +} + +var xxx_messageInfo_MetricRegistration proto.InternalMessageInfo + +func (m *MetricRegistration) GetMetrics() []*MetricMetadata { + if m != nil { + return m.Metrics + } + return nil +} + +type MetricValue struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + // Types that are valid to be assigned to Value: + // *MetricValue_Uint64Value + Value isMetricValue_Value `protobuf_oneof:"value"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *MetricValue) Reset() { *m = MetricValue{} } +func (m *MetricValue) String() string { return proto.CompactTextString(m) } +func (*MetricValue) ProtoMessage() {} +func (*MetricValue) Descriptor() ([]byte, []int) { + return fileDescriptor_87b8778a4ff2ab5c, []int{2} +} + +func (m *MetricValue) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_MetricValue.Unmarshal(m, b) +} +func (m *MetricValue) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_MetricValue.Marshal(b, m, deterministic) +} +func (m *MetricValue) XXX_Merge(src proto.Message) { + xxx_messageInfo_MetricValue.Merge(m, src) +} +func (m *MetricValue) XXX_Size() int { + return xxx_messageInfo_MetricValue.Size(m) +} +func (m *MetricValue) XXX_DiscardUnknown() { + xxx_messageInfo_MetricValue.DiscardUnknown(m) +} + +var xxx_messageInfo_MetricValue proto.InternalMessageInfo + +func (m *MetricValue) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +type isMetricValue_Value interface { + isMetricValue_Value() +} + +type MetricValue_Uint64Value struct { + Uint64Value uint64 `protobuf:"varint,2,opt,name=uint64_value,json=uint64Value,proto3,oneof"` +} + +func (*MetricValue_Uint64Value) isMetricValue_Value() {} + +func (m *MetricValue) GetValue() isMetricValue_Value { + if m != nil { + return m.Value + } + return nil +} + +func (m *MetricValue) GetUint64Value() uint64 { + if x, ok := m.GetValue().(*MetricValue_Uint64Value); ok { + return x.Uint64Value + } + return 0 +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*MetricValue) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*MetricValue_Uint64Value)(nil), + } +} + +type MetricUpdate struct { + Metrics []*MetricValue `protobuf:"bytes,1,rep,name=metrics,proto3" json:"metrics,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *MetricUpdate) Reset() { *m = MetricUpdate{} } +func (m *MetricUpdate) String() string { return proto.CompactTextString(m) } +func (*MetricUpdate) ProtoMessage() {} +func (*MetricUpdate) Descriptor() ([]byte, []int) { + return fileDescriptor_87b8778a4ff2ab5c, []int{3} +} + +func (m *MetricUpdate) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_MetricUpdate.Unmarshal(m, b) +} +func (m *MetricUpdate) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_MetricUpdate.Marshal(b, m, deterministic) +} +func (m *MetricUpdate) XXX_Merge(src proto.Message) { + xxx_messageInfo_MetricUpdate.Merge(m, src) +} +func (m *MetricUpdate) XXX_Size() int { + return xxx_messageInfo_MetricUpdate.Size(m) +} +func (m *MetricUpdate) XXX_DiscardUnknown() { + xxx_messageInfo_MetricUpdate.DiscardUnknown(m) +} + +var xxx_messageInfo_MetricUpdate proto.InternalMessageInfo + +func (m *MetricUpdate) GetMetrics() []*MetricValue { + if m != nil { + return m.Metrics + } + return nil +} + +func init() { + proto.RegisterEnum("gvisor.MetricMetadata_Type", MetricMetadata_Type_name, MetricMetadata_Type_value) + proto.RegisterType((*MetricMetadata)(nil), "gvisor.MetricMetadata") + proto.RegisterType((*MetricRegistration)(nil), "gvisor.MetricRegistration") + proto.RegisterType((*MetricValue)(nil), "gvisor.MetricValue") + proto.RegisterType((*MetricUpdate)(nil), "gvisor.MetricUpdate") +} + +func init() { proto.RegisterFile("pkg/metric/metric.proto", fileDescriptor_87b8778a4ff2ab5c) } + +var fileDescriptor_87b8778a4ff2ab5c = []byte{ + // 288 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x91, 0xc1, 0x4b, 0xc3, 0x30, + 0x14, 0xc6, 0x17, 0xd7, 0x75, 0xfa, 0x3a, 0x86, 0x44, 0xd0, 0x80, 0x20, 0xa5, 0x5e, 0x7a, 0xb1, + 0x93, 0x39, 0x76, 0xf3, 0xe2, 0x41, 0xf4, 0x30, 0x85, 0xb0, 0x79, 0x95, 0xd8, 0x86, 0x12, 0x5c, + 0xdb, 0xd0, 0xa4, 0x85, 0xfe, 0x75, 0xfe, 0x6b, 0xd2, 0x17, 0x95, 0x4d, 0x76, 0xca, 0xcb, 0xfb, + 0xde, 0xf7, 0xf1, 0xcb, 0x0b, 0x5c, 0xe8, 0xcf, 0x7c, 0x56, 0x48, 0x5b, 0xab, 0xf4, 0xe7, 0x48, + 0x74, 0x5d, 0xd9, 0x8a, 0xfa, 0x79, 0xab, 0x4c, 0x55, 0x47, 0x5f, 0x04, 0xa6, 0x2b, 0x14, 0x56, + 0xd2, 0x8a, 0x4c, 0x58, 0x41, 0x29, 0x78, 0xa5, 0x28, 0x24, 0x23, 0x21, 0x89, 0x4f, 0x38, 0xd6, + 0x34, 0x84, 0x20, 0x93, 0x26, 0xad, 0x95, 0xb6, 0xaa, 0x2a, 0xd9, 0x11, 0x4a, 0xbb, 0x2d, 0x7a, + 0x05, 0x90, 0x36, 0x45, 0xb3, 0x15, 0x56, 0xb5, 0x92, 0x0d, 0x43, 0x12, 0x1f, 0xf3, 0x9d, 0x4e, + 0x9f, 0x6a, 0xba, 0x32, 0x65, 0x1e, 0x2a, 0x58, 0xd3, 0x19, 0x78, 0xb6, 0xd3, 0x92, 0x8d, 0x42, + 0x12, 0x4f, 0xe7, 0x97, 0x89, 0x63, 0x4a, 0xf6, 0x79, 0x92, 0x75, 0xa7, 0x25, 0xc7, 0xc1, 0x88, + 0x82, 0xd7, 0xdf, 0x28, 0x80, 0xbf, 0x79, 0x7e, 0x59, 0x2f, 0x17, 0xa7, 0x83, 0xe8, 0x11, 0xa8, + 0x33, 0x70, 0x99, 0x2b, 0x63, 0x6b, 0x81, 0x38, 0xb7, 0x30, 0x76, 0xef, 0x35, 0x8c, 0x84, 0xc3, + 0x38, 0x98, 0x9f, 0x1f, 0x4e, 0xe7, 0xbf, 0x63, 0xd1, 0x2b, 0x04, 0x4e, 0x7a, 0x13, 0xdb, 0x46, + 0x1e, 0xdc, 0xc2, 0x35, 0x4c, 0x1a, 0x55, 0xda, 0xe5, 0xe2, 0xbd, 0xed, 0x67, 0x70, 0x0d, 0xde, + 0xd3, 0x80, 0x07, 0xae, 0x8b, 0xc6, 0x87, 0x31, 0x8c, 0x50, 0x8d, 0xee, 0x61, 0xe2, 0x02, 0x37, + 0x3a, 0x13, 0x56, 0xd2, 0x9b, 0xff, 0x48, 0x67, 0xfb, 0x48, 0x68, 0xff, 0xe3, 0xf9, 0xf0, 0xf1, + 0xa3, 0xee, 0xbe, 0x03, 0x00, 0x00, 0xff, 0xff, 0xcb, 0x7f, 0xcb, 0x46, 0xc3, 0x01, 0x00, 0x00, +} diff --git a/pkg/metric/metric_state_autogen.go b/pkg/metric/metric_state_autogen.go new file mode 100755 index 000000000..985c28832 --- /dev/null +++ b/pkg/metric/metric_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package metric + diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go deleted file mode 100644 index 34969385a..000000000 --- a/pkg/metric/metric_test.go +++ /dev/null @@ -1,252 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package metric - -import ( - "testing" - - "github.com/golang/protobuf/proto" - "gvisor.dev/gvisor/pkg/eventchannel" - pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto" -) - -// sliceEmitter implements eventchannel.Emitter by appending all messages to a -// slice. -type sliceEmitter []proto.Message - -// Emit implements eventchannel.Emitter.Emit. -func (s *sliceEmitter) Emit(msg proto.Message) (bool, error) { - *s = append(*s, msg) - return false, nil -} - -// Emit implements eventchannel.Emitter.Close. -func (s *sliceEmitter) Close() error { - return nil -} - -// Reset clears all events in s. -func (s *sliceEmitter) Reset() { - *s = nil -} - -// emitter is the eventchannel.Emitter used for all tests. Package eventchannel -// doesn't allow removing Emitters, so we must use one global emitter for all -// test cases. -var emitter sliceEmitter - -func init() { - eventchannel.AddEmitter(&emitter) -} - -// reset clears all global state in the metric package. -func reset() { - initialized = false - allMetrics = makeMetricSet() - emitter.Reset() -} - -const ( - fooDescription = "Foo!" - barDescription = "Bar Baz" -) - -func TestInitialize(t *testing.T) { - defer reset() - - _, err := NewUint64Metric("/foo", false, fooDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - _, err = NewUint64Metric("/bar", true, barDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - Initialize() - - if len(emitter) != 1 { - t.Fatalf("Initialize emitted %d events want 1", len(emitter)) - } - - mr, ok := emitter[0].(*pb.MetricRegistration) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricRegistration", emitter[0], emitter[0]) - } - - if len(mr.Metrics) != 2 { - t.Errorf("MetricRegistration got %d metrics want 2", len(mr.Metrics)) - } - - foundFoo := false - foundBar := false - for _, m := range mr.Metrics { - if m.Type != pb.MetricMetadata_UINT64 { - t.Errorf("Metadata %+v Type got %v want %v", m, m.Type, pb.MetricMetadata_UINT64) - } - if !m.Cumulative { - t.Errorf("Metadata %+v Cumulative got false want true", m) - } - - switch m.Name { - case "/foo": - foundFoo = true - if m.Description != fooDescription { - t.Errorf("/foo %+v Description got %q want %q", m, m.Description, fooDescription) - } - if m.Sync { - t.Errorf("/foo %+v Sync got true want false", m) - } - case "/bar": - foundBar = true - if m.Description != barDescription { - t.Errorf("/bar %+v Description got %q want %q", m, m.Description, barDescription) - } - if !m.Sync { - t.Errorf("/bar %+v Sync got true want false", m) - } - } - } - - if !foundFoo { - t.Errorf("/foo not found: %+v", emitter) - } - if !foundBar { - t.Errorf("/bar not found: %+v", emitter) - } -} - -func TestDisable(t *testing.T) { - defer reset() - - _, err := NewUint64Metric("/foo", false, fooDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - _, err = NewUint64Metric("/bar", true, barDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - Disable() - - if len(emitter) != 1 { - t.Fatalf("Initialize emitted %d events want 1", len(emitter)) - } - - mr, ok := emitter[0].(*pb.MetricRegistration) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricRegistration", emitter[0], emitter[0]) - } - - if len(mr.Metrics) != 0 { - t.Errorf("MetricRegistration got %d metrics want 0", len(mr.Metrics)) - } -} - -func TestEmitMetricUpdate(t *testing.T) { - defer reset() - - foo, err := NewUint64Metric("/foo", false, fooDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - _, err = NewUint64Metric("/bar", true, barDescription) - if err != nil { - t.Fatalf("NewUint64Metric got err %v want nil", err) - } - - Initialize() - - // Don't care about the registration metrics. - emitter.Reset() - EmitMetricUpdate() - - if len(emitter) != 1 { - t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter)) - } - - update, ok := emitter[0].(*pb.MetricUpdate) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0]) - } - - if len(update.Metrics) != 2 { - t.Errorf("MetricUpdate got %d metrics want 2", len(update.Metrics)) - } - - // Both are included for their initial values. - foundFoo := false - foundBar := false - for _, m := range update.Metrics { - switch m.Name { - case "/foo": - foundFoo = true - case "/bar": - foundBar = true - } - uv, ok := m.Value.(*pb.MetricValue_Uint64Value) - if !ok { - t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) - continue - } - if uv.Uint64Value != 0 { - t.Errorf("%v: Value got %v want 0", m, uv.Uint64Value) - } - } - - if !foundFoo { - t.Errorf("/foo not found: %+v", emitter) - } - if !foundBar { - t.Errorf("/bar not found: %+v", emitter) - } - - // Increment foo. Only it is included in the next update. - foo.Increment() - - emitter.Reset() - EmitMetricUpdate() - - if len(emitter) != 1 { - t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter)) - } - - update, ok = emitter[0].(*pb.MetricUpdate) - if !ok { - t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0]) - } - - if len(update.Metrics) != 1 { - t.Errorf("MetricUpdate got %d metrics want 1", len(update.Metrics)) - } - - m := update.Metrics[0] - - if m.Name != "/foo" { - t.Errorf("Metric %+v name got %q want '/foo'", m, m.Name) - } - - uv, ok := m.Value.(*pb.MetricValue_Uint64Value) - if !ok { - t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) - } - if uv.Uint64Value != 1 { - t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value) - } -} diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD deleted file mode 100644 index c6737bf97..000000000 --- a/pkg/p9/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -go_library( - name = "p9", - srcs = [ - "buffer.go", - "client.go", - "client_file.go", - "file.go", - "handlers.go", - "messages.go", - "p9.go", - "path_tree.go", - "pool.go", - "server.go", - "transport.go", - "version.go", - ], - importpath = "gvisor.dev/gvisor/pkg/p9", - deps = [ - "//pkg/fd", - "//pkg/log", - "//pkg/unet", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "p9_test", - size = "small", - srcs = [ - "buffer_test.go", - "client_test.go", - "messages_test.go", - "p9_test.go", - "pool_test.go", - "transport_test.go", - "version_test.go", - ], - embed = [":p9"], - deps = [ - "//pkg/fd", - "//pkg/unet", - ], -) diff --git a/pkg/p9/buffer_test.go b/pkg/p9/buffer_test.go deleted file mode 100644 index a9c75f86b..000000000 --- a/pkg/p9/buffer_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "testing" -) - -func TestBufferOverrun(t *testing.T) { - buf := &buffer{ - // This header indicates that a large string should follow, but - // it is only two bytes. Reading a string should cause an - // overrun. - data: []byte{0x0, 0x16}, - } - if s := buf.ReadString(); s != "" { - t.Errorf("overrun read got %s, want empty", s) - } -} diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go deleted file mode 100644 index 87b2dd61e..000000000 --- a/pkg/p9/client_test.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/unet" -) - -// TestVersion tests the version negotiation. -func TestVersion(t *testing.T) { - // First, create a new server and connection. - serverSocket, clientSocket, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer clientSocket.Close() - - // Create a new server and client. - s := NewServer(nil) - go s.Handle(serverSocket) - - // NewClient does a Tversion exchange, so this is our test for success. - c, err := NewClient(clientSocket, 1024*1024 /* 1M message size */, HighestVersionString()) - if err != nil { - t.Fatalf("got %v, expected nil", err) - } - - // Check a bogus version string. - if err := c.sendRecv(&Tversion{Version: "notokay", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL { - t.Errorf("got %v expected %v", err, syscall.EINVAL) - } - - // Check a bogus version number. - if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL { - t.Errorf("got %v expected %v", err, syscall.EINVAL) - } - - // Check a too high version number. - if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: 1024 * 1024}, &Rversion{}); err != syscall.EAGAIN { - t.Errorf("got %v expected %v", err, syscall.EAGAIN) - } - - // Check an invalid MSize. - if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion), MSize: 0}, &Rversion{}); err != syscall.EINVAL { - t.Errorf("got %v expected %v", err, syscall.EINVAL) - } -} diff --git a/pkg/p9/local_server/BUILD b/pkg/p9/local_server/BUILD deleted file mode 100644 index aa6db186c..000000000 --- a/pkg/p9/local_server/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "local_server", - srcs = ["local_server.go"], - deps = [ - "//pkg/fd", - "//pkg/log", - "//pkg/p9", - "//pkg/unet", - ], -) diff --git a/pkg/p9/local_server/local_server.go b/pkg/p9/local_server/local_server.go deleted file mode 100644 index 905188aaa..000000000 --- a/pkg/p9/local_server/local_server.go +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Binary local_server provides a local 9P2000.L server for the p9 package. -// -// To use, first start the server: -// local_server /tmp/my_bind_addr -// -// Then, connect using the Linux 9P filesystem: -// mount -t 9p -o trans=unix /tmp/my_bind_addr /mnt -// -// This package also serves as an examplar. -package main - -import ( - "os" - "path" - "syscall" - - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/unet" -) - -// local wraps a local file. -type local struct { - p9.DefaultWalkGetAttr - - path string - file *os.File -} - -// info constructs a QID for this file. -func (l *local) info() (p9.QID, os.FileInfo, error) { - var ( - qid p9.QID - fi os.FileInfo - err error - ) - - // Stat the file. - if l.file != nil { - fi, err = l.file.Stat() - } else { - fi, err = os.Lstat(l.path) - } - if err != nil { - log.Warningf("error stating %#v: %v", l, err) - return qid, nil, err - } - - // Construct the QID type. - qid.Type = p9.ModeFromOS(fi.Mode()).QIDType() - - // Save the path from the Ino. - qid.Path = fi.Sys().(*syscall.Stat_t).Ino - return qid, fi, nil -} - -// Attach implements p9.Attacher.Attach. -func (l *local) Attach() (p9.File, error) { - return &local{path: "/"}, nil -} - -// Walk implements p9.File.Walk. -func (l *local) Walk(names []string) ([]p9.QID, p9.File, error) { - var qids []p9.QID - last := &local{path: l.path} - for _, name := range names { - c := &local{path: path.Join(last.path, name)} - qid, _, err := c.info() - if err != nil { - return nil, nil, err - } - qids = append(qids, qid) - last = c - } - return qids, last, nil -} - -// StatFS implements p9.File.StatFS. -// -// Not implemented. -func (l *local) StatFS() (p9.FSStat, error) { - return p9.FSStat{}, syscall.ENOSYS -} - -// FSync implements p9.File.FSync. -func (l *local) FSync() error { - return l.file.Sync() -} - -// GetAttr implements p9.File.GetAttr. -// -// Not fully implemented. -func (l *local) GetAttr(req p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) { - qid, fi, err := l.info() - if err != nil { - return qid, p9.AttrMask{}, p9.Attr{}, err - } - - stat := fi.Sys().(*syscall.Stat_t) - attr := p9.Attr{ - Mode: p9.FileMode(stat.Mode), - UID: p9.UID(stat.Uid), - GID: p9.GID(stat.Gid), - NLink: stat.Nlink, - RDev: stat.Rdev, - Size: uint64(stat.Size), - BlockSize: uint64(stat.Blksize), - Blocks: uint64(stat.Blocks), - ATimeSeconds: uint64(stat.Atim.Sec), - ATimeNanoSeconds: uint64(stat.Atim.Nsec), - MTimeSeconds: uint64(stat.Mtim.Sec), - MTimeNanoSeconds: uint64(stat.Mtim.Nsec), - CTimeSeconds: uint64(stat.Ctim.Sec), - CTimeNanoSeconds: uint64(stat.Ctim.Nsec), - } - valid := p9.AttrMask{ - Mode: true, - UID: true, - GID: true, - NLink: true, - RDev: true, - Size: true, - Blocks: true, - ATime: true, - MTime: true, - CTime: true, - } - - return qid, valid, attr, nil -} - -// SetAttr implements p9.File.SetAttr. -// -// Not implemented. -func (l *local) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { - return syscall.ENOSYS -} - -// Remove implements p9.File.Remove. -// -// Not implemented. -func (l *local) Remove() error { - return syscall.ENOSYS -} - -// Rename implements p9.File.Rename. -// -// Not implemented. -func (l *local) Rename(directory p9.File, name string) error { - return syscall.ENOSYS -} - -// Close implements p9.File.Close. -func (l *local) Close() error { - if l.file != nil { - return l.file.Close() - } - return nil -} - -// Open implements p9.File.Open. -func (l *local) Open(mode p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { - qid, _, err := l.info() - if err != nil { - return nil, qid, 0, err - } - - // Do the actual open. - f, err := os.OpenFile(l.path, int(mode), 0) - if err != nil { - return nil, qid, 0, err - } - l.file = f - - // Note: we don't send the local file for this server. - return nil, qid, 4096, nil -} - -// Read implements p9.File.Read. -func (l *local) ReadAt(p []byte, offset uint64) (int, error) { - return l.file.ReadAt(p, int64(offset)) -} - -// Write implements p9.File.Write. -func (l *local) WriteAt(p []byte, offset uint64) (int, error) { - return l.file.WriteAt(p, int64(offset)) -} - -// Create implements p9.File.Create. -func (l *local) Create(name string, mode p9.OpenFlags, permissions p9.FileMode, _ p9.UID, _ p9.GID) (*fd.FD, p9.File, p9.QID, uint32, error) { - f, err := os.OpenFile(l.path, int(mode)|syscall.O_CREAT|syscall.O_EXCL, os.FileMode(permissions)) - if err != nil { - return nil, nil, p9.QID{}, 0, err - } - - l2 := &local{path: path.Join(l.path, name), file: f} - qid, _, err := l2.info() - if err != nil { - l2.Close() - return nil, nil, p9.QID{}, 0, err - } - - return nil, l2, qid, 4096, nil -} - -// Mkdir implements p9.File.Mkdir. -// -// Not properly implemented. -func (l *local) Mkdir(name string, permissions p9.FileMode, _ p9.UID, _ p9.GID) (p9.QID, error) { - if err := os.Mkdir(path.Join(l.path, name), os.FileMode(permissions)); err != nil { - return p9.QID{}, err - } - - // Blank QID. - return p9.QID{}, nil -} - -// Symlink implements p9.File.Symlink. -// -// Not properly implemented. -func (l *local) Symlink(oldname string, newname string, _ p9.UID, _ p9.GID) (p9.QID, error) { - if err := os.Symlink(oldname, path.Join(l.path, newname)); err != nil { - return p9.QID{}, err - } - - // Blank QID. - return p9.QID{}, nil -} - -// Link implements p9.File.Link. -// -// Not properly implemented. -func (l *local) Link(target p9.File, newname string) error { - return os.Link(target.(*local).path, path.Join(l.path, newname)) -} - -// Mknod implements p9.File.Mknod. -// -// Not implemented. -func (l *local) Mknod(name string, mode p9.FileMode, major uint32, minor uint32, _ p9.UID, _ p9.GID) (p9.QID, error) { - return p9.QID{}, syscall.ENOSYS -} - -// RenameAt implements p9.File.RenameAt. -// -// Not implemented. -func (l *local) RenameAt(oldname string, newdir p9.File, newname string) error { - return syscall.ENOSYS -} - -// UnlinkAt implements p9.File.UnlinkAt. -// -// Not implemented. -func (l *local) UnlinkAt(name string, flags uint32) error { - return syscall.ENOSYS -} - -// Readdir implements p9.File.Readdir. -func (l *local) Readdir(offset uint64, count uint32) ([]p9.Dirent, error) { - // We only do *all* dirents in single shot. - const maxDirentBuffer = 1024 * 1024 - buf := make([]byte, maxDirentBuffer) - n, err := syscall.ReadDirent(int(l.file.Fd()), buf) - if err != nil { - // Return zero entries. - return nil, nil - } - - // Parse the entries; note that we read up to offset+count here. - _, newCount, newNames := syscall.ParseDirent(buf[:n], int(offset)+int(count), nil) - var dirents []p9.Dirent - for i := int(offset); i >= 0 && i < newCount; i++ { - entry := local{path: path.Join(l.path, newNames[i])} - qid, _, err := entry.info() - if err != nil { - continue - } - dirents = append(dirents, p9.Dirent{ - QID: qid, - Type: qid.Type, - Name: newNames[i], - Offset: uint64(i + 1), - }) - } - - return dirents, nil -} - -// Readlink implements p9.File.Readlink. -// -// Not properly implemented. -func (l *local) Readlink() (string, error) { - return os.Readlink(l.path) -} - -// Flush implements p9.File.Flush. -func (l *local) Flush() error { - return nil -} - -// Connect implements p9.File.Connect. -func (l *local) Connect(p9.ConnectFlags) (*fd.FD, error) { - return nil, syscall.ECONNREFUSED -} - -// Renamed implements p9.File.Renamed. -func (l *local) Renamed(parent p9.File, newName string) { - l.path = path.Join(parent.(*local).path, newName) -} - -// Allocate implements p9.File.Allocate. -func (l *local) Allocate(mode p9.AllocateMode, offset, length uint64) error { - return syscall.Fallocate(int(l.file.Fd()), mode.ToLinux(), int64(offset), int64(length)) -} - -func main() { - log.SetLevel(log.Debug) - - if len(os.Args) != 2 { - log.Warningf("usage: %s <bind-addr>", os.Args[0]) - os.Exit(1) - } - - // Bind and listen on the socket. - serverSocket, err := unet.BindAndListen(os.Args[1], false) - if err != nil { - log.Warningf("err binding: %v", err) - os.Exit(1) - } - - // Run the server. - s := p9.NewServer(&local{}) - s.Serve(serverSocket) -} - -var ( - _ p9.File = &local{} -) diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go deleted file mode 100644 index 6ba6a1654..000000000 --- a/pkg/p9/messages_test.go +++ /dev/null @@ -1,468 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "fmt" - "reflect" - "testing" -) - -func TestEncodeDecode(t *testing.T) { - objs := []encoder{ - &QID{ - Type: 1, - Version: 2, - Path: 3, - }, - &FSStat{ - Type: 1, - BlockSize: 2, - Blocks: 3, - BlocksFree: 4, - BlocksAvailable: 5, - Files: 6, - FilesFree: 7, - FSID: 8, - NameLength: 9, - }, - &AttrMask{ - Mode: true, - NLink: true, - UID: true, - GID: true, - RDev: true, - ATime: true, - MTime: true, - CTime: true, - INo: true, - Size: true, - Blocks: true, - BTime: true, - Gen: true, - DataVersion: true, - }, - &Attr{ - Mode: Exec, - UID: 2, - GID: 3, - NLink: 4, - RDev: 5, - Size: 6, - BlockSize: 7, - Blocks: 8, - ATimeSeconds: 9, - ATimeNanoSeconds: 10, - MTimeSeconds: 11, - MTimeNanoSeconds: 12, - CTimeSeconds: 13, - CTimeNanoSeconds: 14, - BTimeSeconds: 15, - BTimeNanoSeconds: 16, - Gen: 17, - DataVersion: 18, - }, - &SetAttrMask{ - Permissions: true, - UID: true, - GID: true, - Size: true, - ATime: true, - MTime: true, - CTime: true, - ATimeNotSystemTime: true, - MTimeNotSystemTime: true, - }, - &SetAttr{ - Permissions: 1, - UID: 2, - GID: 3, - Size: 4, - ATimeSeconds: 5, - ATimeNanoSeconds: 6, - MTimeSeconds: 7, - MTimeNanoSeconds: 8, - }, - &Dirent{ - QID: QID{Type: 1}, - Offset: 2, - Type: 3, - Name: "a", - }, - &Rlerror{ - Error: 1, - }, - &Tstatfs{ - FID: 1, - }, - &Rstatfs{ - FSStat: FSStat{Type: 1}, - }, - &Tlopen{ - FID: 1, - Flags: WriteOnly, - }, - &Rlopen{ - QID: QID{Type: 1}, - IoUnit: 2, - }, - &Tlconnect{ - FID: 1, - }, - &Rlconnect{}, - &Tlcreate{ - FID: 1, - Name: "a", - OpenFlags: 2, - Permissions: 3, - GID: 4, - }, - &Rlcreate{ - Rlopen{QID: QID{Type: 1}}, - }, - &Tsymlink{ - Directory: 1, - Name: "a", - Target: "b", - GID: 2, - }, - &Rsymlink{ - QID: QID{Type: 1}, - }, - &Tmknod{ - Directory: 1, - Name: "a", - Mode: 2, - Major: 3, - Minor: 4, - GID: 5, - }, - &Rmknod{ - QID: QID{Type: 1}, - }, - &Trename{ - FID: 1, - Directory: 2, - Name: "a", - }, - &Rrename{}, - &Treadlink{ - FID: 1, - }, - &Rreadlink{ - Target: "a", - }, - &Tgetattr{ - FID: 1, - AttrMask: AttrMask{Mode: true}, - }, - &Rgetattr{ - Valid: AttrMask{Mode: true}, - QID: QID{Type: 1}, - Attr: Attr{Mode: Write}, - }, - &Tsetattr{ - FID: 1, - Valid: SetAttrMask{Permissions: true}, - SetAttr: SetAttr{Permissions: Write}, - }, - &Rsetattr{}, - &Txattrwalk{ - FID: 1, - NewFID: 2, - Name: "a", - }, - &Rxattrwalk{ - Size: 1, - }, - &Txattrcreate{ - FID: 1, - Name: "a", - AttrSize: 2, - Flags: 3, - }, - &Rxattrcreate{}, - &Treaddir{ - Directory: 1, - Offset: 2, - Count: 3, - }, - &Rreaddir{ - // Count must be sufficient to encode a dirent. - Count: 0x18, - Entries: []Dirent{{QID: QID{Type: 2}}}, - }, - &Tfsync{ - FID: 1, - }, - &Rfsync{}, - &Tlink{ - Directory: 1, - Target: 2, - Name: "a", - }, - &Rlink{}, - &Tmkdir{ - Directory: 1, - Name: "a", - Permissions: 2, - GID: 3, - }, - &Rmkdir{ - QID: QID{Type: 1}, - }, - &Trenameat{ - OldDirectory: 1, - OldName: "a", - NewDirectory: 2, - NewName: "b", - }, - &Rrenameat{}, - &Tunlinkat{ - Directory: 1, - Name: "a", - Flags: 2, - }, - &Runlinkat{}, - &Tversion{ - MSize: 1, - Version: "a", - }, - &Rversion{ - MSize: 1, - Version: "a", - }, - &Tauth{ - AuthenticationFID: 1, - UserName: "a", - AttachName: "b", - UID: 2, - }, - &Rauth{ - QID: QID{Type: 1}, - }, - &Tattach{ - FID: 1, - Auth: Tauth{AuthenticationFID: 2}, - }, - &Rattach{ - QID: QID{Type: 1}, - }, - &Tflush{ - OldTag: 1, - }, - &Rflush{}, - &Twalk{ - FID: 1, - NewFID: 2, - Names: []string{"a"}, - }, - &Rwalk{ - QIDs: []QID{{Type: 1}}, - }, - &Tread{ - FID: 1, - Offset: 2, - Count: 3, - }, - &Rread{ - Data: []byte{'a'}, - }, - &Twrite{ - FID: 1, - Offset: 2, - Data: []byte{'a'}, - }, - &Rwrite{ - Count: 1, - }, - &Tclunk{ - FID: 1, - }, - &Rclunk{}, - &Tremove{ - FID: 1, - }, - &Rremove{}, - &Tflushf{ - FID: 1, - }, - &Rflushf{}, - &Twalkgetattr{ - FID: 1, - NewFID: 2, - Names: []string{"a"}, - }, - &Rwalkgetattr{ - QIDs: []QID{{Type: 1}}, - Valid: AttrMask{Mode: true}, - Attr: Attr{Mode: Write}, - }, - &Tucreate{ - Tlcreate: Tlcreate{ - FID: 1, - Name: "a", - OpenFlags: 2, - Permissions: 3, - GID: 4, - }, - UID: 5, - }, - &Rucreate{ - Rlcreate{Rlopen{QID: QID{Type: 1}}}, - }, - &Tumkdir{ - Tmkdir: Tmkdir{ - Directory: 1, - Name: "a", - Permissions: 2, - GID: 3, - }, - UID: 4, - }, - &Rumkdir{ - Rmkdir{QID: QID{Type: 1}}, - }, - &Tusymlink{ - Tsymlink: Tsymlink{ - Directory: 1, - Name: "a", - Target: "b", - GID: 2, - }, - UID: 3, - }, - &Rusymlink{ - Rsymlink{QID: QID{Type: 1}}, - }, - &Tumknod{ - Tmknod: Tmknod{ - Directory: 1, - Name: "a", - Mode: 2, - Major: 3, - Minor: 4, - GID: 5, - }, - UID: 6, - }, - &Rumknod{ - Rmknod{QID: QID{Type: 1}}, - }, - } - - for _, enc := range objs { - // Encode the original. - data := make([]byte, initialBufferLength) - buf := buffer{data: data[:0]} - enc.Encode(&buf) - - // Create a new object, same as the first. - enc2 := reflect.New(reflect.ValueOf(enc).Elem().Type()).Interface().(encoder) - buf2 := buffer{data: buf.data} - - // To be fair, we need to add any payloads (directly). - if pl, ok := enc.(payloader); ok { - enc2.(payloader).SetPayload(pl.Payload()) - } - - // And any file payloads (directly). - if fl, ok := enc.(filer); ok { - enc2.(filer).SetFilePayload(fl.FilePayload()) - } - - // Mark sure it was okay. - enc2.Decode(&buf2) - if buf2.isOverrun() { - t.Errorf("object %#v->%#v got overrun on decode", enc, enc2) - continue - } - - // Check that they are equal. - if !reflect.DeepEqual(enc, enc2) { - t.Errorf("object %#v and %#v differ", enc, enc2) - continue - } - } -} - -func TestMessageStrings(t *testing.T) { - for typ := range msgRegistry.factories { - entry := &msgRegistry.factories[typ] - if entry.create != nil { - name := fmt.Sprintf("%+v", typ) - t.Run(name, func(t *testing.T) { - defer func() { // Ensure no panic. - if r := recover(); r != nil { - t.Errorf("printing %s failed: %v", name, r) - } - }() - m := entry.create() - _ = fmt.Sprintf("%v", m) - err := ErrInvalidMsgType{MsgType(typ)} - _ = err.Error() - }) - } - } -} - -func TestRegisterDuplicate(t *testing.T) { - defer func() { - if r := recover(); r == nil { - // We expect a panic. - t.FailNow() - } - }() - - // Register a duplicate. - msgRegistry.register(MsgRlerror, func() message { return &Rlerror{} }) -} - -func TestMsgCache(t *testing.T) { - // Cache starts empty. - if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } - - // Message can be created with an empty cache. - msg, err := msgRegistry.get(0, MsgRlerror) - if err != nil { - t.Errorf("msgRegistry.get(): %v", err) - } - if got, want := len(msgRegistry.factories[MsgRlerror].cache), 0; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } - - // Check that message is added to the cache when returned. - msgRegistry.put(msg) - if got, want := len(msgRegistry.factories[MsgRlerror].cache), 1; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } - - // Check that returned message is reused. - if got, err := msgRegistry.get(0, MsgRlerror); err != nil { - t.Errorf("msgRegistry.get(): %v", err) - } else if msg != got { - t.Errorf("Message not reused, got: %d, want: %d", got, msg) - } - - // Check that cache doesn't grow beyond max size. - for i := 0; i < maxCacheSize+1; i++ { - msgRegistry.put(&Rlerror{}) - } - if got, want := len(msgRegistry.factories[MsgRlerror].cache), maxCacheSize; got != want { - t.Errorf("Wrong cache size, got: %d, want: %d", got, want) - } -} diff --git a/pkg/p9/p9_state_autogen.go b/pkg/p9/p9_state_autogen.go new file mode 100755 index 000000000..0b9556862 --- /dev/null +++ b/pkg/p9/p9_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package p9 + diff --git a/pkg/p9/p9_test.go b/pkg/p9/p9_test.go deleted file mode 100644 index 8dda6cc64..000000000 --- a/pkg/p9/p9_test.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "os" - "testing" -) - -func TestFileModeHelpers(t *testing.T) { - fns := map[FileMode]struct { - // name identifies the file mode. - name string - - // function is the function that should return true given the - // right FileMode. - function func(m FileMode) bool - }{ - ModeRegular: { - name: "regular", - function: FileMode.IsRegular, - }, - ModeDirectory: { - name: "directory", - function: FileMode.IsDir, - }, - ModeNamedPipe: { - name: "named pipe", - function: FileMode.IsNamedPipe, - }, - ModeCharacterDevice: { - name: "character device", - function: FileMode.IsCharacterDevice, - }, - ModeBlockDevice: { - name: "block device", - function: FileMode.IsBlockDevice, - }, - ModeSymlink: { - name: "symlink", - function: FileMode.IsSymlink, - }, - ModeSocket: { - name: "socket", - function: FileMode.IsSocket, - }, - } - for mode, info := range fns { - // Make sure the mode doesn't identify as anything but itself. - for testMode, testfns := range fns { - if mode != testMode && testfns.function(mode) { - t.Errorf("Mode %s returned true when asked if it was mode %s", info.name, testfns.name) - } - } - - // Make sure mode identifies as itself. - if !info.function(mode) { - t.Errorf("Mode %s returned false when asked if it was itself", info.name) - } - } -} - -func TestFileModeToQID(t *testing.T) { - for _, test := range []struct { - // name identifies the test. - name string - - // mode is the FileMode we start out with. - mode FileMode - - // want is the corresponding QIDType we expect. - want QIDType - }{ - { - name: "Directories are of type directory", - mode: ModeDirectory, - want: TypeDir, - }, - { - name: "Sockets are append-only files", - mode: ModeSocket, - want: TypeAppendOnly, - }, - { - name: "Named pipes are append-only files", - mode: ModeNamedPipe, - want: TypeAppendOnly, - }, - { - name: "Character devices are append-only files", - mode: ModeCharacterDevice, - want: TypeAppendOnly, - }, - { - name: "Symlinks are of type symlink", - mode: ModeSymlink, - want: TypeSymlink, - }, - { - name: "Regular files are of type regular", - mode: ModeRegular, - want: TypeRegular, - }, - { - name: "Block devices are regular files", - mode: ModeBlockDevice, - want: TypeRegular, - }, - } { - if qidType := test.mode.QIDType(); qidType != test.want { - t.Errorf("ModeToQID test %s failed: got %o, wanted %o", test.name, qidType, test.want) - } - } -} - -func TestP9ModeConverters(t *testing.T) { - for _, m := range []FileMode{ - ModeRegular, - ModeDirectory, - ModeCharacterDevice, - ModeBlockDevice, - ModeSocket, - ModeSymlink, - ModeNamedPipe, - } { - if mb := ModeFromOS(m.OSMode()); mb != m { - t.Errorf("Converting %o to OS.FileMode gives %o and is converted back as %o", m, m.OSMode(), mb) - } - } -} - -func TestOSModeConverters(t *testing.T) { - // Modes that can be converted back and forth. - for _, m := range []os.FileMode{ - 0, // Regular file. - os.ModeDir, - os.ModeCharDevice | os.ModeDevice, - os.ModeDevice, - os.ModeSocket, - os.ModeSymlink, - os.ModeNamedPipe, - } { - if mb := ModeFromOS(m).OSMode(); mb != m { - t.Errorf("Converting %o to p9.FileMode gives %o and is converted back as %o", m, ModeFromOS(m), mb) - } - } - - // Modes that will be converted to a regular file since p9 cannot - // express these. - for _, m := range []os.FileMode{ - os.ModeAppend, - os.ModeExclusive, - os.ModeTemporary, - } { - if p9Mode := ModeFromOS(m); p9Mode != ModeRegular { - t.Errorf("Converting %o to p9.FileMode should have given ModeRegular, but yielded %o", m, p9Mode) - } - } -} - -func TestAttrMaskContains(t *testing.T) { - req := AttrMask{Mode: true, Size: true} - have := AttrMask{} - if have.Contains(req) { - t.Fatalf("AttrMask %v should not be a superset of %v", have, req) - } - have.Mode = true - if have.Contains(req) { - t.Fatalf("AttrMask %v should not be a superset of %v", have, req) - } - have.Size = true - have.MTime = true - if !have.Contains(req) { - t.Fatalf("AttrMask %v should be a superset of %v", have, req) - } -} diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD deleted file mode 100644 index 6e939a49a..000000000 --- a/pkg/p9/p9test/BUILD +++ /dev/null @@ -1,88 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -alias( - name = "mockgen", - actual = "@com_github_golang_mock//mockgen:mockgen", -) - -MOCK_SRC_PACKAGE = "gvisor.dev/gvisor/pkg/p9" - -# mockgen_reflect is a source file that contains mock generation code that -# imports the p9 package and generates a specification via reflection. The -# usual generation path must be split into two distinct parts because the full -# source tree is not available to all build targets. Only declared depencies -# are available (and even then, not the Go source files). -genrule( - name = "mockgen_reflect", - testonly = 1, - outs = ["mockgen_reflect.go"], - cmd = ( - "$(location :mockgen) " + - "-package p9test " + - "-prog_only " + MOCK_SRC_PACKAGE + " " + - "Attacher,File > $@" - ), - tools = [":mockgen"], -) - -# mockgen_exec is the binary that includes the above reflection generator. -# Running this binary will emit an encoded version of the p9 Attacher and File -# structures. This is consumed by the mocks genrule, below. -go_binary( - name = "mockgen_exec", - testonly = 1, - srcs = ["mockgen_reflect.go"], - deps = [ - "//pkg/p9", - "@com_github_golang_mock//mockgen/model:go_default_library", - ], -) - -# mocks consumes the encoded output above, and generates the full source for a -# set of mocks. These are included directly in the p9test library. -genrule( - name = "mocks", - testonly = 1, - outs = ["mocks.go"], - cmd = ( - "$(location :mockgen) " + - "-package p9test " + - "-exec_only $(location :mockgen_exec) " + MOCK_SRC_PACKAGE + " File > $@" - ), - tools = [ - ":mockgen", - ":mockgen_exec", - ], -) - -go_library( - name = "p9test", - srcs = [ - "mocks.go", - "p9test.go", - ], - importpath = "gvisor.dev/gvisor/pkg/p9/p9test", - visibility = ["//:sandbox"], - deps = [ - "//pkg/fd", - "//pkg/log", - "//pkg/p9", - "//pkg/unet", - "@com_github_golang_mock//gomock:go_default_library", - ], -) - -go_test( - name = "client_test", - size = "small", - srcs = ["client_test.go"], - embed = [":p9test"], - deps = [ - "//pkg/fd", - "//pkg/p9", - "@com_github_golang_mock//gomock:go_default_library", - ], -) diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go deleted file mode 100644 index 16799a790..000000000 --- a/pkg/p9/p9test/client_test.go +++ /dev/null @@ -1,2073 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9test - -import ( - "bytes" - "fmt" - "io" - "math/rand" - "os" - "reflect" - "strings" - "sync" - "syscall" - "testing" - "time" - - "github.com/golang/mock/gomock" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/p9" -) - -func TestPanic(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Create a new root. - d := h.NewDirectory(nil)(nil) - defer d.Close() // Needed manually. - h.Attacher.EXPECT().Attach().Return(d, nil).Do(func() { - // Panic here, and ensure that we get back EFAULT. - panic("handler") - }) - - // Attach to the client. - if _, err := c.Attach("/"); err != syscall.EFAULT { - t.Fatalf("got attach err %v, want EFAULT", err) - } -} - -func TestAttachNoLeak(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Create a new root. - d := h.NewDirectory(nil)(nil) - h.Attacher.EXPECT().Attach().Return(d, nil).Times(1) - - // Attach to the client. - f, err := c.Attach("/") - if err != nil { - t.Fatalf("got attach err %v, want nil", err) - } - - // Don't close the file. This should be closed automatically when the - // client disconnects. The mock asserts that everything is closed - // exactly once. This statement just removes the unused variable error. - _ = f -} - -func TestBadAttach(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Return an error on attach. - h.Attacher.EXPECT().Attach().Return(nil, syscall.EINVAL).Times(1) - - // Attach to the client. - if _, err := c.Attach("/"); err != syscall.EINVAL { - t.Fatalf("got attach err %v, want syscall.EINVAL", err) - } -} - -func TestWalkAttach(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - // Create a new root. - d := h.NewDirectory(map[string]Generator{ - "a": h.NewDirectory(map[string]Generator{ - "b": h.NewFile(), - }), - })(nil) - h.Attacher.EXPECT().Attach().Return(d, nil).Times(1) - - // Attach to the client as a non-root, and ensure that the walk above - // occurs as expected. We should get back b, and all references should - // be dropped when the file is closed. - f, err := c.Attach("/a/b") - if err != nil { - t.Fatalf("got attach err %v, want nil", err) - } - defer f.Close() - - // Check that's a regular file. - if _, _, attr, err := f.GetAttr(p9.AttrMaskAll()); err != nil { - t.Errorf("got err %v, want nil", err) - } else if !attr.Mode.IsRegular() { - t.Errorf("got mode %v, want regular file", err) - } -} - -// newTypeMap returns a new type map dictionary. -func newTypeMap(h *Harness) map[string]Generator { - return map[string]Generator{ - "directory": h.NewDirectory(map[string]Generator{}), - "file": h.NewFile(), - "symlink": h.NewSymlink(), - "block-device": h.NewBlockDevice(), - "character-device": h.NewCharacterDevice(), - "named-pipe": h.NewNamedPipe(), - "socket": h.NewSocket(), - } -} - -// newRoot returns a new root filesystem. -// -// This is set up in a deterministic way for testing most operations. -// -// The represented file system looks like: -// - file -// - symlink -// - directory -// ... -// + one -// - file -// - symlink -// - directory -// ... -// + two -// - file -// - symlink -// - directory -// ... -// + three -// - file -// - symlink -// - directory -// ... -func newRoot(h *Harness, c *p9.Client) (*Mock, p9.File) { - root := newTypeMap(h) - one := newTypeMap(h) - two := newTypeMap(h) - three := newTypeMap(h) - one["two"] = h.NewDirectory(two) // Will be nested in one. - root["one"] = h.NewDirectory(one) // Top level. - root["three"] = h.NewDirectory(three) // Alternate top-level. - - // Create a new root. - rootBackend := h.NewDirectory(root)(nil) - h.Attacher.EXPECT().Attach().Return(rootBackend, nil) - - // Attach to the client. - r, err := c.Attach("/") - if err != nil { - h.t.Fatalf("got attach err %v, want nil", err) - } - - return rootBackend, r -} - -func allInvalidNames(from string) []string { - return []string{ - from + "/other", - from + "/..", - from + "/.", - from + "/", - "other/" + from, - "/" + from, - "./" + from, - "../" + from, - ".", - "..", - "/", - "", - } -} - -func TestWalkInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Run relevant tests. - for name := range newTypeMap(h) { - // These are all the various ways that one might attempt to - // construct compound paths. They should all be rejected, as - // any compound that contains a / is not allowed, as well as - // the singular paths of '.' and '..'. - if _, _, err := root.Walk([]string{".", name}); err != syscall.EINVAL { - t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err) - } - if _, _, err := root.Walk([]string{"..", name}); err != syscall.EINVAL { - t.Errorf("Walk through . %s wanted EINVAL, got %v", name, err) - } - if _, _, err := root.Walk([]string{name, "."}); err != syscall.EINVAL { - t.Errorf("Walk through %s . wanted EINVAL, got %v", name, err) - } - if _, _, err := root.Walk([]string{name, ".."}); err != syscall.EINVAL { - t.Errorf("Walk through %s .. wanted EINVAL, got %v", name, err) - } - for _, invalidName := range allInvalidNames(name) { - if _, _, err := root.Walk([]string{invalidName}); err != syscall.EINVAL { - t.Errorf("Walk through %s wanted EINVAL, got %v", invalidName, err) - } - } - wantErr := syscall.EINVAL - if name == "directory" { - // We can attempt a walk through a directory. However, - // we should never see a file named "other", so we - // expect this to return ENOENT. - wantErr = syscall.ENOENT - } - if _, _, err := root.Walk([]string{name, "other"}); err != wantErr { - t.Errorf("Walk through %s/other wanted %v, got %v", name, wantErr, err) - } - - // Do a successful walk. - _, f, err := root.Walk([]string{name}) - if err != nil { - t.Errorf("Walk to %s wanted nil, got %v", name, err) - } - defer f.Close() - local := h.Pop(f) - - // Check that the file matches. - _, localMask, localAttr, localErr := local.GetAttr(p9.AttrMaskAll()) - if _, mask, attr, err := f.GetAttr(p9.AttrMaskAll()); mask != localMask || attr != localAttr || err != localErr { - t.Errorf("GetAttr got (%v, %v, %v), wanted (%v, %v, %v)", - mask, attr, err, localMask, localAttr, localErr) - } - - // Ensure we can't walk backwards. - if _, _, err := f.Walk([]string{"."}); err != syscall.EINVAL { - t.Errorf("Walk through %s/. wanted EINVAL, got %v", name, err) - } - if _, _, err := f.Walk([]string{".."}); err != syscall.EINVAL { - t.Errorf("Walk through %s/.. wanted EINVAL, got %v", name, err) - } - } -} - -// fileGenerator is a function to generate files via walk or create. -// -// Examples are: -// - walkHelper -// - walkAndOpenHelper -// - createHelper -type fileGenerator func(*Harness, string, p9.File) (*Mock, *Mock, p9.File) - -// walkHelper walks to the given file. -// -// The backends of the parent and walked file are returned, as well as the -// walked client file. -func walkHelper(h *Harness, name string, dir p9.File) (parentBackend *Mock, walkedBackend *Mock, walked p9.File) { - _, parent, err := dir.Walk(nil) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer parent.Close() - parentBackend = h.Pop(parent) - - _, walked, err = parent.Walk([]string{name}) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - walkedBackend = h.Pop(walked) - - return parentBackend, walkedBackend, walked -} - -// walkAndOpenHelper additionally opens the walked file, if possible. -func walkAndOpenHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) { - parentBackend, walkedBackend, walked := walkHelper(h, name, dir) - if p9.CanOpen(walkedBackend.Attr.Mode) { - // Open for all file types that we can. We stick to a read-only - // open here because directories may not be opened otherwise. - walkedBackend.EXPECT().Open(p9.ReadOnly).Times(1) - if _, _, _, err := walked.Open(p9.ReadOnly); err != nil { - h.t.Errorf("got open err %v, want nil", err) - } - } else { - // ... or assert an error for others. - if _, _, _, err := walked.Open(p9.ReadOnly); err != syscall.EINVAL { - h.t.Errorf("got open err %v, want EINVAL", err) - } - } - return parentBackend, walkedBackend, walked -} - -// createHelper creates the given file and returns the parent directory, -// created file and client file, which must be closed when done. -func createHelper(h *Harness, name string, dir p9.File) (*Mock, *Mock, p9.File) { - // Clone the directory first, since Create replaces the existing file. - // We change the type after calling create. - _, dirThenFile, err := dir.Walk(nil) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - - // Create a new server-side file. On the server-side, the a new file is - // returned from a create call. The client will reuse the same file, - // but we still expect the normal chain of closes. This complicates - // things a bit because the "parent" will always chain to the cloned - // dir above. - dirBackend := h.Pop(dirThenFile) // New backend directory. - newFile := h.NewFile()(dirBackend) // New file with backend parent. - dirBackend.EXPECT().Create(name, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, newFile, newFile.QID, uint32(0), nil) - - // Create via the client. - _, dirThenFile, _, _, err = dirThenFile.Create(name, p9.ReadOnly, 0, 0, 0) - if err != nil { - h.t.Fatalf("got create err %v, want nil", err) - } - - // Ensure subsequent walks succeed. - dirBackend.AddChild(name, h.NewFile()) - return dirBackend, newFile, dirThenFile -} - -// deprecatedRemover allows us to access the deprecated Remove operation within -// the p9.File client object. -type deprecatedRemover interface { - Remove() error -} - -// checkDeleted asserts that relevant methods fail for an unlinked file. -// -// This function will close the file at the end. -func checkDeleted(h *Harness, file p9.File) { - defer file.Close() // See doc. - - if _, _, _, err := file.Open(p9.ReadOnly); err != syscall.EINVAL { - h.t.Errorf("open while deleted, got %v, want EINVAL", err) - } - if _, _, _, _, err := file.Create("created", p9.ReadOnly, 0, 0, 0); err != syscall.EINVAL { - h.t.Errorf("create while deleted, got %v, want EINVAL", err) - } - if _, err := file.Symlink("old", "new", 0, 0); err != syscall.EINVAL { - h.t.Errorf("symlink while deleted, got %v, want EINVAL", err) - } - // N.B. This link is technically invalid, but if a call to link is - // actually made in the backend then the mock will panic. - if err := file.Link(file, "new"); err != syscall.EINVAL { - h.t.Errorf("link while deleted, got %v, want EINVAL", err) - } - if err := file.RenameAt("src", file, "dst"); err != syscall.EINVAL { - h.t.Errorf("renameAt while deleted, got %v, want EINVAL", err) - } - if err := file.UnlinkAt("file", 0); err != syscall.EINVAL { - h.t.Errorf("unlinkAt while deleted, got %v, want EINVAL", err) - } - if err := file.Rename(file, "dst"); err != syscall.EINVAL { - h.t.Errorf("rename while deleted, got %v, want EINVAL", err) - } - if _, err := file.Readlink(); err != syscall.EINVAL { - h.t.Errorf("readlink while deleted, got %v, want EINVAL", err) - } - if _, err := file.Mkdir("dir", p9.ModeDirectory, 0, 0); err != syscall.EINVAL { - h.t.Errorf("mkdir while deleted, got %v, want EINVAL", err) - } - if _, err := file.Mknod("dir", p9.ModeDirectory, 0, 0, 0, 0); err != syscall.EINVAL { - h.t.Errorf("mknod while deleted, got %v, want EINVAL", err) - } - if _, err := file.Readdir(0, 1); err != syscall.EINVAL { - h.t.Errorf("readdir while deleted, got %v, want EINVAL", err) - } - if _, err := file.Connect(p9.ConnectFlags(0)); err != syscall.EINVAL { - h.t.Errorf("connect while deleted, got %v, want EINVAL", err) - } - - // The remove method is technically deprecated, but we want to ensure - // that it still checks for deleted appropriately. We must first clone - // the file because remove is equivalent to close. - _, newFile, err := file.Walk(nil) - if err == syscall.EBUSY { - // We can't walk from here because this reference is open - // aleady. Okay, we will also have unopened cases through - // TestUnlink, just skip the remove operation for now. - return - } else if err != nil { - h.t.Fatalf("clone failed, got %v, want nil", err) - } - if err := newFile.(deprecatedRemover).Remove(); err != syscall.EINVAL { - h.t.Errorf("remove while deleted, got %v, want EINVAL", err) - } -} - -// deleter is a function to remove a file. -type deleter func(parent p9.File, name string) error - -// unlinkAt is a deleter. -func unlinkAt(parent p9.File, name string) error { - // Call unlink. Note that a filesystem may normally impose additional - // constaints on unlinkat success, such as ensuring that a directory is - // empty, requiring AT_REMOVEDIR in flags to remove a directory, etc. - // None of that is required internally (entire trees can be marked - // deleted when this operation succeeds), so the mock will succeed. - return parent.UnlinkAt(name, 0) -} - -// remove is a deleter. -func remove(parent p9.File, name string) error { - // See notes above re: remove. - _, newFile, err := parent.Walk([]string{name}) - if err != nil { - // Should not be expected. - return err - } - - // Do the actual remove. - if err := newFile.(deprecatedRemover).Remove(); err != nil { - return err - } - - // Ensure that the remove closed the file. - if err := newFile.(deprecatedRemover).Remove(); err != syscall.EBADF { - return syscall.EBADF // Propagate this code. - } - - return nil -} - -// unlinkHelper unlinks the noted path, and ensures that all relevant -// operations on that path, acquired from multiple paths, start failing. -func unlinkHelper(h *Harness, root p9.File, targetNames []string, targetGen fileGenerator, deleteFn deleter) { - // name is the file to be unlinked. - name := targetNames[len(targetNames)-1] - - // Walk to the directory containing the target. - _, parent, err := root.Walk(targetNames[:len(targetNames)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer parent.Close() - parentBackend := h.Pop(parent) - - // Walk to or generate the target file. - _, _, target := targetGen(h, name, parent) - defer checkDeleted(h, target) - - // Walk to a second reference. - _, second, err := parent.Walk([]string{name}) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer checkDeleted(h, second) - - // Walk to a third reference, from the start. - _, third, err := root.Walk(targetNames) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer checkDeleted(h, third) - - // This will be translated in the backend to an unlinkat. - parentBackend.EXPECT().UnlinkAt(name, uint32(0)).Return(nil) - - // Actually perform the deletion. - if err := deleteFn(parent, name); err != nil { - h.t.Fatalf("got delete err %v, want nil", err) - } -} - -func unlinkTest(t *testing.T, targetNames []string, targetGen fileGenerator) { - t.Run(fmt.Sprintf("unlinkAt(%s)", strings.Join(targetNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - unlinkHelper(h, root, targetNames, targetGen, unlinkAt) - }) - t.Run(fmt.Sprintf("remove(%s)", strings.Join(targetNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - unlinkHelper(h, root, targetNames, targetGen, remove) - }) -} - -func TestUnlink(t *testing.T) { - // Unlink all files. - for name := range newTypeMap(nil) { - unlinkTest(t, []string{name}, walkHelper) - unlinkTest(t, []string{name}, walkAndOpenHelper) - unlinkTest(t, []string{"one", name}, walkHelper) - unlinkTest(t, []string{"one", name}, walkAndOpenHelper) - unlinkTest(t, []string{"one", "two", name}, walkHelper) - unlinkTest(t, []string{"one", "two", name}, walkAndOpenHelper) - } - - // Unlink a directory. - unlinkTest(t, []string{"one"}, walkHelper) - unlinkTest(t, []string{"one"}, walkAndOpenHelper) - unlinkTest(t, []string{"one", "two"}, walkHelper) - unlinkTest(t, []string{"one", "two"}, walkAndOpenHelper) - - // Unlink created files. - unlinkTest(t, []string{"created"}, createHelper) - unlinkTest(t, []string{"one", "created"}, createHelper) - unlinkTest(t, []string{"one", "two", "created"}, createHelper) -} - -func TestUnlinkAtInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.UnlinkAt(invalidName, 0); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -// expectRenamed asserts an ordered sequence of rename calls, based on all the -// elements in elements being the source, and the first element therein -// changing to dstName, parented at dstParent. -func expectRenamed(file *Mock, elements []string, dstParent *Mock, dstName string) *gomock.Call { - if len(elements) > 0 { - // Recurse to the parent, if necessary. - call := expectRenamed(file.parent, elements[:len(elements)-1], dstParent, dstName) - - // Recursive case: this element is unchanged, but should have - // it's hook called after the parent. - return file.EXPECT().Renamed(file.parent, elements[len(elements)-1]).Do(func(p p9.File, _ string) { - file.parent = p.(*Mock) - }).After(call) - } - - // Base case: this is the changed element. - return file.EXPECT().Renamed(dstParent, dstName).Do(func(p p9.File, name string) { - file.parent = p.(*Mock) - }) -} - -// renamer is a rename function. -type renamer func(h *Harness, srcParent, dstParent p9.File, origName, newName string, selfRename bool) error - -// renameAt is a renamer. -func renameAt(_ *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error { - return srcParent.RenameAt(srcName, dstParent, dstName) -} - -// rename is a renamer. -func rename(h *Harness, srcParent, dstParent p9.File, srcName, dstName string, selfRename bool) error { - _, f, err := srcParent.Walk([]string{srcName}) - if err != nil { - return err - } - defer f.Close() - if !selfRename { - backend := h.Pop(f) - backend.EXPECT().Renamed(gomock.Any(), dstName).Do(func(p p9.File, name string) { - backend.parent = p.(*Mock) // Required for close ordering. - }) - } - return f.Rename(dstParent, dstName) -} - -// renameHelper executes a rename, and asserts that all relevant elements -// receive expected notifications. If overwriting a file, this includes -// ensuring that the target has been appropriately marked as unlinked. -func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string, target fileGenerator, renameFn renamer) { - // Walk to the directory containing the target. - srcQID, targetParent, err := root.Walk(srcNames[:len(srcNames)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer targetParent.Close() - targetParentBackend := h.Pop(targetParent) - - // Walk to or generate the target file. - _, targetBackend, src := target(h, srcNames[len(srcNames)-1], targetParent) - defer src.Close() - - // Walk to a second reference. - _, second, err := targetParent.Walk([]string{srcNames[len(srcNames)-1]}) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer second.Close() - secondBackend := h.Pop(second) - - // Walk to a third reference, from the start. - _, third, err := root.Walk(srcNames) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer third.Close() - thirdBackend := h.Pop(third) - - // Find the common suffix to identify the rename parent. - var ( - renameDestPath []string - renameSrcPath []string - selfRename bool - ) - for i := 1; i <= len(srcNames) && i <= len(dstNames); i++ { - if srcNames[len(srcNames)-i] != dstNames[len(dstNames)-i] { - // Take the full prefix of dstNames up until this - // point, including the first mismatched name. The - // first mismatch must be the renamed entry. - renameDestPath = dstNames[:len(dstNames)-i+1] - renameSrcPath = srcNames[:len(srcNames)-i+1] - - // Does the renameDestPath fully contain the - // renameSrcPath here? If yes, then this is a mismatch. - // We can't rename the src to some subpath of itself. - if len(renameDestPath) > len(renameSrcPath) && - reflect.DeepEqual(renameDestPath[:len(renameSrcPath)], renameSrcPath) { - renameDestPath = nil - renameSrcPath = nil - continue - } - break - } - } - if len(renameSrcPath) == 0 || len(renameDestPath) == 0 { - // This must be a rename to self, or a tricky look-alike. This - // happens iff we fail to find a suitable divergence in the two - // paths. It's a true self move if the path length is the same. - renameDestPath = dstNames - renameSrcPath = srcNames - selfRename = len(srcNames) == len(dstNames) - } - - // Walk to the source parent. - _, srcParent, err := root.Walk(renameSrcPath[:len(renameSrcPath)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer srcParent.Close() - srcParentBackend := h.Pop(srcParent) - - // Walk to the destination parent. - _, dstParent, err := root.Walk(renameDestPath[:len(renameDestPath)-1]) - if err != nil { - h.t.Fatalf("got walk err %v, want nil", err) - } - defer dstParent.Close() - dstParentBackend := h.Pop(dstParent) - - // expectedErr is the result of the rename operation. - var expectedErr error - - // Walk to the target file, if one exists. - dstQID, dst, err := root.Walk(renameDestPath) - if err == nil { - if !selfRename && srcQID[0].Type == dstQID[0].Type { - // If there is a destination file, and is it of the - // same type as the source file, then we expect the - // rename to succeed. We expect the destination file to - // be deleted, so we run a deletion test on it in this - // case. - defer checkDeleted(h, dst) - } else { - if !selfRename { - // If the type is different than the - // destination, then we expect the rename to - // fail. We expect ensure that this is - // returned. - expectedErr = syscall.EINVAL - } else { - // This is the file being renamed to itself. - // This is technically allowed and a no-op, but - // all the triggers will fire. - } - dst.Close() - } - } - dstName := renameDestPath[len(renameDestPath)-1] // Renamed element. - srcName := renameSrcPath[len(renameSrcPath)-1] // Renamed element. - if expectedErr == nil && !selfRename { - // Expect all to be renamed appropriately. Note that if this is - // a final file being renamed, then we expect the file to be - // called with the new parent. If not, then we expect the - // rename hook to be called, but the parent will remain - // unchanged. - elements := srcNames[len(renameSrcPath):] - expectRenamed(targetBackend, elements, dstParentBackend, dstName) - expectRenamed(secondBackend, elements, dstParentBackend, dstName) - expectRenamed(thirdBackend, elements, dstParentBackend, dstName) - - // The target parent has also been opened, and may be moved - // directly or indirectly. - if len(elements) > 1 { - expectRenamed(targetParentBackend, elements[:len(elements)-1], dstParentBackend, dstName) - } - } - - // Expect the rename if it's not the same file. Note that like unlink, - // renames are always translated to the at variant in the backend. - if !selfRename { - srcParentBackend.EXPECT().RenameAt(srcName, dstParentBackend, dstName).Return(expectedErr) - } - - // Perform the actual rename; everything has been lined up. - if err := renameFn(h, srcParent, dstParent, srcName, dstName, selfRename); err != expectedErr { - h.t.Fatalf("got rename err %v, want %v", err, expectedErr) - } -} - -func renameTest(t *testing.T, srcNames []string, dstNames []string, target fileGenerator) { - t.Run(fmt.Sprintf("renameAt(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - renameHelper(h, root, srcNames, dstNames, target, renameAt) - }) - t.Run(fmt.Sprintf("rename(%s->%s)", strings.Join(srcNames, "/"), strings.Join(dstNames, "/")), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - renameHelper(h, root, srcNames, dstNames, target, rename) - }) -} - -func TestRename(t *testing.T) { - // In-directory rename, simple case. - for name := range newTypeMap(nil) { - // Within the root. - renameTest(t, []string{name}, []string{"renamed"}, walkHelper) - renameTest(t, []string{name}, []string{"renamed"}, walkAndOpenHelper) - - // Within a subdirectory. - renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", "renamed"}, walkAndOpenHelper) - } - - // ... with created files. - renameTest(t, []string{"created"}, []string{"renamed"}, createHelper) - renameTest(t, []string{"one", "created"}, []string{"one", "renamed"}, createHelper) - - // Across directories. - for name := range newTypeMap(nil) { - // Down one level. - renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", "two", "renamed"}, walkAndOpenHelper) - - // Up one level. - renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkHelper) - renameTest(t, []string{"one", "two", name}, []string{"one", "renamed"}, walkAndOpenHelper) - - // Across at the same level. - renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkHelper) - renameTest(t, []string{"one", name}, []string{"three", "renamed"}, walkAndOpenHelper) - } - - // ... with created files. - renameTest(t, []string{"one", "created"}, []string{"one", "two", "renamed"}, createHelper) - renameTest(t, []string{"one", "two", "created"}, []string{"one", "renamed"}, createHelper) - renameTest(t, []string{"one", "created"}, []string{"three", "renamed"}, createHelper) - - // Renaming parents. - for name := range newTypeMap(nil) { - // Rename a parent. - renameTest(t, []string{"one", name}, []string{"renamed", name}, walkHelper) - renameTest(t, []string{"one", name}, []string{"renamed", name}, walkAndOpenHelper) - - // Rename a super parent. - renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkHelper) - renameTest(t, []string{"one", "two", name}, []string{"renamed", name}, walkAndOpenHelper) - } - - // ... with created files. - renameTest(t, []string{"one", "created"}, []string{"renamed", "created"}, createHelper) - renameTest(t, []string{"one", "two", "created"}, []string{"renamed", "created"}, createHelper) - - // Over existing files, including itself. - for name := range newTypeMap(nil) { - for other := range newTypeMap(nil) { - // Overwrite the noted file (may be itself). - renameTest(t, []string{"one", name}, []string{"one", other}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", other}, walkAndOpenHelper) - - // Overwrite other files in another directory. - renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkHelper) - renameTest(t, []string{"one", name}, []string{"one", "two", other}, walkAndOpenHelper) - } - - // Overwrite by moving the parent. - renameTest(t, []string{"three", name}, []string{"one", name}, walkHelper) - renameTest(t, []string{"three", name}, []string{"one", name}, walkAndOpenHelper) - - // Create over the types. - renameTest(t, []string{"one", "created"}, []string{"one", name}, createHelper) - renameTest(t, []string{"one", "created"}, []string{"one", "two", name}, createHelper) - renameTest(t, []string{"three", "created"}, []string{"one", name}, createHelper) - } -} - -func TestRenameInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.Rename(root, invalidName); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestRenameAtInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.RenameAt(invalidName, root, "okay"); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - if err := root.RenameAt("okay", root, invalidName); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestReadlink(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, f, err := root.Walk([]string{name}) - if err != nil { - t.Fatalf("walk failed: got %v, wanted nil", err) - } - defer f.Close() - backend := h.Pop(f) - - const symlinkTarget = "symlink-target" - - if backend.Attr.Mode.IsSymlink() { - // This should only go through on symlinks. - backend.EXPECT().Readlink().Return(symlinkTarget, nil) - } - - // Attempt a Readlink operation. - target, err := f.Readlink() - if err != nil && err != syscall.EINVAL { - t.Errorf("readlink got %v, wanted EINVAL", err) - } else if err == nil && target != symlinkTarget { - t.Errorf("readlink got %v, wanted %v", target, symlinkTarget) - } - }) - } -} - -// fdTest is a wrapper around operations that may send file descriptors. This -// asserts that the file descriptors are working as intended. -func fdTest(t *testing.T, sendFn func(*fd.FD) *fd.FD) { - // Create a pipe that we can read from. - r, w, err := os.Pipe() - if err != nil { - t.Fatalf("unable to create pipe: %v", err) - } - defer r.Close() - defer w.Close() - - // Attempt to send the write end. - wFD, err := fd.NewFromFile(w) - if err != nil { - t.Fatalf("unable to convert file: %v", err) - } - defer wFD.Close() // This is a copy. - - // Send wFD and receive newFD. - newFD := sendFn(wFD) - defer newFD.Close() - - // Attempt to write. - const message = "hello" - if _, err := newFD.Write([]byte(message)); err != nil { - t.Fatalf("write got %v, wanted nil", err) - } - - // Should see the message on our end. - buffer := []byte(message) - if _, err := io.ReadFull(r, buffer); err != nil { - t.Fatalf("read got %v, wanted nil", err) - } - if string(buffer) != message { - t.Errorf("got message %v, wanted %v", string(buffer), message) - } -} - -func TestConnect(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Catch all the non-socket cases. - if !backend.Attr.Mode.IsSocket() { - // This has been set up to fail if Connect is called. - if _, err := f.Connect(p9.ConnectFlags(0)); err != syscall.EINVAL { - t.Errorf("connect got %v, wanted EINVAL", err) - } - return - } - - // Ensure the fd exchange works. - fdTest(t, func(send *fd.FD) *fd.FD { - backend.EXPECT().Connect(p9.ConnectFlags(0)).Return(send, nil) - recv, err := backend.Connect(p9.ConnectFlags(0)) - if err != nil { - t.Fatalf("connect got %v, wanted nil", err) - } - return recv - }) - }) - } -} - -func TestReaddir(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Catch all the non-directory cases. - if !backend.Attr.Mode.IsDir() { - // This has also been set up to fail if Readdir is called. - if _, err := f.Readdir(0, 1); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) - } - return - } - - // Ensure that readdir works for directories. - if _, err := f.Readdir(0, 1); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) - } - if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) - } - if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) - } - backend.EXPECT().Open(p9.ReadOnly).Times(1) - if _, _, _, err := f.Open(p9.ReadOnly); err != nil { - t.Errorf("readdir got %v, wanted nil", err) - } - backend.EXPECT().Readdir(uint64(0), uint32(1)).Times(1) - if _, err := f.Readdir(0, 1); err != nil { - t.Errorf("readdir got %v, wanted nil", err) - } - }) - } -} - -func TestOpen(t *testing.T) { - type openTest struct { - name string - mode p9.OpenFlags - err error - match func(p9.FileMode) bool - } - - cases := []openTest{ - { - name: "invalid", - mode: ^p9.OpenFlagsModeMask, - err: syscall.EINVAL, - match: func(p9.FileMode) bool { return true }, - }, - { - name: "not-openable-read-only", - mode: p9.ReadOnly, - err: syscall.EINVAL, - match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, - }, - { - name: "not-openable-write-only", - mode: p9.WriteOnly, - err: syscall.EINVAL, - match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, - }, - { - name: "not-openable-read-write", - mode: p9.ReadWrite, - err: syscall.EINVAL, - match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, - }, - { - name: "directory-read-only", - mode: p9.ReadOnly, - err: nil, - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - }, - { - name: "directory-read-write", - mode: p9.ReadWrite, - err: syscall.EINVAL, - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - }, - { - name: "directory-write-only", - mode: p9.WriteOnly, - err: syscall.EINVAL, - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - }, - { - name: "read-only", - mode: p9.ReadOnly, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, - }, - { - name: "write-only", - mode: p9.WriteOnly, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - { - name: "read-write", - mode: p9.ReadWrite, - err: nil, - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, - }, - } - - // Open(mode OpenFlags) (*fd.FD, QID, uint32, error) - // - only works on Regular, NamedPipe, BLockDevice, CharacterDevice - // - returning a file works as expected - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Does this match the case? - if !tc.match(backend.Attr.Mode) { - t.SkipNow() - } - - // Ensure open-required operations fail. - if _, err := f.ReadAt([]byte("hello"), 0); err != syscall.EINVAL { - t.Errorf("readAt got %v, wanted EINVAL", err) - } - if _, err := f.WriteAt(make([]byte, 6), 0); err != syscall.EINVAL { - t.Errorf("writeAt got %v, wanted EINVAL", err) - } - if err := f.FSync(); err != syscall.EINVAL { - t.Errorf("fsync got %v, wanted EINVAL", err) - } - if _, err := f.Readdir(0, 1); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) - } - - // Attempt the given open. - if tc.err != nil { - // We expect an error, just test and return. - if _, _, _, err := f.Open(tc.mode); err != tc.err { - t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err) - } - return - } - - // Run an FD test, since we expect success. - fdTest(t, func(send *fd.FD) *fd.FD { - backend.EXPECT().Open(tc.mode).Return(send, p9.QID{}, uint32(0), nil).Times(1) - recv, _, _, err := f.Open(tc.mode) - if err != tc.err { - t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err) - } - return recv - }) - - // If the open was successful, attempt another one. - if _, _, _, err := f.Open(tc.mode); err != syscall.EINVAL { - t.Errorf("second open with mode %v got %v, want EINVAL", tc.mode, err) - } - - // Ensure that all illegal operations fail. - if _, _, err := f.Walk(nil); err != syscall.EINVAL && err != syscall.EBUSY { - t.Errorf("walk got %v, wanted EINVAL or EBUSY", err) - } - if _, _, _, _, err := f.WalkGetAttr(nil); err != syscall.EINVAL && err != syscall.EBUSY { - t.Errorf("walkgetattr got %v, wanted EINVAL or EBUSY", err) - } - }) - } - } -} - -func TestClose(t *testing.T) { - type closeTest struct { - name string - closeFn func(backend *Mock, f p9.File) - } - - cases := []closeTest{ - { - name: "close", - closeFn: func(_ *Mock, f p9.File) { - f.Close() - }, - }, - { - name: "remove", - closeFn: func(backend *Mock, f p9.File) { - // Allow the rename call in the parent, automatically translated. - backend.parent.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Times(1) - f.(deprecatedRemover).Remove() - }, - }, - } - - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s(%s)", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - - // Close via the prescribed method. - tc.closeFn(backend, f) - - // Everything should fail with EBADF. - if _, _, err := f.Walk(nil); err != syscall.EBADF { - t.Errorf("walk got %v, wanted EBADF", err) - } - if _, err := f.StatFS(); err != syscall.EBADF { - t.Errorf("statfs got %v, wanted EBADF", err) - } - if _, _, _, err := f.GetAttr(p9.AttrMaskAll()); err != syscall.EBADF { - t.Errorf("getattr got %v, wanted EBADF", err) - } - if err := f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{}); err != syscall.EBADF { - t.Errorf("setattrk got %v, wanted EBADF", err) - } - if err := f.Rename(root, "new-name"); err != syscall.EBADF { - t.Errorf("rename got %v, wanted EBADF", err) - } - if err := f.Close(); err != syscall.EBADF { - t.Errorf("close got %v, wanted EBADF", err) - } - if _, _, _, err := f.Open(p9.ReadOnly); err != syscall.EBADF { - t.Errorf("open got %v, wanted EBADF", err) - } - if _, err := f.ReadAt([]byte("hello"), 0); err != syscall.EBADF { - t.Errorf("readAt got %v, wanted EBADF", err) - } - if _, err := f.WriteAt(make([]byte, 6), 0); err != syscall.EBADF { - t.Errorf("writeAt got %v, wanted EBADF", err) - } - if err := f.FSync(); err != syscall.EBADF { - t.Errorf("fsync got %v, wanted EBADF", err) - } - if _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 0, 0); err != syscall.EBADF { - t.Errorf("create got %v, wanted EBADF", err) - } - if _, err := f.Mkdir("new-directory", 0, 0, 0); err != syscall.EBADF { - t.Errorf("mkdir got %v, wanted EBADF", err) - } - if _, err := f.Symlink("old-name", "new-name", 0, 0); err != syscall.EBADF { - t.Errorf("symlink got %v, wanted EBADF", err) - } - if err := f.Link(root, "new-name"); err != syscall.EBADF { - t.Errorf("link got %v, wanted EBADF", err) - } - if _, err := f.Mknod("new-block-device", 0, 0, 0, 0, 0); err != syscall.EBADF { - t.Errorf("mknod got %v, wanted EBADF", err) - } - if err := f.RenameAt("old-name", root, "new-name"); err != syscall.EBADF { - t.Errorf("renameAt got %v, wanted EBADF", err) - } - if err := f.UnlinkAt("name", 0); err != syscall.EBADF { - t.Errorf("unlinkAt got %v, wanted EBADF", err) - } - if _, err := f.Readdir(0, 1); err != syscall.EBADF { - t.Errorf("readdir got %v, wanted EBADF", err) - } - if _, err := f.Readlink(); err != syscall.EBADF { - t.Errorf("readlink got %v, wanted EBADF", err) - } - if err := f.Flush(); err != syscall.EBADF { - t.Errorf("flush got %v, wanted EBADF", err) - } - if _, _, _, _, err := f.WalkGetAttr(nil); err != syscall.EBADF { - t.Errorf("walkgetattr got %v, wanted EBADF", err) - } - if _, err := f.Connect(p9.ConnectFlags(0)); err != syscall.EBADF { - t.Errorf("connect got %v, wanted EBADF", err) - } - }) - } - } -} - -// onlyWorksOnOpenThings is a helper test method for operations that should -// only work on files that have been explicitly opened. -func onlyWorksOnOpenThings(h *Harness, t *testing.T, name string, root p9.File, mode p9.OpenFlags, expectedErr error, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) { - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Does it work before opening? - if err := fn(backend, f, false); err != syscall.EINVAL { - t.Errorf("operation got %v, wanted EINVAL", err) - } - - // Is this openable? - if !p9.CanOpen(backend.Attr.Mode) { - return // Nothing to do. - } - - // If this is a directory, we can't handle writing. - if backend.Attr.Mode.IsDir() && (mode == p9.ReadWrite || mode == p9.WriteOnly) { - return // Skip. - } - - // Open the file. - backend.EXPECT().Open(mode) - if _, _, _, err := f.Open(mode); err != nil { - t.Fatalf("open got %v, wanted nil", err) - } - - // Attempt the operation. - if err := fn(backend, f, expectedErr == nil); err != expectedErr { - t.Fatalf("operation got %v, wanted %v", err, expectedErr) - } -} - -func TestRead(t *testing.T) { - type readTest struct { - name string - mode p9.OpenFlags - err error - } - - cases := []readTest{ - { - name: "read-only", - mode: p9.ReadOnly, - err: nil, - }, - { - name: "read-write", - mode: p9.ReadWrite, - err: nil, - }, - { - name: "write-only", - mode: p9.WriteOnly, - err: syscall.EPERM, - }, - } - - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - const message = "hello" - - onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if !shouldSucceed { - _, err := f.ReadAt([]byte(message), 0) - return err - } - - // Prepare for the call to readAt in the backend. - backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { - copy(p, message) - }).Return(len(message), nil) - - // Make the client call. - p := make([]byte, 2*len(message)) // Double size. - n, err := f.ReadAt(p, 0) - - // Sanity check result. - if err != nil { - return err - } - if n != len(message) { - t.Fatalf("message length incorrect, got %d, want %d", n, len(message)) - } - if !bytes.Equal(p[:n], []byte(message)) { - t.Fatalf("message incorrect, got %v, want %v", p, []byte(message)) - } - return nil // Success. - }) - }) - } - } -} - -func TestWrite(t *testing.T) { - type writeTest struct { - name string - mode p9.OpenFlags - err error - } - - cases := []writeTest{ - { - name: "read-only", - mode: p9.ReadOnly, - err: syscall.EPERM, - }, - { - name: "read-write", - mode: p9.ReadWrite, - err: nil, - }, - { - name: "write-only", - mode: p9.WriteOnly, - err: nil, - }, - } - - for name := range newTypeMap(nil) { - for _, tc := range cases { - t.Run(fmt.Sprintf("%s-%s", tc.name, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - const message = "hello" - - onlyWorksOnOpenThings(h, t, name, root, tc.mode, tc.err, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if !shouldSucceed { - _, err := f.WriteAt([]byte(message), 0) - return err - } - - // Prepare for the call to readAt in the backend. - var output []byte // Saved by Do below. - backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { - output = p - }).Return(len(message), nil) - - // Make the client call. - n, err := f.WriteAt([]byte(message), 0) - - // Sanity check result. - if err != nil { - return err - } - if n != len(message) { - t.Fatalf("message length incorrect, got %d, want %d", n, len(message)) - } - if !bytes.Equal(output, []byte(message)) { - t.Fatalf("message incorrect, got %v, want %v", output, []byte(message)) - } - return nil // Success. - }) - }) - } - } -} - -func TestFSync(t *testing.T) { - for name := range newTypeMap(nil) { - for _, mode := range []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} { - t.Run(fmt.Sprintf("%s-%s", mode, name), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnOpenThings(h, t, name, root, mode, nil, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().FSync().Times(1) - } - return f.FSync() - }) - }) - } - } -} - -func TestFlush(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - backend.EXPECT().Flush() - f.Flush() - }) - } -} - -// onlyWorksOnDirectories is a helper test method for operations that should -// only work on unopened directories, such as create, mkdir and symlink. -func onlyWorksOnDirectories(h *Harness, t *testing.T, name string, root p9.File, fn func(backend *Mock, f p9.File, shouldSucceed bool) error) { - // Walk to the file normally. - _, backend, f := walkHelper(h, name, root) - defer f.Close() - - // Only directories support mknod. - if !backend.Attr.Mode.IsDir() { - if err := fn(backend, f, false); err != syscall.EINVAL { - t.Errorf("operation got %v, wanted EINVAL", err) - } - return // Nothing else to do. - } - - // Should succeed. - if err := fn(backend, f, true); err != nil { - t.Fatalf("operation got %v, wanted nil", err) - } - - // Open the directory. - backend.EXPECT().Open(p9.ReadOnly).Times(1) - if _, _, _, err := f.Open(p9.ReadOnly); err != nil { - t.Fatalf("open got %v, wanted nil", err) - } - - // Should not work again. - if err := fn(backend, f, false); err != syscall.EINVAL { - t.Fatalf("operation got %v, wanted EINVAL", err) - } -} - -func TestCreate(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if !shouldSucceed { - _, _, _, _, err := f.Create("new-file", p9.ReadWrite, 0, 1, 2) - return err - } - - // If the create is going to succeed, then we - // need to create a new backend file, and we - // clone to ensure that we don't close the - // original. - _, newF, err := f.Walk(nil) - if err != nil { - t.Fatalf("clone got %v, wanted nil", err) - } - defer newF.Close() - newBackend := h.Pop(newF) - - // Run a regular FD test to validate that path. - fdTest(t, func(send *fd.FD) *fd.FD { - // Return the send FD on success. - newFile := h.NewFile()(backend) // New file with the parent backend. - newBackend.EXPECT().Create("new-file", p9.ReadWrite, p9.FileMode(0), p9.UID(1), p9.GID(2)).Return(send, newFile, p9.QID{}, uint32(0), nil) - - // Receive the fd back. - recv, _, _, _, err := newF.Create("new-file", p9.ReadWrite, 0, 1, 2) - if err != nil { - t.Fatalf("create got %v, wanted nil", err) - } - return recv - }) - - // The above will fail via normal test flow, so - // we can assume that it passed. - return nil - }) - }) - } -} - -func TestCreateInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if _, _, _, _, err := root.Create(invalidName, p9.ReadWrite, 0, 0, 0); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestMkdir(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Mkdir("new-directory", p9.FileMode(0), p9.UID(1), p9.GID(2)) - } - _, err := f.Mkdir("new-directory", 0, 1, 2) - return err - }) - }) - } -} - -func TestMkdirInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if _, err := root.Mkdir(invalidName, 0, 0, 0); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestSymlink(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Symlink("old-name", "new-name", p9.UID(1), p9.GID(2)) - } - _, err := f.Symlink("old-name", "new-name", 1, 2) - return err - }) - }) - } -} - -func TestSyminkInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - // We need only test for invalid names in the new name, - // the target can be an arbitrary string and we don't - // need to sanity check it. - if _, err := root.Symlink("old-name", invalidName, 0, 0); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestLink(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Link(gomock.Any(), "new-link") - } - return f.Link(f, "new-link") - }) - }) - } -} - -func TestLinkInvalid(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - for name := range newTypeMap(nil) { - for _, invalidName := range allInvalidNames(name) { - if err := root.Link(root, invalidName); err != syscall.EINVAL { - t.Errorf("got %v for name %q, want EINVAL", err, invalidName) - } - } - } -} - -func TestMknod(t *testing.T) { - for name := range newTypeMap(nil) { - t.Run(name, func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - onlyWorksOnDirectories(h, t, name, root, func(backend *Mock, f p9.File, shouldSucceed bool) error { - if shouldSucceed { - backend.EXPECT().Mknod("new-block-device", p9.FileMode(0), uint32(1), uint32(2), p9.UID(3), p9.GID(4)).Times(1) - } - _, err := f.Mknod("new-block-device", 0, 1, 2, 3, 4) - return err - }) - }) - } -} - -// concurrentFn is a specification of a concurrent operation. This is used to -// drive the concurrency tests below. -type concurrentFn struct { - name string - match func(p9.FileMode) bool - op func(h *Harness, backend *Mock, f p9.File, callback func()) -} - -func concurrentTest(t *testing.T, name string, fn1, fn2 concurrentFn, sameDir, expectedOkay bool) { - var ( - names1 []string - names2 []string - ) - if sameDir { - // Use the same file one directory up. - names1, names2 = []string{"one", name}, []string{"one", name} - } else { - // For different directories, just use siblings. - names1, names2 = []string{"one", name}, []string{"three", name} - } - - t.Run(fmt.Sprintf("%s(%v)+%s(%v)", fn1.name, names1, fn2.name, names2), func(t *testing.T) { - h, c := NewHarness(t) - defer h.Finish() - - _, root := newRoot(h, c) - defer root.Close() - - // Walk to both files as given. - _, f1, err := root.Walk(names1) - if err != nil { - t.Fatalf("error walking, got %v, want nil", err) - } - defer f1.Close() - b1 := h.Pop(f1) - _, f2, err := root.Walk(names2) - if err != nil { - t.Fatalf("error walking, got %v, want nil", err) - } - defer f2.Close() - b2 := h.Pop(f2) - - // Are these a good match for the current test case? - if !fn1.match(b1.Attr.Mode) { - t.SkipNow() - } - if !fn2.match(b2.Attr.Mode) { - t.SkipNow() - } - - // Construct our "concurrency creator". - in1 := make(chan struct{}, 1) - in2 := make(chan struct{}, 1) - var top sync.WaitGroup - var fns sync.WaitGroup - defer top.Wait() - top.Add(2) // Accounting for below. - defer fns.Done() - fns.Add(1) // See line above; released before top.Wait. - go func() { - defer top.Done() - fn1.op(h, b1, f1, func() { - in1 <- struct{}{} - fns.Wait() - }) - }() - go func() { - defer top.Done() - fn2.op(h, b2, f2, func() { - in2 <- struct{}{} - fns.Wait() - }) - }() - - // Compute a reasonable timeout. If we expect the operation to hang, - // give it 10 milliseconds before we assert that it's fine. After all, - // there will be a lot of these tests. If we don't expect it to hang, - // give it a full minute, since the machine could be slow. - timeout := 10 * time.Millisecond - if expectedOkay { - timeout = 1 * time.Minute - } - - // Read the first channel. - var second chan struct{} - select { - case <-in1: - second = in2 - case <-in2: - second = in1 - } - - // Catch concurrency. - select { - case <-second: - // We finished successful. Is this good? Depends on the - // expected result. - if !expectedOkay { - t.Errorf("%q and %q proceeded concurrently!", fn1.name, fn2.name) - } - case <-time.After(timeout): - // Great, things did not proceed concurrently. Is that what we - // expected? - if expectedOkay { - t.Errorf("%q and %q hung concurrently!", fn1.name, fn2.name) - } - } - }) -} - -func randomFileName() string { - return fmt.Sprintf("%x", rand.Int63()) -} - -func TestConcurrency(t *testing.T) { - readExclusive := []concurrentFn{ - { - // N.B. We can't explicitly check WalkGetAttr behavior, - // but we rely on the fact that the internal code paths - // are the same. - name: "walk", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // See the documentation of WalkCallback. - // Because walk is actually implemented by the - // mock, we need a special place for this - // callback. - // - // Note that a clone actually locks the parent - // node. So we walk from this node to test - // concurrent operations appropriately. - backend.WalkCallback = func() error { - callback() - return nil - } - f.Walk([]string{randomFileName()}) // Won't exist. - }, - }, - { - name: "fsync", - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Open(gomock.Any()) - backend.EXPECT().FSync().Do(func() { - callback() - }) - f.Open(p9.ReadOnly) // Required. - f.FSync() - }, - }, - { - name: "readdir", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Open(gomock.Any()) - backend.EXPECT().Readdir(gomock.Any(), gomock.Any()).Do(func(uint64, uint32) { - callback() - }) - f.Open(p9.ReadOnly) // Required. - f.Readdir(0, 1) - }, - }, - { - name: "readlink", - match: func(mode p9.FileMode) bool { return mode.IsSymlink() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Readlink().Do(func() { - callback() - }) - f.Readlink() - }, - }, - { - name: "connect", - match: func(mode p9.FileMode) bool { return mode.IsSocket() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Connect(gomock.Any()).Do(func(p9.ConnectFlags) { - callback() - }) - f.Connect(0) - }, - }, - { - name: "open", - match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Open(gomock.Any()).Do(func(p9.OpenFlags) { - callback() - }) - f.Open(p9.ReadOnly) - }, - }, - { - name: "flush", - match: func(mode p9.FileMode) bool { return true }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Flush().Do(func() { - callback() - }) - f.Flush() - }, - }, - } - writeExclusive := []concurrentFn{ - { - // N.B. We can't really check getattr. But this is an - // extremely low-risk function, it seems likely that - // this check is paranoid anyways. - name: "setattr", - match: func(mode p9.FileMode) bool { return true }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().SetAttr(gomock.Any(), gomock.Any()).Do(func(p9.SetAttrMask, p9.SetAttr) { - callback() - }) - f.SetAttr(p9.SetAttrMask{}, p9.SetAttr{}) - }, - }, - { - name: "unlinkAt", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) { - callback() - }) - f.UnlinkAt(randomFileName(), 0) - }, - }, - { - name: "mknod", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Mknod(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, uint32, uint32, p9.UID, p9.GID) { - callback() - }) - f.Mknod(randomFileName(), 0, 0, 0, 0, 0) - }, - }, - { - name: "link", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Link(gomock.Any(), gomock.Any()).Do(func(p9.File, string) { - callback() - }) - f.Link(f, randomFileName()) - }, - }, - { - name: "symlink", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Symlink(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, string, p9.UID, p9.GID) { - callback() - }) - f.Symlink(randomFileName(), randomFileName(), 0, 0) - }, - }, - { - name: "mkdir", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().Mkdir(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.FileMode, p9.UID, p9.GID) { - callback() - }) - f.Mkdir(randomFileName(), 0, 0, 0) - }, - }, - { - name: "create", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // Return an error for the creation operation, as this is the simplest. - backend.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil, p9.QID{}, uint32(0), syscall.EINVAL).Do(func(string, p9.OpenFlags, p9.FileMode, p9.UID, p9.GID) { - callback() - }) - f.Create(randomFileName(), p9.ReadOnly, 0, 0, 0) - }, - }, - } - globalExclusive := []concurrentFn{ - { - name: "remove", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // Remove operates on a locked parent. So we - // add a child, walk to it and call remove. - // Note that because this operation can operate - // concurrently with itself, we need to - // generate a random file name. - randomFile := randomFileName() - backend.AddChild(randomFile, h.NewFile()) - defer backend.RemoveChild(randomFile) - _, file, err := f.Walk([]string{randomFile}) - if err != nil { - h.t.Fatalf("walk got %v, want nil", err) - } - - // Remove is automatically translated to the parent. - backend.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Do(func(string, uint32) { - callback() - }) - - // Remove is also a close. - file.(deprecatedRemover).Remove() - }, - }, - { - name: "rename", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - // Similarly to remove, because we need to - // operate on a child, we allow a walk. - randomFile := randomFileName() - backend.AddChild(randomFile, h.NewFile()) - defer backend.RemoveChild(randomFile) - _, file, err := f.Walk([]string{randomFile}) - if err != nil { - h.t.Fatalf("walk got %v, want nil", err) - } - defer file.Close() - fileBackend := h.Pop(file) - - // Rename is automatically translated to the parent. - backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) { - callback() - }) - - // Attempt the rename. - fileBackend.EXPECT().Renamed(gomock.Any(), gomock.Any()) - file.Rename(f, randomFileName()) - }, - }, - { - name: "renameAt", - match: func(mode p9.FileMode) bool { return mode.IsDir() }, - op: func(h *Harness, backend *Mock, f p9.File, callback func()) { - backend.EXPECT().RenameAt(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(string, p9.File, string) { - callback() - }) - - // Attempt the rename. There are no active fids - // with this name, so we don't need to expect - // Renamed hooks on anything. - f.RenameAt(randomFileName(), f, randomFileName()) - }, - }, - } - - for _, fn1 := range readExclusive { - for _, fn2 := range readExclusive { - for name := range newTypeMap(nil) { - // Everything should be able to proceed in parallel. - concurrentTest(t, name, fn1, fn2, true, true) - concurrentTest(t, name, fn1, fn2, false, true) - } - } - } - - for _, fn1 := range append(readExclusive, writeExclusive...) { - for _, fn2 := range writeExclusive { - for name := range newTypeMap(nil) { - // Only cross-directory functions should proceed in parallel. - concurrentTest(t, name, fn1, fn2, true, false) - concurrentTest(t, name, fn1, fn2, false, true) - } - } - } - - for _, fn1 := range append(append(readExclusive, writeExclusive...), globalExclusive...) { - for _, fn2 := range globalExclusive { - for name := range newTypeMap(nil) { - // Nothing should be able to run in parallel. - concurrentTest(t, name, fn1, fn2, true, false) - concurrentTest(t, name, fn1, fn2, false, false) - } - } - } -} diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go deleted file mode 100644 index 95846e5f7..000000000 --- a/pkg/p9/p9test/p9test.go +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package p9test provides standard mocks for p9. -package p9test - -import ( - "fmt" - "sync" - "sync/atomic" - "syscall" - "testing" - - "github.com/golang/mock/gomock" - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/unet" -) - -// Harness is an attacher mock. -type Harness struct { - t *testing.T - mockCtrl *gomock.Controller - Attacher *MockAttacher - wg sync.WaitGroup - clientSocket *unet.Socket - mu sync.Mutex - created []*Mock -} - -// globalPath is a QID.Path Generator. -var globalPath uint64 - -// MakePath returns a globally unique path. -func MakePath() uint64 { - return atomic.AddUint64(&globalPath, 1) -} - -// Generator is a function that generates a new file. -type Generator func(parent *Mock) *Mock - -// Mock is a common mock element. -type Mock struct { - p9.DefaultWalkGetAttr - *MockFile - parent *Mock - closed bool - harness *Harness - QID p9.QID - Attr p9.Attr - children map[string]Generator - - // WalkCallback is a special function that will be called from within - // the walk context. This is needed for the concurrent tests within - // this package. - WalkCallback func() error -} - -// globalMu protects the children maps in all mocks. Note that this is not a -// particularly elegant solution, but because the test has walks from the root -// through to final nodes, we must share maps below, and it's easiest to simply -// protect against concurrent access globally. -var globalMu sync.RWMutex - -// AddChild adds a new child to the Mock. -func (m *Mock) AddChild(name string, generator Generator) { - globalMu.Lock() - defer globalMu.Unlock() - m.children[name] = generator -} - -// RemoveChild removes the child with the given name. -func (m *Mock) RemoveChild(name string) { - globalMu.Lock() - defer globalMu.Unlock() - delete(m.children, name) -} - -// Matches implements gomock.Matcher.Matches. -func (m *Mock) Matches(x interface{}) bool { - if om, ok := x.(*Mock); ok { - return m.QID.Path == om.QID.Path - } - return false -} - -// String implements gomock.Matcher.String. -func (m *Mock) String() string { - return fmt.Sprintf("Mock{Mode: 0x%x, QID.Path: %d}", m.Attr.Mode, m.QID.Path) -} - -// GetAttr returns the current attributes. -func (m *Mock) GetAttr(mask p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) { - return m.QID, p9.AttrMaskAll(), m.Attr, nil -} - -// Walk supports clone and walking in directories. -func (m *Mock) Walk(names []string) ([]p9.QID, p9.File, error) { - if m.WalkCallback != nil { - if err := m.WalkCallback(); err != nil { - return nil, nil, err - } - } - if len(names) == 0 { - // Clone the file appropriately. - nm := m.harness.NewMock(m.parent, m.QID.Path, m.Attr) - nm.children = m.children // Inherit children. - return []p9.QID{nm.QID}, nm, nil - } else if len(names) != 1 { - m.harness.t.Fail() // Should not happen. - return nil, nil, syscall.EINVAL - } - - if m.Attr.Mode.IsDir() { - globalMu.RLock() - defer globalMu.RUnlock() - if fn, ok := m.children[names[0]]; ok { - // Generate the child. - nm := fn(m) - return []p9.QID{nm.QID}, nm, nil - } - // No child found. - return nil, nil, syscall.ENOENT - } - - // Call the underlying mock. - return m.MockFile.Walk(names) -} - -// WalkGetAttr calls the default implementation; this is a client-side optimization. -func (m *Mock) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, p9.Attr, error) { - return m.DefaultWalkGetAttr.WalkGetAttr(names) -} - -// Pop pops off the most recently created Mock and assert that this mock -// represents the same file passed in. If nil is passed in, no check is -// performed. -// -// Precondition: there must be at least one Mock or this will panic. -func (h *Harness) Pop(clientFile p9.File) *Mock { - h.mu.Lock() - defer h.mu.Unlock() - - if clientFile == nil { - // If no clientFile is provided, then we always return the last - // created file. The caller can safely use this as long as - // there is no concurrency. - m := h.created[len(h.created)-1] - h.created = h.created[:len(h.created)-1] - return m - } - - qid, _, _, err := clientFile.GetAttr(p9.AttrMaskAll()) - if err != nil { - // We do not expect this to happen. - panic(fmt.Sprintf("err during Pop: %v", err)) - } - - // Find the relevant file in our created list. We must scan the last - // from back to front to ensure that we favor the most recently - // generated file. - for i := len(h.created) - 1; i >= 0; i-- { - m := h.created[i] - if qid.Path == m.QID.Path { - // Copy and truncate. - copy(h.created[i:], h.created[i+1:]) - h.created = h.created[:len(h.created)-1] - return m - } - } - - // Unable to find relevant file. - panic(fmt.Sprintf("unable to locate file with QID %+v", qid.Path)) -} - -// NewMock returns a new base file. -func (h *Harness) NewMock(parent *Mock, path uint64, attr p9.Attr) *Mock { - m := &Mock{ - MockFile: NewMockFile(h.mockCtrl), - parent: parent, - harness: h, - QID: p9.QID{ - Type: p9.QIDType((attr.Mode & p9.FileModeMask) >> 12), - Path: path, - }, - Attr: attr, - } - - // Always ensure Close is after the parent's close. Note that this - // can't be done via a straight-forward After call, because the parent - // might change after initial creation. We ensure that this is true at - // close time. - m.EXPECT().Close().Return(nil).Times(1).Do(func() { - if m.parent != nil && m.parent.closed { - h.t.FailNow() - } - // Note that this should not be racy, as this operation should - // be protected by the Times(1) above first. - m.closed = true - }) - - // Remember what was created. - h.mu.Lock() - defer h.mu.Unlock() - h.created = append(h.created, m) - - return m -} - -// NewFile returns a new file mock. -// -// Note that ReadAt and WriteAt must be mocked separately. -func (h *Harness) NewFile() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeRegular}) - } -} - -// NewDirectory returns a new mock directory. -// -// Note that Mkdir, Link, Mknod, RenameAt, UnlinkAt and Readdir must be mocked -// separately. Walk is provided and children may be manipulated via AddChild -// and RemoveChild. After calling Walk remotely, one can use Pop to find the -// corresponding backend mock on the server side. -func (h *Harness) NewDirectory(contents map[string]Generator) Generator { - return func(parent *Mock) *Mock { - m := h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeDirectory}) - m.children = contents // Save contents. - return m - } -} - -// NewSymlink returns a new mock directory. -// -// Note that Readlink must be mocked separately. -func (h *Harness) NewSymlink() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSymlink}) - } -} - -// NewBlockDevice returns a new mock block device. -func (h *Harness) NewBlockDevice() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeBlockDevice}) - } -} - -// NewCharacterDevice returns a new mock character device. -func (h *Harness) NewCharacterDevice() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeCharacterDevice}) - } -} - -// NewNamedPipe returns a new mock named pipe. -func (h *Harness) NewNamedPipe() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeNamedPipe}) - } -} - -// NewSocket returns a new mock socket. -func (h *Harness) NewSocket() Generator { - return func(parent *Mock) *Mock { - return h.NewMock(parent, MakePath(), p9.Attr{Mode: p9.ModeSocket}) - } -} - -// Finish completes all checks and shuts down the server. -func (h *Harness) Finish() { - h.clientSocket.Close() - h.wg.Wait() - h.mockCtrl.Finish() -} - -// NewHarness creates and returns a new test server. -// -// It should always be used as: -// -// h, c := NewHarness(t) -// defer h.Finish() -// -func NewHarness(t *testing.T) (*Harness, *p9.Client) { - // Create the mock. - mockCtrl := gomock.NewController(t) - h := &Harness{ - t: t, - mockCtrl: mockCtrl, - Attacher: NewMockAttacher(mockCtrl), - } - - // Make socket pair. - serverSocket, clientSocket, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v wanted nil", err) - } - - // Start the server, synchronized on exit. - server := p9.NewServer(h.Attacher) - h.wg.Add(1) - go func() { - defer h.wg.Done() - server.Handle(serverSocket) - }() - - // Create the client. - client, err := p9.NewClient(clientSocket, 1024, p9.HighestVersionString()) - if err != nil { - serverSocket.Close() - clientSocket.Close() - t.Fatalf("new client got %v, expected nil", err) - return nil, nil // Never hit. - } - - // Capture the client socket. - h.clientSocket = clientSocket - return h, client -} diff --git a/pkg/p9/pool_test.go b/pkg/p9/pool_test.go deleted file mode 100644 index e4746b8da..000000000 --- a/pkg/p9/pool_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "testing" -) - -func TestPoolUnique(t *testing.T) { - p := pool{start: 1, limit: 3} - got := make(map[uint64]bool) - - for { - n, ok := p.Get() - if !ok { - break - } - - // Check unique. - if _, ok := got[n]; ok { - t.Errorf("pool spit out %v multiple times", n) - } - - // Record. - got[n] = true - } -} - -func TestExausted(t *testing.T) { - p := pool{start: 1, limit: 500} - for i := 0; i < 499; i++ { - _, ok := p.Get() - if !ok { - t.Fatalf("pool exhausted before 499 items") - } - } - - _, ok := p.Get() - if ok { - t.Errorf("pool not exhausted when it should be") - } -} - -func TestPoolRecycle(t *testing.T) { - p := pool{start: 1, limit: 500} - n1, _ := p.Get() - p.Put(n1) - n2, _ := p.Get() - if n1 != n2 { - t.Errorf("pool not recycling items") - } -} diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go deleted file mode 100644 index cdb3bc841..000000000 --- a/pkg/p9/transport_test.go +++ /dev/null @@ -1,229 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "io/ioutil" - "os" - "testing" - - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/unet" -) - -const ( - MsgTypeBadEncode = iota + 252 - MsgTypeBadDecode - MsgTypeUnregistered -) - -func TestSendRecv(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - if err := send(client, Tag(1), &Tlopen{}); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - tag, m, err := recv(server, maximumLength, msgRegistry.get) - if err != nil { - t.Fatalf("recv got err %v expected nil", err) - } - if tag != Tag(1) { - t.Fatalf("got tag %v expected 1", tag) - } - if _, ok := m.(*Tlopen); !ok { - t.Fatalf("got message %v expected *Tlopen", m) - } -} - -// badDecode overruns on decode. -type badDecode struct{} - -func (*badDecode) Decode(b *buffer) { b.markOverrun() } -func (*badDecode) Encode(b *buffer) {} -func (*badDecode) Type() MsgType { return MsgTypeBadDecode } -func (*badDecode) String() string { return "badDecode{}" } - -func TestRecvOverrun(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - if err := send(client, Tag(1), &badDecode{}); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - if _, _, err := recv(server, maximumLength, msgRegistry.get); err == nil { - t.Fatalf("recv got err %v expected ErrSocket{ErrNoValidMessage}", err) - } -} - -// unregistered is not registered on decode. -type unregistered struct{} - -func (*unregistered) Decode(b *buffer) {} -func (*unregistered) Encode(b *buffer) {} -func (*unregistered) Type() MsgType { return MsgTypeUnregistered } -func (*unregistered) String() string { return "unregistered{}" } - -func TestRecvInvalidType(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - if err := send(client, Tag(1), &unregistered{}); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - _, _, err = recv(server, maximumLength, msgRegistry.get) - if _, ok := err.(*ErrInvalidMsgType); !ok { - t.Fatalf("recv got err %v expected ErrInvalidMsgType", err) - } -} - -func TestSendRecvWithFile(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - // Create a tempfile. - osf, err := ioutil.TempFile("", "p9") - if err != nil { - t.Fatalf("tempfile got err %v expected nil", err) - } - os.Remove(osf.Name()) - f, err := fd.NewFromFile(osf) - osf.Close() - if err != nil { - t.Fatalf("unable to create file: %v", err) - } - - if err := send(client, Tag(1), &Rlopen{File: f}); err != nil { - t.Fatalf("send got err %v expected nil", err) - } - - // Enable withFile. - tag, m, err := recv(server, maximumLength, msgRegistry.get) - if err != nil { - t.Fatalf("recv got err %v expected nil", err) - } - if tag != Tag(1) { - t.Fatalf("got tag %v expected 1", tag) - } - rlopen, ok := m.(*Rlopen) - if !ok { - t.Fatalf("got m %v expected *Rlopen", m) - } - if rlopen.File == nil { - t.Fatalf("got nil file expected non-nil") - } -} - -func TestRecvClosed(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - client.Close() - - _, _, err = recv(server, maximumLength, msgRegistry.get) - if err == nil { - t.Fatalf("got err nil expected non-nil") - } - if _, ok := err.(ErrSocket); !ok { - t.Fatalf("got err %v expected ErrSocket", err) - } -} - -func TestSendClosed(t *testing.T) { - server, client, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("socketpair got err %v expected nil", err) - } - server.Close() - defer client.Close() - - err = send(client, Tag(1), &Tlopen{}) - if err == nil { - t.Fatalf("send got err nil expected non-nil") - } - if _, ok := err.(ErrSocket); !ok { - t.Fatalf("got err %v expected ErrSocket", err) - } -} - -func BenchmarkSendRecv(b *testing.B) { - server, client, err := unet.SocketPair(false) - if err != nil { - b.Fatalf("socketpair got err %v expected nil", err) - } - defer server.Close() - defer client.Close() - - // Exchange Rflush messages since these contain no data and therefore incur - // no additional marshaling overhead. - go func() { - for i := 0; i < b.N; i++ { - tag, m, err := recv(server, maximumLength, msgRegistry.get) - if err != nil { - b.Fatalf("recv got err %v expected nil", err) - } - if tag != Tag(1) { - b.Fatalf("got tag %v expected 1", tag) - } - if _, ok := m.(*Rflush); !ok { - b.Fatalf("got message %T expected *Rflush", m) - } - if err := send(server, Tag(2), &Rflush{}); err != nil { - b.Fatalf("send got err %v expected nil", err) - } - } - }() - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := send(client, Tag(1), &Rflush{}); err != nil { - b.Fatalf("send got err %v expected nil", err) - } - tag, m, err := recv(client, maximumLength, msgRegistry.get) - if err != nil { - b.Fatalf("recv got err %v expected nil", err) - } - if tag != Tag(2) { - b.Fatalf("got tag %v expected 2", tag) - } - if _, ok := m.(*Rflush); !ok { - b.Fatalf("got message %v expected *Rflush", m) - } - } -} - -func init() { - msgRegistry.register(MsgTypeBadDecode, func() message { return &badDecode{} }) -} diff --git a/pkg/p9/version_test.go b/pkg/p9/version_test.go deleted file mode 100644 index 291e8580e..000000000 --- a/pkg/p9/version_test.go +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package p9 - -import ( - "testing" -) - -func TestVersionNumberEquivalent(t *testing.T) { - for i := uint32(0); i < 1024; i++ { - str := versionString(i) - version, ok := parseVersion(str) - if !ok { - t.Errorf("#%d: parseVersion(%q) failed, want success", i, str) - continue - } - if i != version { - t.Errorf("#%d: got version %d, want %d", i, i, version) - } - } -} - -func TestVersionStringEquivalent(t *testing.T) { - // There is one case where the version is not equivalent on purpose, - // that is 9P2000.L.Google.0. It is not equivalent because versionString - // must always return the more generic 9P2000.L for legacy servers that - // check for it. See net/9p/client.c. - str := "9P2000.L.Google.0" - version, ok := parseVersion(str) - if !ok { - t.Errorf("parseVersion(%q) failed, want success", str) - } - if got := versionString(version); got != "9P2000.L" { - t.Errorf("versionString(%d) got %q, want %q", version, got, "9P2000.L") - } - - for _, test := range []struct { - versionString string - }{ - { - versionString: "9P2000.L", - }, - { - versionString: "9P2000.L.Google.1", - }, - { - versionString: "9P2000.L.Google.347823894", - }, - } { - version, ok := parseVersion(test.versionString) - if !ok { - t.Errorf("parseVersion(%q) failed, want success", test.versionString) - continue - } - if got := versionString(version); got != test.versionString { - t.Errorf("versionString(%d) got %q, want %q", version, got, test.versionString) - } - } -} - -func TestParseVersion(t *testing.T) { - for _, test := range []struct { - versionString string - expectSuccess bool - expectedVersion uint32 - }{ - { - versionString: "9P", - expectSuccess: false, - }, - { - versionString: "9P.L", - expectSuccess: false, - }, - { - versionString: "9P200.L", - expectSuccess: false, - }, - { - versionString: "9P2000", - expectSuccess: false, - }, - { - versionString: "9P2000.L.Google.-1", - expectSuccess: false, - }, - { - versionString: "9P2000.L.Google.", - expectSuccess: false, - }, - { - versionString: "9P2000.L.Google.3546343826724305832", - expectSuccess: false, - }, - { - versionString: "9P2001.L", - expectSuccess: false, - }, - { - versionString: "9P2000.L", - expectSuccess: true, - expectedVersion: 0, - }, - { - versionString: "9P2000.L.Google.0", - expectSuccess: true, - expectedVersion: 0, - }, - { - versionString: "9P2000.L.Google.1", - expectSuccess: true, - expectedVersion: 1, - }, - } { - version, ok := parseVersion(test.versionString) - if ok != test.expectSuccess { - t.Errorf("parseVersion(%q) got (_, %v), want (_, %v)", test.versionString, ok, test.expectSuccess) - continue - } - if !test.expectSuccess { - continue - } - if version != test.expectedVersion { - t.Errorf("parseVersion(%q) got (%d, _), want (%d, _)", test.versionString, version, test.expectedVersion) - } - } -} - -func BenchmarkParseVersion(b *testing.B) { - for n := 0; n < b.N; n++ { - parseVersion("9P2000.L.Google.1") - } -} diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD deleted file mode 100644 index 697e7a2f4..000000000 --- a/pkg/procid/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "procid", - srcs = [ - "procid.go", - "procid_amd64.s", - "procid_arm64.s", - ], - importpath = "gvisor.dev/gvisor/pkg/procid", - visibility = ["//visibility:public"], -) - -go_test( - name = "procid_test", - size = "small", - srcs = [ - "procid_test.go", - ], - embed = [":procid"], -) - -go_test( - name = "procid_net_test", - size = "small", - srcs = [ - "procid_net_test.go", - "procid_test.go", - ], - embed = [":procid"], -) diff --git a/pkg/procid/procid_net_test.go b/pkg/procid/procid_net_test.go deleted file mode 100644 index b628e2285..000000000 --- a/pkg/procid/procid_net_test.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package procid - -// This file is just to force the inclusion of the "net" package, which will -// make the test binary a cgo one. -import ( - _ "net" -) diff --git a/pkg/procid/procid_state_autogen.go b/pkg/procid/procid_state_autogen.go new file mode 100755 index 000000000..f27a7c510 --- /dev/null +++ b/pkg/procid/procid_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package procid + diff --git a/pkg/procid/procid_test.go b/pkg/procid/procid_test.go deleted file mode 100644 index 88dd0b3ae..000000000 --- a/pkg/procid/procid_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package procid - -import ( - "os" - "runtime" - "sync" - "syscall" - "testing" -) - -// runOnMain is used to send functions to run on the main (initial) thread. -var runOnMain = make(chan func(), 10) - -func checkProcid(t *testing.T, start *sync.WaitGroup, done *sync.WaitGroup) { - defer done.Done() - - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - start.Done() - start.Wait() - - procID := Current() - tid := syscall.Gettid() - - if procID != uint64(tid) { - t.Logf("Bad procid: expected %v, got %v", tid, procID) - t.Fail() - } -} - -func TestProcidInitialized(t *testing.T) { - var start sync.WaitGroup - var done sync.WaitGroup - - count := 100 - start.Add(count + 1) - done.Add(count + 1) - - // Run the check on the main thread. - // - // When cgo is not included, the only case when procid isn't initialized - // is in the main (initial) thread, so we have to test this case - // specifically. - runOnMain <- func() { - checkProcid(t, &start, &done) - } - - // Run the check on a number of different threads. - for i := 0; i < count; i++ { - go checkProcid(t, &start, &done) - } - - done.Wait() -} - -func TestMain(m *testing.M) { - // Make sure we remain at the main (initial) thread. - runtime.LockOSThread() - - // Start running tests in a different goroutine. - go func() { - os.Exit(m.Run()) - }() - - // Execute any functions that have been sent for execution in the main - // thread. - for f := range runOnMain { - f() - } -} diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD deleted file mode 100644 index f4f2001f3..000000000 --- a/pkg/rand/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "rand", - srcs = [ - "rand.go", - "rand_linux.go", - ], - importpath = "gvisor.dev/gvisor/pkg/rand", - visibility = ["//:sandbox"], - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/rand/rand_state_autogen.go b/pkg/rand/rand_state_autogen.go new file mode 100755 index 000000000..e46e9ec7e --- /dev/null +++ b/pkg/rand/rand_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package rand + diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD deleted file mode 100644 index 0652fbbe6..000000000 --- a/pkg/refs/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "weak_ref_list", - out = "weak_ref_list.go", - package = "refs", - prefix = "weakRef", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*WeakRef", - "Linker": "*WeakRef", - }, -) - -go_library( - name = "refs", - srcs = [ - "refcounter.go", - "refcounter_state.go", - "weak_ref_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/refs", - visibility = ["//:sandbox"], -) - -go_test( - name = "refs_test", - size = "small", - srcs = ["refcounter_test.go"], - embed = [":refs"], -) diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go deleted file mode 100644 index ffd3d3f07..000000000 --- a/pkg/refs/refcounter_test.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package refs - -import ( - "reflect" - "sync" - "testing" -) - -type testCounter struct { - AtomicRefCount - - // mu protects the boolean below. - mu sync.Mutex - - // destroyed indicates whether this was destroyed. - destroyed bool -} - -func (t *testCounter) DecRef() { - t.AtomicRefCount.DecRefWithDestructor(t.destroy) -} - -func (t *testCounter) destroy() { - t.mu.Lock() - defer t.mu.Unlock() - t.destroyed = true -} - -func (t *testCounter) IsDestroyed() bool { - t.mu.Lock() - defer t.mu.Unlock() - return t.destroyed -} - -func newTestCounter() *testCounter { - return &testCounter{destroyed: false} -} - -func TestOneRef(t *testing.T) { - tc := newTestCounter() - tc.DecRef() - - if !tc.IsDestroyed() { - t.Errorf("object should have been destroyed") - } -} - -func TestTwoRefs(t *testing.T) { - tc := newTestCounter() - tc.IncRef() - tc.DecRef() - tc.DecRef() - - if !tc.IsDestroyed() { - t.Errorf("object should have been destroyed") - } -} - -func TestMultiRefs(t *testing.T) { - tc := newTestCounter() - tc.IncRef() - tc.DecRef() - - tc.IncRef() - tc.DecRef() - - tc.DecRef() - - if !tc.IsDestroyed() { - t.Errorf("object should have been destroyed") - } -} - -func TestWeakRef(t *testing.T) { - tc := newTestCounter() - w := NewWeakRef(tc, nil) - - // Try resolving. - if x := w.Get(); x == nil { - t.Errorf("weak reference didn't resolve: expected %v, got nil", tc) - } else { - x.DecRef() - } - - // Try resolving again. - if x := w.Get(); x == nil { - t.Errorf("weak reference didn't resolve: expected %v, got nil", tc) - } else { - x.DecRef() - } - - // Shouldn't be destroyed yet. (Can't continue if this fails.) - if tc.IsDestroyed() { - t.Fatalf("original object destroyed earlier than expected") - } - - // Drop the original reference. - tc.DecRef() - - // Assert destroyed. - if !tc.IsDestroyed() { - t.Errorf("original object not destroyed as expected") - } - - // Shouldn't be anything. - if x := w.Get(); x != nil { - t.Errorf("weak reference resolved: expected nil, got %v", x) - } -} - -func TestWeakRefDrop(t *testing.T) { - tc := newTestCounter() - w := NewWeakRef(tc, nil) - w.Drop() - - // Just assert the list is empty. - if !tc.weakRefs.Empty() { - t.Errorf("weak reference not dropped") - } - - // Drop the original reference. - tc.DecRef() -} - -type testWeakRefUser struct { - weakRefGone func() -} - -func (u *testWeakRefUser) WeakRefGone() { - u.weakRefGone() -} - -func TestCallback(t *testing.T) { - called := false - tc := newTestCounter() - var w *WeakRef - w = NewWeakRef(tc, &testWeakRefUser{func() { - called = true - - // Check that the weak ref has been zapped. - rc := w.obj.Load().(RefCounter) - if v := reflect.ValueOf(rc); v != reflect.Zero(v.Type()) { - t.Fatalf("Callback called with non-nil ptr") - } - - // Check that we're not holding the mutex by acquiring and - // releasing it. - tc.mu.Lock() - tc.mu.Unlock() - }}) - - // Drop the original reference, this must trigger the callback. - tc.DecRef() - - if !called { - t.Fatalf("Callback not called") - } -} diff --git a/pkg/refs/refs_state_autogen.go b/pkg/refs/refs_state_autogen.go new file mode 100755 index 000000000..959b6603f --- /dev/null +++ b/pkg/refs/refs_state_autogen.go @@ -0,0 +1,77 @@ +// automatically generated by stateify. + +package refs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *WeakRef) beforeSave() {} +func (x *WeakRef) save(m state.Map) { + x.beforeSave() + var obj savedReference = x.saveObj() + m.SaveValue("obj", obj) + m.Save("user", &x.user) +} + +func (x *WeakRef) afterLoad() {} +func (x *WeakRef) load(m state.Map) { + m.Load("user", &x.user) + m.LoadValue("obj", new(savedReference), func(y interface{}) { x.loadObj(y.(savedReference)) }) +} + +func (x *AtomicRefCount) beforeSave() {} +func (x *AtomicRefCount) save(m state.Map) { + x.beforeSave() + m.Save("refCount", &x.refCount) +} + +func (x *AtomicRefCount) afterLoad() {} +func (x *AtomicRefCount) load(m state.Map) { + m.Load("refCount", &x.refCount) +} + +func (x *savedReference) beforeSave() {} +func (x *savedReference) save(m state.Map) { + x.beforeSave() + m.Save("obj", &x.obj) +} + +func (x *savedReference) afterLoad() {} +func (x *savedReference) load(m state.Map) { + m.Load("obj", &x.obj) +} + +func (x *weakRefList) beforeSave() {} +func (x *weakRefList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *weakRefList) afterLoad() {} +func (x *weakRefList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *weakRefEntry) beforeSave() {} +func (x *weakRefEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *weakRefEntry) afterLoad() {} +func (x *weakRefEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("refs.WeakRef", (*WeakRef)(nil), state.Fns{Save: (*WeakRef).save, Load: (*WeakRef).load}) + state.Register("refs.AtomicRefCount", (*AtomicRefCount)(nil), state.Fns{Save: (*AtomicRefCount).save, Load: (*AtomicRefCount).load}) + state.Register("refs.savedReference", (*savedReference)(nil), state.Fns{Save: (*savedReference).save, Load: (*savedReference).load}) + state.Register("refs.weakRefList", (*weakRefList)(nil), state.Fns{Save: (*weakRefList).save, Load: (*weakRefList).load}) + state.Register("refs.weakRefEntry", (*weakRefEntry)(nil), state.Fns{Save: (*weakRefEntry).save, Load: (*weakRefEntry).load}) +} diff --git a/pkg/refs/weak_ref_list.go b/pkg/refs/weak_ref_list.go new file mode 100755 index 000000000..df8e98bf5 --- /dev/null +++ b/pkg/refs/weak_ref_list.go @@ -0,0 +1,173 @@ +package refs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type weakRefElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (weakRefElementMapper) linkerFor(elem *WeakRef) *WeakRef { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type weakRefList struct { + head *WeakRef + tail *WeakRef +} + +// Reset resets list l to the empty state. +func (l *weakRefList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *weakRefList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *weakRefList) Front() *WeakRef { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *weakRefList) Back() *WeakRef { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *weakRefList) PushFront(e *WeakRef) { + weakRefElementMapper{}.linkerFor(e).SetNext(l.head) + weakRefElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + weakRefElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *weakRefList) PushBack(e *WeakRef) { + weakRefElementMapper{}.linkerFor(e).SetNext(nil) + weakRefElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + weakRefElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *weakRefList) PushBackList(m *weakRefList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + weakRefElementMapper{}.linkerFor(l.tail).SetNext(m.head) + weakRefElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *weakRefList) InsertAfter(b, e *WeakRef) { + a := weakRefElementMapper{}.linkerFor(b).Next() + weakRefElementMapper{}.linkerFor(e).SetNext(a) + weakRefElementMapper{}.linkerFor(e).SetPrev(b) + weakRefElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + weakRefElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *weakRefList) InsertBefore(a, e *WeakRef) { + b := weakRefElementMapper{}.linkerFor(a).Prev() + weakRefElementMapper{}.linkerFor(e).SetNext(a) + weakRefElementMapper{}.linkerFor(e).SetPrev(b) + weakRefElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + weakRefElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *weakRefList) Remove(e *WeakRef) { + prev := weakRefElementMapper{}.linkerFor(e).Prev() + next := weakRefElementMapper{}.linkerFor(e).Next() + + if prev != nil { + weakRefElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + weakRefElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type weakRefEntry struct { + next *WeakRef + prev *WeakRef +} + +// Next returns the entry that follows e in the list. +func (e *weakRefEntry) Next() *WeakRef { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *weakRefEntry) Prev() *WeakRef { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *weakRefEntry) SetNext(elem *WeakRef) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *weakRefEntry) SetPrev(elem *WeakRef) { + e.prev = elem +} diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD deleted file mode 100644 index d1024e49d..000000000 --- a/pkg/seccomp/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data") - -package(licenses = ["notice"]) - -go_binary( - name = "victim", - testonly = 1, - srcs = ["seccomp_test_victim.go"], - deps = [":seccomp"], -) - -go_embed_data( - name = "victim_data", - testonly = 1, - src = "victim", - package = "seccomp", - var = "victimData", -) - -go_library( - name = "seccomp", - srcs = [ - "seccomp.go", - "seccomp_amd64.go", - "seccomp_arm64.go", - "seccomp_rules.go", - "seccomp_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/seccomp", - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi/linux", - "//pkg/bpf", - "//pkg/log", - ], -) - -go_test( - name = "seccomp_test", - size = "small", - srcs = [ - "seccomp_test.go", - ":victim_data", - ], - embed = [":seccomp"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/bpf", - ], -) diff --git a/pkg/seccomp/seccomp_state_autogen.go b/pkg/seccomp/seccomp_state_autogen.go new file mode 100755 index 000000000..0fc23d1a8 --- /dev/null +++ b/pkg/seccomp/seccomp_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package seccomp + diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go deleted file mode 100644 index 353686ed3..000000000 --- a/pkg/seccomp/seccomp_test.go +++ /dev/null @@ -1,505 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seccomp - -import ( - "bytes" - "fmt" - "io" - "io/ioutil" - "math" - "math/rand" - "os" - "os/exec" - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/bpf" -) - -type seccompData struct { - nr uint32 - arch uint32 - instructionPointer uint64 - args [6]uint64 -} - -// newVictim makes a victim binary. -func newVictim() (string, error) { - f, err := ioutil.TempFile("", "victim") - if err != nil { - return "", err - } - defer f.Close() - path := f.Name() - if _, err := io.Copy(f, bytes.NewBuffer(victimData)); err != nil { - os.Remove(path) - return "", err - } - if err := os.Chmod(path, 0755); err != nil { - os.Remove(path) - return "", err - } - return path, nil -} - -// asInput converts a seccompData to a bpf.Input. -func (d *seccompData) asInput() bpf.Input { - return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} -} - -func TestBasic(t *testing.T) { - type spec struct { - // desc is the test's description. - desc string - - // data is the input data. - data seccompData - - // want is the expected return value of the BPF program. - want linux.BPFAction - } - - for _, test := range []struct { - ruleSets []RuleSet - defaultAction linux.BPFAction - specs []spec - }{ - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{1: {}}, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Single syscall allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Single syscall disallowed", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - AllowValue(0x1), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - { - Rules: SyscallRules{ - 1: {}, - 2: {}, - }, - Action: linux.SECCOMP_RET_TRAP, - }, - }, - defaultAction: linux.SECCOMP_RET_KILL_THREAD, - specs: []spec{ - { - desc: "Multiple rulesets allowed (1a)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x1}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Multiple rulesets allowed (1b)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple rulesets allowed (2)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple rulesets allowed (2)", - data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_KILL_THREAD, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: {}, - 3: {}, - 5: {}, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Multiple syscalls allowed (1)", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Multiple syscalls allowed (3)", - data: seccompData{nr: 3, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Multiple syscalls allowed (5)", - data: seccompData{nr: 5, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Multiple syscalls disallowed (0)", - data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple syscalls disallowed (2)", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple syscalls disallowed (4)", - data: seccompData{nr: 4, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple syscalls disallowed (6)", - data: seccompData{nr: 6, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "Multiple syscalls disallowed (100)", - data: seccompData{nr: 100, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: {}, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Wrong architecture", - data: seccompData{nr: 1, arch: 123}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: {}, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Syscall disallowed, action trap", - data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - AllowAny{}, - AllowValue(0xf), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Syscall argument allowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xf}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Syscall argument disallowed", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xe}}, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - AllowValue(0xf), - }, - { - AllowValue(0xe), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "Syscall argument allowed, two rules", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf}}, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "Syscall argument allowed, two rules", - data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xe}}, - want: linux.SECCOMP_RET_ALLOW, - }, - }, - }, - { - ruleSets: []RuleSet{ - { - Rules: SyscallRules{ - 1: []Rule{ - { - AllowValue(0), - AllowValue(math.MaxUint64 - 1), - AllowValue(math.MaxUint32), - }, - }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, - defaultAction: linux.SECCOMP_RET_TRAP, - specs: []spec{ - { - desc: "64bit syscall argument allowed", - data: seccompData{ - nr: 1, - arch: linux.AUDIT_ARCH_X86_64, - args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32}, - }, - want: linux.SECCOMP_RET_ALLOW, - }, - { - desc: "64bit syscall argument disallowed", - data: seccompData{ - nr: 1, - arch: linux.AUDIT_ARCH_X86_64, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32}, - }, - want: linux.SECCOMP_RET_TRAP, - }, - { - desc: "64bit syscall argument disallowed", - data: seccompData{ - nr: 1, - arch: linux.AUDIT_ARCH_X86_64, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, - }, - want: linux.SECCOMP_RET_TRAP, - }, - }, - }, - } { - instrs, err := BuildProgram(test.ruleSets, test.defaultAction) - if err != nil { - t.Errorf("%s: buildProgram() got error: %v", test.specs[0].desc, err) - continue - } - p, err := bpf.Compile(instrs) - if err != nil { - t.Errorf("%s: bpf.Compile() got error: %v", test.specs[0].desc, err) - continue - } - for _, spec := range test.specs { - got, err := bpf.Exec(p, spec.data.asInput()) - if err != nil { - t.Errorf("%s: bpf.Exec() got error: %v", spec.desc, err) - continue - } - if got != uint32(spec.want) { - t.Errorf("%s: bpd.Exec() = %d, want: %d", spec.desc, got, spec.want) - } - } - } -} - -// TestRandom tests that randomly generated rules are encoded correctly. -func TestRandom(t *testing.T) { - rand.Seed(time.Now().UnixNano()) - size := rand.Intn(50) + 1 - syscallRules := make(map[uintptr][]Rule) - for len(syscallRules) < size { - n := uintptr(rand.Intn(200)) - if _, ok := syscallRules[n]; !ok { - syscallRules[n] = []Rule{} - } - } - - fmt.Printf("Testing filters: %v", syscallRules) - instrs, err := BuildProgram([]RuleSet{ - RuleSet{ - Rules: syscallRules, - Action: linux.SECCOMP_RET_ALLOW, - }, - }, linux.SECCOMP_RET_TRAP) - if err != nil { - t.Fatalf("buildProgram() got error: %v", err) - } - p, err := bpf.Compile(instrs) - if err != nil { - t.Fatalf("bpf.Compile() got error: %v", err) - } - for i := uint32(0); i < 200; i++ { - data := seccompData{nr: i, arch: linux.AUDIT_ARCH_X86_64} - got, err := bpf.Exec(p, data.asInput()) - if err != nil { - t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i) - continue - } - want := linux.SECCOMP_RET_TRAP - if _, ok := syscallRules[uintptr(i)]; ok { - want = linux.SECCOMP_RET_ALLOW - } - if got != uint32(want) { - t.Errorf("bpf.Exec() = %d, want: %d, for syscall %d", got, want, i) - } - } -} - -// TestReadDeal checks that a process dies when it trips over the filter and -// that it doesn't die when the filter is not triggered. -func TestRealDeal(t *testing.T) { - for _, test := range []struct { - die bool - want string - }{ - {die: true, want: "bad system call"}, - {die: false, want: "Syscall was allowed!!!"}, - } { - victim, err := newVictim() - if err != nil { - t.Fatalf("unable to get victim: %v", err) - } - defer os.Remove(victim) - dieFlag := fmt.Sprintf("-die=%v", test.die) - cmd := exec.Command(victim, dieFlag) - - out, err := cmd.CombinedOutput() - if test.die { - if err == nil { - t.Errorf("victim was not killed as expected, output: %s", out) - continue - } - // Depending on kernel version, either RET_TRAP or RET_KILL_PROCESS is - // used. RET_TRAP dumps reason for exit in output, while RET_KILL_PROCESS - // returns SIGSYS as exit status. - if !strings.Contains(string(out), test.want) && - !strings.Contains(err.Error(), test.want) { - t.Errorf("Victim error is wrong, got: %v, err: %v, want: %v", string(out), err, test.want) - continue - } - } else { - if err != nil { - t.Errorf("victim failed to execute, err: %v", err) - continue - } - if !strings.Contains(string(out), test.want) { - t.Errorf("Victim output is wrong, got: %v, want: %v", string(out), test.want) - continue - } - } - } -} - -// TestMerge ensures that empty rules are not erased when rules are merged. -func TestMerge(t *testing.T) { - for _, tst := range []struct { - name string - main []Rule - merge []Rule - want []Rule - }{ - { - name: "empty both", - main: nil, - merge: nil, - want: []Rule{{}, {}}, - }, - { - name: "empty main", - main: nil, - merge: []Rule{{}}, - want: []Rule{{}, {}}, - }, - { - name: "empty merge", - main: []Rule{{}}, - merge: nil, - want: []Rule{{}, {}}, - }, - } { - t.Run(tst.name, func(t *testing.T) { - mainRules := SyscallRules{1: tst.main} - mergeRules := SyscallRules{1: tst.merge} - mainRules.Merge(mergeRules) - if got, want := len(mainRules[1]), len(tst.want); got != want { - t.Errorf("wrong length, got: %d, want: %d", got, want) - } - for i, r := range mainRules[1] { - if r != tst.want[i] { - t.Errorf("result, got: %v, want: %v", r, tst.want[i]) - } - } - }) - } -} - -// TestAddRule ensures that empty rules are not erased when rules are added. -func TestAddRule(t *testing.T) { - rules := SyscallRules{1: {}} - rules.AddRule(1, Rule{}) - if got, want := len(rules[1]), 2; got != want { - t.Errorf("len(rules[1]), got: %d, want: %d", got, want) - } -} diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go deleted file mode 100644 index 62ae1fd9f..000000000 --- a/pkg/seccomp/seccomp_test_victim.go +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Test binary used to test that seccomp filters are properly constructed and -// indeed kill the process on violation. -package main - -import ( - "flag" - "fmt" - "os" - "syscall" - - "gvisor.dev/gvisor/pkg/seccomp" -) - -func main() { - dieFlag := flag.Bool("die", false, "trips over the filter if true") - flag.Parse() - - syscalls := seccomp.SyscallRules{ - syscall.SYS_ACCEPT: {}, - syscall.SYS_ARCH_PRCTL: {}, - syscall.SYS_BIND: {}, - syscall.SYS_BRK: {}, - syscall.SYS_CLOCK_GETTIME: {}, - syscall.SYS_CLONE: {}, - syscall.SYS_CLOSE: {}, - syscall.SYS_DUP: {}, - syscall.SYS_DUP2: {}, - syscall.SYS_EPOLL_CREATE1: {}, - syscall.SYS_EPOLL_CTL: {}, - syscall.SYS_EPOLL_WAIT: {}, - syscall.SYS_EPOLL_PWAIT: {}, - syscall.SYS_EXIT: {}, - syscall.SYS_EXIT_GROUP: {}, - syscall.SYS_FALLOCATE: {}, - syscall.SYS_FCHMOD: {}, - syscall.SYS_FCNTL: {}, - syscall.SYS_FSTAT: {}, - syscall.SYS_FSYNC: {}, - syscall.SYS_FTRUNCATE: {}, - syscall.SYS_FUTEX: {}, - syscall.SYS_GETDENTS64: {}, - syscall.SYS_GETPEERNAME: {}, - syscall.SYS_GETPID: {}, - syscall.SYS_GETSOCKNAME: {}, - syscall.SYS_GETSOCKOPT: {}, - syscall.SYS_GETTID: {}, - syscall.SYS_GETTIMEOFDAY: {}, - syscall.SYS_LISTEN: {}, - syscall.SYS_LSEEK: {}, - syscall.SYS_MADVISE: {}, - syscall.SYS_MINCORE: {}, - syscall.SYS_MMAP: {}, - syscall.SYS_MPROTECT: {}, - syscall.SYS_MUNLOCK: {}, - syscall.SYS_MUNMAP: {}, - syscall.SYS_NANOSLEEP: {}, - syscall.SYS_NEWFSTATAT: {}, - syscall.SYS_OPEN: {}, - syscall.SYS_POLL: {}, - syscall.SYS_PREAD64: {}, - syscall.SYS_PSELECT6: {}, - syscall.SYS_PWRITE64: {}, - syscall.SYS_READ: {}, - syscall.SYS_READLINKAT: {}, - syscall.SYS_READV: {}, - syscall.SYS_RECVMSG: {}, - syscall.SYS_RENAMEAT: {}, - syscall.SYS_RESTART_SYSCALL: {}, - syscall.SYS_RT_SIGACTION: {}, - syscall.SYS_RT_SIGPROCMASK: {}, - syscall.SYS_RT_SIGRETURN: {}, - syscall.SYS_SCHED_YIELD: {}, - syscall.SYS_SENDMSG: {}, - syscall.SYS_SETITIMER: {}, - syscall.SYS_SET_ROBUST_LIST: {}, - syscall.SYS_SETSOCKOPT: {}, - syscall.SYS_SHUTDOWN: {}, - syscall.SYS_SIGALTSTACK: {}, - syscall.SYS_SOCKET: {}, - syscall.SYS_SYNC_FILE_RANGE: {}, - syscall.SYS_TGKILL: {}, - syscall.SYS_UTIMENSAT: {}, - syscall.SYS_WRITE: {}, - syscall.SYS_WRITEV: {}, - } - die := *dieFlag - if !die { - syscalls[syscall.SYS_OPENAT] = []seccomp.Rule{ - { - seccomp.AllowValue(10), - }, - } - } - - if err := seccomp.Install(syscalls); err != nil { - fmt.Printf("Failed to install seccomp: %v", err) - os.Exit(1) - } - fmt.Printf("Filters installed\n") - - syscall.RawSyscall(syscall.SYS_OPENAT, 10, 0, 0) - fmt.Printf("Syscall was allowed!!!\n") -} diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD deleted file mode 100644 index f38fb39f3..000000000 --- a/pkg/secio/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "secio", - srcs = [ - "full_reader.go", - "secio.go", - ], - importpath = "gvisor.dev/gvisor/pkg/secio", - visibility = ["//pkg/sentry:internal"], -) - -go_test( - name = "secio_test", - size = "small", - srcs = ["secio_test.go"], - embed = [":secio"], -) diff --git a/pkg/secio/secio_state_autogen.go b/pkg/secio/secio_state_autogen.go new file mode 100755 index 000000000..ec559f264 --- /dev/null +++ b/pkg/secio/secio_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package secio + diff --git a/pkg/secio/secio_test.go b/pkg/secio/secio_test.go deleted file mode 100644 index d1d905187..000000000 --- a/pkg/secio/secio_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package secio - -import ( - "bytes" - "errors" - "io" - "io/ioutil" - "math" - "testing" -) - -var errEndOfBuffer = errors.New("write beyond end of buffer") - -// buffer resembles bytes.Buffer, but implements io.ReaderAt and io.WriterAt. -// Reads beyond the end of the buffer return io.EOF. Writes beyond the end of -// the buffer return errEndOfBuffer. -type buffer struct { - Bytes []byte -} - -// ReadAt implements io.ReaderAt.ReadAt. -func (b *buffer) ReadAt(dst []byte, off int64) (int, error) { - if off >= int64(len(b.Bytes)) { - return 0, io.EOF - } - n := copy(dst, b.Bytes[off:]) - if n < len(dst) { - return n, io.EOF - } - return n, nil -} - -// WriteAt implements io.WriterAt.WriteAt. -func (b *buffer) WriteAt(src []byte, off int64) (int, error) { - if off >= int64(len(b.Bytes)) { - return 0, errEndOfBuffer - } - n := copy(b.Bytes[off:], src) - if n < len(src) { - return n, errEndOfBuffer - } - return n, nil -} - -func newBufferString(s string) *buffer { - return &buffer{[]byte(s)} -} - -func TestOffsetReader(t *testing.T) { - buf := newBufferString("foobar") - r := NewOffsetReader(buf, 3) - dst, err := ioutil.ReadAll(r) - if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil { - t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want) - } -} - -func TestSectionReader(t *testing.T) { - buf := newBufferString("foobarbaz") - r := NewSectionReader(buf, 3, 3) - dst, err := ioutil.ReadAll(r) - if want, wantErr := []byte("bar"), ErrReachedLimit; !bytes.Equal(dst, want) || err != wantErr { - t.Errorf("ReadAll: got (%q, %v), wanted (%q, %v)", dst, err, want, wantErr) - } -} - -func TestSectionReaderLimitOverflow(t *testing.T) { - // SectionReader behaves like OffsetReader when limit overflows int64. - buf := newBufferString("foobar") - r := NewSectionReader(buf, 3, math.MaxInt64) - dst, err := ioutil.ReadAll(r) - if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil { - t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want) - } -} - -func TestOffsetWriter(t *testing.T) { - buf := newBufferString("ABCDEF") - w := NewOffsetWriter(buf, 3) - n, err := w.Write([]byte("foobar")) - if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr { - t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) { - t.Errorf("buf.Bytes: got %q, wanted %q", got, want) - } -} - -func TestSectionWriter(t *testing.T) { - buf := newBufferString("ABCDEFGHI") - w := NewSectionWriter(buf, 3, 3) - n, err := w.Write([]byte("foobar")) - if wantN, wantErr := 3, ErrReachedLimit; n != wantN || err != wantErr { - t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := buf.Bytes, []byte("ABCfooGHI"); !bytes.Equal(got, want) { - t.Errorf("buf.Bytes: got %q, wanted %q", got, want) - } -} - -func TestSectionWriterLimitOverflow(t *testing.T) { - // SectionWriter behaves like OffsetWriter when limit overflows int64. - buf := newBufferString("ABCDEF") - w := NewSectionWriter(buf, 3, math.MaxInt64) - n, err := w.Write([]byte("foobar")) - if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr { - t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) { - t.Errorf("buf.Bytes: got %q, wanted %q", got, want) - } -} diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD deleted file mode 100644 index 700385907..000000000 --- a/pkg/segment/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -load("//tools/go_generics:defs.bzl", "go_template") - -go_template( - name = "generic_range", - srcs = ["range.go"], - types = [ - "T", - ], -) - -go_template( - name = "generic_set", - srcs = [ - "set.go", - "set_state.go", - ], - opt_consts = [ - "minDegree", - ], - types = [ - "Key", - "Range", - "Value", - "Functions", - ], -) diff --git a/pkg/segment/set_state.go b/pkg/segment/set_state.go deleted file mode 100644 index 76de92591..000000000 --- a/pkg/segment/set_state.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segment - -func (s *Set) saveRoot() *SegmentDataSlices { - return s.ExportSortedSlices() -} - -func (s *Set) loadRoot(sds *SegmentDataSlices) { - if err := s.ImportSortedSlices(sds); err != nil { - panic(err) - } -} diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD deleted file mode 100644 index 694486296..000000000 --- a/pkg/segment/test/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package( - default_visibility = ["//visibility:private"], - licenses = ["notice"], -) - -load("//tools/go_generics:defs.bzl", "go_template_instance") - -go_template_instance( - name = "int_range", - out = "int_range.go", - package = "segment", - template = "//pkg/segment:generic_range", - types = { - "T": "int", - }, -) - -go_template_instance( - name = "int_set", - out = "int_set.go", - package = "segment", - template = "//pkg/segment:generic_set", - types = { - "Key": "int", - "Range": "Range", - "Value": "int", - "Functions": "setFunctions", - }, -) - -go_library( - name = "segment", - testonly = 1, - srcs = [ - "int_range.go", - "int_set.go", - "set_functions.go", - ], - importpath = "gvisor.dev/gvisor/pkg/segment/segment", - deps = [ - "//pkg/state", - ], -) - -go_test( - name = "segment_test", - size = "small", - srcs = ["segment_test.go"], - embed = [":segment"], -) diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go deleted file mode 100644 index f19a005f3..000000000 --- a/pkg/segment/test/segment_test.go +++ /dev/null @@ -1,564 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segment - -import ( - "fmt" - "math/rand" - "testing" -) - -const ( - // testSize is the baseline number of elements inserted into sets under - // test, and is chosen to be large enough to ensure interesting amounts of - // tree rebalancing. - // - // Note that because checkSet is called between each insertion/removal in - // some tests that use it, tests may be quadratic in testSize. - testSize = 8000 - - // valueOffset is the difference between the value and start of test - // segments. - valueOffset = 100000 -) - -func shuffle(xs []int) { - for i := range xs { - j := rand.Intn(i + 1) - xs[i], xs[j] = xs[j], xs[i] - } -} - -func randPermutation(size int) []int { - p := make([]int, size) - for i := range p { - p[i] = i - } - shuffle(p) - return p -} - -// checkSet returns an error if s is incorrectly sorted, does not contain -// exactly expectedSegments segments, or contains a segment for which val != -// key + valueOffset. -func checkSet(s *Set, expectedSegments int) error { - havePrev := false - prev := 0 - nrSegments := 0 - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - next := seg.Start() - if havePrev && prev >= next { - return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) - } - if got, want := seg.Value(), seg.Start()+valueOffset; got != want { - return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start, got, want) - } - prev = next - havePrev = true - nrSegments++ - } - if nrSegments != expectedSegments { - return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) - } - return nil -} - -// countSegmentsIn returns the number of segments in s. -func countSegmentsIn(s *Set) int { - var count int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - count++ - } - return count -} - -func TestAddRandom(t *testing.T) { - var s Set - order := randPermutation(testSize) - var nrInsertions int - for i, j := range order { - if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { - t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) - break - } - nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Insertion order: %v", order[:nrInsertions]) - t.Logf("Set contents:\n%v", &s) - } -} - -func TestRemoveRandom(t *testing.T) { - var s Set - for i := 0; i < testSize; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) { - t.Fatalf("Failed to insert segment %d", i) - } - } - order := randPermutation(testSize) - var nrRemovals int - for i, j := range order { - seg := s.FindSegment(j) - if !seg.Ok() { - t.Errorf("Iteration %d: failed to find segment with key %d", i, j) - break - } - s.Remove(seg) - nrRemovals++ - if err := checkSet(&s, testSize-nrRemovals); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := countSegmentsIn(&s), testSize-nrRemovals; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Removal order: %v", order[:nrRemovals]) - t.Logf("Set contents:\n%v", &s) - t.FailNow() - } -} - -func TestAddSequentialAdjacent(t *testing.T) { - var s Set - var nrInsertions int - for i := 0; i < testSize; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) { - t.Fatalf("Failed to insert segment %d", i) - } - nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Set contents:\n%v", &s) - } - - first := s.FirstSegment() - gotSeg, gotGap := first.PrevNonEmpty() - if wantGap := s.FirstGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("FirstSegment().PrevNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", gotSeg, gotGap, wantGap) - } - gotSeg, gotGap = first.NextNonEmpty() - if wantSeg := first.NextSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("FirstSegment().NextNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", gotSeg, gotGap, wantSeg) - } - - last := s.LastSegment() - gotSeg, gotGap = last.PrevNonEmpty() - if wantSeg := last.PrevSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("LastSegment().PrevNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", gotSeg, gotGap, wantSeg) - } - gotSeg, gotGap = last.NextNonEmpty() - if wantGap := s.LastGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("LastSegment().NextNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", gotSeg, gotGap, wantGap) - } - - for seg := first.NextSegment(); seg != last; seg = seg.NextSegment() { - gotSeg, gotGap = seg.PrevNonEmpty() - if wantSeg := seg.PrevSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("%v.PrevNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", seg, gotSeg, gotGap, wantSeg) - } - gotSeg, gotGap = seg.NextNonEmpty() - if wantSeg := seg.NextSegment(); gotSeg != wantSeg || gotGap.Ok() { - t.Errorf("%v.NextNonEmpty(): got (%v, %v), wanted (%v, <terminal iterator>)", seg, gotSeg, gotGap, wantSeg) - } - } -} - -func TestAddSequentialNonAdjacent(t *testing.T) { - var s Set - var nrInsertions int - for i := 0; i < testSize; i++ { - // The range here differs from TestAddSequentialAdjacent so that - // consecutive segments are not adjacent. - if !s.AddWithoutMerging(Range{2 * i, 2*i + 1}, 2*i+valueOffset) { - t.Fatalf("Failed to insert segment %d", i) - } - nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { - t.Errorf("Iteration %d: %v", i, err) - break - } - } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { - t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) - } - if t.Failed() { - t.Logf("Set contents:\n%v", &s) - } - - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - gotSeg, gotGap := seg.PrevNonEmpty() - if wantGap := seg.PrevGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("%v.PrevNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", seg, gotSeg, gotGap, wantGap) - } - gotSeg, gotGap = seg.NextNonEmpty() - if wantGap := seg.NextGap(); gotSeg.Ok() || gotGap != wantGap { - t.Errorf("%v.NextNonEmpty(): got (%v, %v), wanted (<terminal iterator>, %v)", seg, gotSeg, gotGap, wantGap) - } - } -} - -func TestMergeSplit(t *testing.T) { - tests := []struct { - name string - initial []Range - split bool - splitAddr int - final []Range - }{ - { - name: "Add merges after existing segment", - initial: []Range{{1000, 1100}, {1100, 1200}}, - final: []Range{{1000, 1200}}, - }, - { - name: "Add merges before existing segment", - initial: []Range{{1100, 1200}, {1000, 1100}}, - final: []Range{{1000, 1200}}, - }, - { - name: "Add merges between existing segments", - initial: []Range{{1000, 1100}, {1200, 1300}, {1100, 1200}}, - final: []Range{{1000, 1300}}, - }, - { - name: "SplitAt does nothing at a free address", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 300, - final: []Range{{100, 200}}, - }, - { - name: "SplitAt does nothing at the beginning of a segment", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 100, - final: []Range{{100, 200}}, - }, - { - name: "SplitAt does nothing at the end of a segment", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 200, - final: []Range{{100, 200}}, - }, - { - name: "SplitAt splits in the middle of a segment", - initial: []Range{{100, 200}}, - split: true, - splitAddr: 150, - final: []Range{{100, 150}, {150, 200}}, - }, - } -Tests: - for _, test := range tests { - var s Set - for _, r := range test.initial { - if !s.Add(r, 0) { - t.Errorf("%s: Add(%v) failed; set contents:\n%v", test.name, r, &s) - continue Tests - } - } - if test.split { - s.SplitAt(test.splitAddr) - } - var i int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) - continue Tests - } - if got, want := seg.Range(), test.final[i]; got != want { - t.Errorf("%s: Segment %d mismatch: got %v, wanted %v; set contents:\n%v", test.name, i, got, want, &s) - continue Tests - } - i++ - } - if i < len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, i, len(test.final), &s) - } - } -} - -func TestIsolate(t *testing.T) { - tests := []struct { - name string - initial Range - bounds Range - final []Range - }{ - { - name: "Isolate does not split a segment that falls inside bounds", - initial: Range{100, 200}, - bounds: Range{100, 200}, - final: []Range{{100, 200}}, - }, - { - name: "Isolate splits at beginning of segment", - initial: Range{50, 200}, - bounds: Range{100, 200}, - final: []Range{{50, 100}, {100, 200}}, - }, - { - name: "Isolate splits at end of segment", - initial: Range{100, 250}, - bounds: Range{100, 200}, - final: []Range{{100, 200}, {200, 250}}, - }, - { - name: "Isolate splits at beginning and end of segment", - initial: Range{50, 250}, - bounds: Range{100, 200}, - final: []Range{{50, 100}, {100, 200}, {200, 250}}, - }, - } -Tests: - for _, test := range tests { - var s Set - seg := s.Insert(s.FirstGap(), test.initial, 0) - seg = s.Isolate(seg, test.bounds) - if !test.bounds.IsSupersetOf(seg.Range()) { - t.Errorf("%s: Isolated segment %v lies outside bounds %v; set contents:\n%v", test.name, seg.Range(), test.bounds, &s) - } - var i int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) - continue Tests - } - if got, want := seg.Range(), test.final[i]; got != want { - t.Errorf("%s: Segment %d mismatch: got %v, wanted %v; set contents:\n%v", test.name, i, got, want, &s) - continue Tests - } - i++ - } - if i < len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, i, len(test.final), &s) - } - } -} - -func benchmarkAddSequential(b *testing.B, size int) { - for n := 0; n < b.N; n++ { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - } -} - -func benchmarkAddRandom(b *testing.B, size int) { - order := randPermutation(size) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - var s Set - for _, i := range order { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - } -} - -func benchmarkFindSequential(b *testing.B, size int) { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - - b.ResetTimer() - for n := 0; n < b.N; n++ { - for i := 0; i < size; i++ { - if seg := s.FindSegment(i); !seg.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - } - } -} - -func benchmarkFindRandom(b *testing.B, size int) { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - order := randPermutation(size) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - for _, i := range order { - if si := s.FindSegment(i); !si.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - } - } -} - -func benchmarkIteration(b *testing.B, size int) { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - - b.ResetTimer() - var count uint64 - for n := 0; n < b.N; n++ { - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - count++ - } - } - if got, want := count, uint64(size)*uint64(b.N); got != want { - b.Fatalf("Iterated wrong number of segments: got %d, wanted %d", got, want) - } -} - -func benchmarkAddFindRemoveSequential(b *testing.B, size int) { - for n := 0; n < b.N; n++ { - var s Set - for i := 0; i < size; i++ { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - for i := 0; i < size; i++ { - seg := s.FindSegment(i) - if !seg.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - s.Remove(seg) - } - if !s.IsEmpty() { - b.Fatalf("Set not empty after all removals:\n%v", &s) - } - } -} - -func benchmarkAddFindRemoveRandom(b *testing.B, size int) { - order := randPermutation(size) - - b.ResetTimer() - for n := 0; n < b.N; n++ { - var s Set - for _, i := range order { - if !s.AddWithoutMerging(Range{i, i + 1}, i) { - b.Fatalf("Failed to insert segment %d", i) - } - } - for _, i := range order { - seg := s.FindSegment(i) - if !seg.Ok() { - b.Fatalf("Failed to find segment %d", i) - } - s.Remove(seg) - } - if !s.IsEmpty() { - b.Fatalf("Set not empty after all removals:\n%v", &s) - } - } -} - -// Although we don't generally expect our segment sets to get this big, they're -// useful for emulating the effect of cache pressure. -var testSizes = []struct { - desc string - size int -}{ - {"64", 1 << 6}, - {"256", 1 << 8}, - {"1K", 1 << 10}, - {"4K", 1 << 12}, - {"16K", 1 << 14}, - {"64K", 1 << 16}, -} - -func BenchmarkAddSequential(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddSequential(b, test.size) - }) - } -} - -func BenchmarkAddRandom(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddRandom(b, test.size) - }) - } -} - -func BenchmarkFindSequential(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkFindSequential(b, test.size) - }) - } -} - -func BenchmarkFindRandom(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkFindRandom(b, test.size) - }) - } -} - -func BenchmarkIteration(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkIteration(b, test.size) - }) - } -} - -func BenchmarkAddFindRemoveSequential(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddFindRemoveSequential(b, test.size) - }) - } -} - -func BenchmarkAddFindRemoveRandom(b *testing.B) { - for _, test := range testSizes { - b.Run(test.desc, func(b *testing.B) { - benchmarkAddFindRemoveRandom(b, test.size) - }) - } -} diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go deleted file mode 100644 index bcddb39bb..000000000 --- a/pkg/segment/test/set_functions.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segment - -// Basic numeric constants that we define because the math package doesn't. -// TODO(nlacasse): These should be Math.MaxInt64/MinInt64? -const ( - maxInt = int(^uint(0) >> 1) - minInt = -maxInt - 1 -) - -type setFunctions struct{} - -func (setFunctions) MinKey() int { - return minInt -} - -func (setFunctions) MaxKey() int { - return maxInt -} - -func (setFunctions) ClearValue(*int) {} - -func (setFunctions) Merge(_ Range, val1 int, _ Range, _ int) (int, bool) { - return val1, true -} - -func (setFunctions) Split(_ Range, val int, _ int) (int, int) { - return val, val -} diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD deleted file mode 100644 index 53989301f..000000000 --- a/pkg/sentry/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -# This BUILD file defines a package_group that allows for interdependencies for -# sentry-internal packages. - -package(licenses = ["notice"]) - -package_group( - name = "internal", - packages = [ - "//pkg/sentry/...", - "//runsc/...", - ], -) diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD deleted file mode 100644 index 7aace2d7b..000000000 --- a/pkg/sentry/arch/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") - -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "arch", - srcs = [ - "aligned.go", - "arch.go", - "arch_amd64.go", - "arch_amd64.s", - "arch_state_x86.go", - "arch_x86.go", - "auxv.go", - "signal_act.go", - "signal_amd64.go", - "signal_info.go", - "signal_stack.go", - "stack.go", - "syscalls_amd64.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/arch", - visibility = ["//:sandbox"], - deps = [ - ":registers_go_proto", - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/cpuid", - "//pkg/log", - "//pkg/sentry/context", - "//pkg/sentry/limits", - "//pkg/sentry/usermem", - "//pkg/syserror", - ], -) - -proto_library( - name = "registers_proto", - srcs = ["registers.proto"], - visibility = ["//visibility:public"], -) - -go_proto_library( - name = "registers_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto", - proto = ":registers_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/arch/arch_state_autogen.go b/pkg/sentry/arch/arch_state_autogen.go new file mode 100755 index 000000000..9b1205531 --- /dev/null +++ b/pkg/sentry/arch/arch_state_autogen.go @@ -0,0 +1,193 @@ +// automatically generated by stateify. + +package arch + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *MmapLayout) beforeSave() {} +func (x *MmapLayout) save(m state.Map) { + x.beforeSave() + m.Save("MinAddr", &x.MinAddr) + m.Save("MaxAddr", &x.MaxAddr) + m.Save("BottomUpBase", &x.BottomUpBase) + m.Save("TopDownBase", &x.TopDownBase) + m.Save("DefaultDirection", &x.DefaultDirection) + m.Save("MaxStackRand", &x.MaxStackRand) +} + +func (x *MmapLayout) afterLoad() {} +func (x *MmapLayout) load(m state.Map) { + m.Load("MinAddr", &x.MinAddr) + m.Load("MaxAddr", &x.MaxAddr) + m.Load("BottomUpBase", &x.BottomUpBase) + m.Load("TopDownBase", &x.TopDownBase) + m.Load("DefaultDirection", &x.DefaultDirection) + m.Load("MaxStackRand", &x.MaxStackRand) +} + +func (x *context64) beforeSave() {} +func (x *context64) save(m state.Map) { + x.beforeSave() + m.Save("State", &x.State) + m.Save("sigFPState", &x.sigFPState) +} + +func (x *context64) afterLoad() {} +func (x *context64) load(m state.Map) { + m.Load("State", &x.State) + m.Load("sigFPState", &x.sigFPState) +} + +func (x *syscallPtraceRegs) beforeSave() {} +func (x *syscallPtraceRegs) save(m state.Map) { + x.beforeSave() + m.Save("R15", &x.R15) + m.Save("R14", &x.R14) + m.Save("R13", &x.R13) + m.Save("R12", &x.R12) + m.Save("Rbp", &x.Rbp) + m.Save("Rbx", &x.Rbx) + m.Save("R11", &x.R11) + m.Save("R10", &x.R10) + m.Save("R9", &x.R9) + m.Save("R8", &x.R8) + m.Save("Rax", &x.Rax) + m.Save("Rcx", &x.Rcx) + m.Save("Rdx", &x.Rdx) + m.Save("Rsi", &x.Rsi) + m.Save("Rdi", &x.Rdi) + m.Save("Orig_rax", &x.Orig_rax) + m.Save("Rip", &x.Rip) + m.Save("Cs", &x.Cs) + m.Save("Eflags", &x.Eflags) + m.Save("Rsp", &x.Rsp) + m.Save("Ss", &x.Ss) + m.Save("Fs_base", &x.Fs_base) + m.Save("Gs_base", &x.Gs_base) + m.Save("Ds", &x.Ds) + m.Save("Es", &x.Es) + m.Save("Fs", &x.Fs) + m.Save("Gs", &x.Gs) +} + +func (x *syscallPtraceRegs) afterLoad() {} +func (x *syscallPtraceRegs) load(m state.Map) { + m.Load("R15", &x.R15) + m.Load("R14", &x.R14) + m.Load("R13", &x.R13) + m.Load("R12", &x.R12) + m.Load("Rbp", &x.Rbp) + m.Load("Rbx", &x.Rbx) + m.Load("R11", &x.R11) + m.Load("R10", &x.R10) + m.Load("R9", &x.R9) + m.Load("R8", &x.R8) + m.Load("Rax", &x.Rax) + m.Load("Rcx", &x.Rcx) + m.Load("Rdx", &x.Rdx) + m.Load("Rsi", &x.Rsi) + m.Load("Rdi", &x.Rdi) + m.Load("Orig_rax", &x.Orig_rax) + m.Load("Rip", &x.Rip) + m.Load("Cs", &x.Cs) + m.Load("Eflags", &x.Eflags) + m.Load("Rsp", &x.Rsp) + m.Load("Ss", &x.Ss) + m.Load("Fs_base", &x.Fs_base) + m.Load("Gs_base", &x.Gs_base) + m.Load("Ds", &x.Ds) + m.Load("Es", &x.Es) + m.Load("Fs", &x.Fs) + m.Load("Gs", &x.Gs) +} + +func (x *State) beforeSave() {} +func (x *State) save(m state.Map) { + x.beforeSave() + var Regs syscallPtraceRegs = x.saveRegs() + m.SaveValue("Regs", Regs) + m.Save("x86FPState", &x.x86FPState) + m.Save("FeatureSet", &x.FeatureSet) +} + +func (x *State) load(m state.Map) { + m.LoadWait("x86FPState", &x.x86FPState) + m.Load("FeatureSet", &x.FeatureSet) + m.LoadValue("Regs", new(syscallPtraceRegs), func(y interface{}) { x.loadRegs(y.(syscallPtraceRegs)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *AuxEntry) beforeSave() {} +func (x *AuxEntry) save(m state.Map) { + x.beforeSave() + m.Save("Key", &x.Key) + m.Save("Value", &x.Value) +} + +func (x *AuxEntry) afterLoad() {} +func (x *AuxEntry) load(m state.Map) { + m.Load("Key", &x.Key) + m.Load("Value", &x.Value) +} + +func (x *SignalAct) beforeSave() {} +func (x *SignalAct) save(m state.Map) { + x.beforeSave() + m.Save("Handler", &x.Handler) + m.Save("Flags", &x.Flags) + m.Save("Restorer", &x.Restorer) + m.Save("Mask", &x.Mask) +} + +func (x *SignalAct) afterLoad() {} +func (x *SignalAct) load(m state.Map) { + m.Load("Handler", &x.Handler) + m.Load("Flags", &x.Flags) + m.Load("Restorer", &x.Restorer) + m.Load("Mask", &x.Mask) +} + +func (x *SignalStack) beforeSave() {} +func (x *SignalStack) save(m state.Map) { + x.beforeSave() + m.Save("Addr", &x.Addr) + m.Save("Flags", &x.Flags) + m.Save("Size", &x.Size) +} + +func (x *SignalStack) afterLoad() {} +func (x *SignalStack) load(m state.Map) { + m.Load("Addr", &x.Addr) + m.Load("Flags", &x.Flags) + m.Load("Size", &x.Size) +} + +func (x *SignalInfo) beforeSave() {} +func (x *SignalInfo) save(m state.Map) { + x.beforeSave() + m.Save("Signo", &x.Signo) + m.Save("Errno", &x.Errno) + m.Save("Code", &x.Code) + m.Save("Fields", &x.Fields) +} + +func (x *SignalInfo) afterLoad() {} +func (x *SignalInfo) load(m state.Map) { + m.Load("Signo", &x.Signo) + m.Load("Errno", &x.Errno) + m.Load("Code", &x.Code) + m.Load("Fields", &x.Fields) +} + +func init() { + state.Register("arch.MmapLayout", (*MmapLayout)(nil), state.Fns{Save: (*MmapLayout).save, Load: (*MmapLayout).load}) + state.Register("arch.context64", (*context64)(nil), state.Fns{Save: (*context64).save, Load: (*context64).load}) + state.Register("arch.syscallPtraceRegs", (*syscallPtraceRegs)(nil), state.Fns{Save: (*syscallPtraceRegs).save, Load: (*syscallPtraceRegs).load}) + state.Register("arch.State", (*State)(nil), state.Fns{Save: (*State).save, Load: (*State).load}) + state.Register("arch.AuxEntry", (*AuxEntry)(nil), state.Fns{Save: (*AuxEntry).save, Load: (*AuxEntry).load}) + state.Register("arch.SignalAct", (*SignalAct)(nil), state.Fns{Save: (*SignalAct).save, Load: (*SignalAct).load}) + state.Register("arch.SignalStack", (*SignalStack)(nil), state.Fns{Save: (*SignalStack).save, Load: (*SignalStack).load}) + state.Register("arch.SignalInfo", (*SignalInfo)(nil), state.Fns{Save: (*SignalInfo).save, Load: (*SignalInfo).load}) +} diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto deleted file mode 100644 index 9dc83e241..000000000 --- a/pkg/sentry/arch/registers.proto +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -message AMD64Registers { - uint64 rax = 1; - uint64 rbx = 2; - uint64 rcx = 3; - uint64 rdx = 4; - uint64 rsi = 5; - uint64 rdi = 6; - uint64 rsp = 7; - uint64 rbp = 8; - - uint64 r8 = 9; - uint64 r9 = 10; - uint64 r10 = 11; - uint64 r11 = 12; - uint64 r12 = 13; - uint64 r13 = 14; - uint64 r14 = 15; - uint64 r15 = 16; - - uint64 rip = 17; - uint64 rflags = 18; - uint64 orig_rax = 19; - uint64 cs = 20; - uint64 ds = 21; - uint64 es = 22; - uint64 fs = 23; - uint64 gs = 24; - uint64 ss = 25; - uint64 fs_base = 26; - uint64 gs_base = 27; -} - -message Registers { - oneof arch { - AMD64Registers amd64 = 1; - } -} diff --git a/pkg/sentry/arch/registers_go_proto/registers.pb.go b/pkg/sentry/arch/registers_go_proto/registers.pb.go new file mode 100755 index 000000000..088209be7 --- /dev/null +++ b/pkg/sentry/arch/registers_go_proto/registers.pb.go @@ -0,0 +1,367 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/sentry/arch/registers.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type AMD64Registers struct { + Rax uint64 `protobuf:"varint,1,opt,name=rax,proto3" json:"rax,omitempty"` + Rbx uint64 `protobuf:"varint,2,opt,name=rbx,proto3" json:"rbx,omitempty"` + Rcx uint64 `protobuf:"varint,3,opt,name=rcx,proto3" json:"rcx,omitempty"` + Rdx uint64 `protobuf:"varint,4,opt,name=rdx,proto3" json:"rdx,omitempty"` + Rsi uint64 `protobuf:"varint,5,opt,name=rsi,proto3" json:"rsi,omitempty"` + Rdi uint64 `protobuf:"varint,6,opt,name=rdi,proto3" json:"rdi,omitempty"` + Rsp uint64 `protobuf:"varint,7,opt,name=rsp,proto3" json:"rsp,omitempty"` + Rbp uint64 `protobuf:"varint,8,opt,name=rbp,proto3" json:"rbp,omitempty"` + R8 uint64 `protobuf:"varint,9,opt,name=r8,proto3" json:"r8,omitempty"` + R9 uint64 `protobuf:"varint,10,opt,name=r9,proto3" json:"r9,omitempty"` + R10 uint64 `protobuf:"varint,11,opt,name=r10,proto3" json:"r10,omitempty"` + R11 uint64 `protobuf:"varint,12,opt,name=r11,proto3" json:"r11,omitempty"` + R12 uint64 `protobuf:"varint,13,opt,name=r12,proto3" json:"r12,omitempty"` + R13 uint64 `protobuf:"varint,14,opt,name=r13,proto3" json:"r13,omitempty"` + R14 uint64 `protobuf:"varint,15,opt,name=r14,proto3" json:"r14,omitempty"` + R15 uint64 `protobuf:"varint,16,opt,name=r15,proto3" json:"r15,omitempty"` + Rip uint64 `protobuf:"varint,17,opt,name=rip,proto3" json:"rip,omitempty"` + Rflags uint64 `protobuf:"varint,18,opt,name=rflags,proto3" json:"rflags,omitempty"` + OrigRax uint64 `protobuf:"varint,19,opt,name=orig_rax,json=origRax,proto3" json:"orig_rax,omitempty"` + Cs uint64 `protobuf:"varint,20,opt,name=cs,proto3" json:"cs,omitempty"` + Ds uint64 `protobuf:"varint,21,opt,name=ds,proto3" json:"ds,omitempty"` + Es uint64 `protobuf:"varint,22,opt,name=es,proto3" json:"es,omitempty"` + Fs uint64 `protobuf:"varint,23,opt,name=fs,proto3" json:"fs,omitempty"` + Gs uint64 `protobuf:"varint,24,opt,name=gs,proto3" json:"gs,omitempty"` + Ss uint64 `protobuf:"varint,25,opt,name=ss,proto3" json:"ss,omitempty"` + FsBase uint64 `protobuf:"varint,26,opt,name=fs_base,json=fsBase,proto3" json:"fs_base,omitempty"` + GsBase uint64 `protobuf:"varint,27,opt,name=gs_base,json=gsBase,proto3" json:"gs_base,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *AMD64Registers) Reset() { *m = AMD64Registers{} } +func (m *AMD64Registers) String() string { return proto.CompactTextString(m) } +func (*AMD64Registers) ProtoMessage() {} +func (*AMD64Registers) Descriptor() ([]byte, []int) { + return fileDescriptor_082b7510610e0457, []int{0} +} + +func (m *AMD64Registers) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_AMD64Registers.Unmarshal(m, b) +} +func (m *AMD64Registers) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_AMD64Registers.Marshal(b, m, deterministic) +} +func (m *AMD64Registers) XXX_Merge(src proto.Message) { + xxx_messageInfo_AMD64Registers.Merge(m, src) +} +func (m *AMD64Registers) XXX_Size() int { + return xxx_messageInfo_AMD64Registers.Size(m) +} +func (m *AMD64Registers) XXX_DiscardUnknown() { + xxx_messageInfo_AMD64Registers.DiscardUnknown(m) +} + +var xxx_messageInfo_AMD64Registers proto.InternalMessageInfo + +func (m *AMD64Registers) GetRax() uint64 { + if m != nil { + return m.Rax + } + return 0 +} + +func (m *AMD64Registers) GetRbx() uint64 { + if m != nil { + return m.Rbx + } + return 0 +} + +func (m *AMD64Registers) GetRcx() uint64 { + if m != nil { + return m.Rcx + } + return 0 +} + +func (m *AMD64Registers) GetRdx() uint64 { + if m != nil { + return m.Rdx + } + return 0 +} + +func (m *AMD64Registers) GetRsi() uint64 { + if m != nil { + return m.Rsi + } + return 0 +} + +func (m *AMD64Registers) GetRdi() uint64 { + if m != nil { + return m.Rdi + } + return 0 +} + +func (m *AMD64Registers) GetRsp() uint64 { + if m != nil { + return m.Rsp + } + return 0 +} + +func (m *AMD64Registers) GetRbp() uint64 { + if m != nil { + return m.Rbp + } + return 0 +} + +func (m *AMD64Registers) GetR8() uint64 { + if m != nil { + return m.R8 + } + return 0 +} + +func (m *AMD64Registers) GetR9() uint64 { + if m != nil { + return m.R9 + } + return 0 +} + +func (m *AMD64Registers) GetR10() uint64 { + if m != nil { + return m.R10 + } + return 0 +} + +func (m *AMD64Registers) GetR11() uint64 { + if m != nil { + return m.R11 + } + return 0 +} + +func (m *AMD64Registers) GetR12() uint64 { + if m != nil { + return m.R12 + } + return 0 +} + +func (m *AMD64Registers) GetR13() uint64 { + if m != nil { + return m.R13 + } + return 0 +} + +func (m *AMD64Registers) GetR14() uint64 { + if m != nil { + return m.R14 + } + return 0 +} + +func (m *AMD64Registers) GetR15() uint64 { + if m != nil { + return m.R15 + } + return 0 +} + +func (m *AMD64Registers) GetRip() uint64 { + if m != nil { + return m.Rip + } + return 0 +} + +func (m *AMD64Registers) GetRflags() uint64 { + if m != nil { + return m.Rflags + } + return 0 +} + +func (m *AMD64Registers) GetOrigRax() uint64 { + if m != nil { + return m.OrigRax + } + return 0 +} + +func (m *AMD64Registers) GetCs() uint64 { + if m != nil { + return m.Cs + } + return 0 +} + +func (m *AMD64Registers) GetDs() uint64 { + if m != nil { + return m.Ds + } + return 0 +} + +func (m *AMD64Registers) GetEs() uint64 { + if m != nil { + return m.Es + } + return 0 +} + +func (m *AMD64Registers) GetFs() uint64 { + if m != nil { + return m.Fs + } + return 0 +} + +func (m *AMD64Registers) GetGs() uint64 { + if m != nil { + return m.Gs + } + return 0 +} + +func (m *AMD64Registers) GetSs() uint64 { + if m != nil { + return m.Ss + } + return 0 +} + +func (m *AMD64Registers) GetFsBase() uint64 { + if m != nil { + return m.FsBase + } + return 0 +} + +func (m *AMD64Registers) GetGsBase() uint64 { + if m != nil { + return m.GsBase + } + return 0 +} + +type Registers struct { + // Types that are valid to be assigned to Arch: + // *Registers_Amd64 + Arch isRegisters_Arch `protobuf_oneof:"arch"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Registers) Reset() { *m = Registers{} } +func (m *Registers) String() string { return proto.CompactTextString(m) } +func (*Registers) ProtoMessage() {} +func (*Registers) Descriptor() ([]byte, []int) { + return fileDescriptor_082b7510610e0457, []int{1} +} + +func (m *Registers) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Registers.Unmarshal(m, b) +} +func (m *Registers) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Registers.Marshal(b, m, deterministic) +} +func (m *Registers) XXX_Merge(src proto.Message) { + xxx_messageInfo_Registers.Merge(m, src) +} +func (m *Registers) XXX_Size() int { + return xxx_messageInfo_Registers.Size(m) +} +func (m *Registers) XXX_DiscardUnknown() { + xxx_messageInfo_Registers.DiscardUnknown(m) +} + +var xxx_messageInfo_Registers proto.InternalMessageInfo + +type isRegisters_Arch interface { + isRegisters_Arch() +} + +type Registers_Amd64 struct { + Amd64 *AMD64Registers `protobuf:"bytes,1,opt,name=amd64,proto3,oneof"` +} + +func (*Registers_Amd64) isRegisters_Arch() {} + +func (m *Registers) GetArch() isRegisters_Arch { + if m != nil { + return m.Arch + } + return nil +} + +func (m *Registers) GetAmd64() *AMD64Registers { + if x, ok := m.GetArch().(*Registers_Amd64); ok { + return x.Amd64 + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*Registers) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*Registers_Amd64)(nil), + } +} + +func init() { + proto.RegisterType((*AMD64Registers)(nil), "gvisor.AMD64Registers") + proto.RegisterType((*Registers)(nil), "gvisor.Registers") +} + +func init() { proto.RegisterFile("pkg/sentry/arch/registers.proto", fileDescriptor_082b7510610e0457) } + +var fileDescriptor_082b7510610e0457 = []byte{ + // 354 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x54, 0x92, 0x4d, 0x6f, 0xe2, 0x30, + 0x10, 0x86, 0x17, 0x08, 0x01, 0xcc, 0x2e, 0xcb, 0x66, 0x5b, 0x18, 0xda, 0x43, 0x2b, 0x4e, 0x3d, + 0x05, 0x02, 0x01, 0xc1, 0xb1, 0xb4, 0x87, 0x5e, 0x7a, 0xc9, 0x1f, 0x40, 0xf9, 0x74, 0xad, 0x7e, + 0xc4, 0xf2, 0xa0, 0x2a, 0x3d, 0xf7, 0x8f, 0x57, 0xf6, 0xd8, 0xaa, 0x7a, 0xcb, 0xf3, 0xcc, 0x6b, + 0x39, 0x93, 0xbc, 0xec, 0x4a, 0x3e, 0xf3, 0x05, 0x96, 0x6f, 0x27, 0xf5, 0xb1, 0x48, 0x55, 0xfe, + 0xb4, 0x50, 0x25, 0x17, 0x78, 0x2a, 0x15, 0x86, 0x52, 0xd5, 0xa7, 0x3a, 0xf0, 0xf9, 0xbb, 0xc0, + 0x5a, 0xcd, 0x3f, 0x3d, 0x36, 0xba, 0x7d, 0xbc, 0xdf, 0xc6, 0x89, 0x0b, 0x04, 0x63, 0xd6, 0x51, + 0x69, 0x03, 0xad, 0xeb, 0xd6, 0x8d, 0x97, 0xe8, 0x47, 0x63, 0xb2, 0x06, 0xda, 0xd6, 0x64, 0x64, + 0xf2, 0x06, 0x3a, 0xd6, 0xe4, 0x64, 0x8a, 0x06, 0x3c, 0x6b, 0x0a, 0x32, 0x28, 0xa0, 0x6b, 0x0d, + 0x0a, 0xca, 0x08, 0xf0, 0x5d, 0x86, 0x0c, 0x4a, 0xe8, 0xb9, 0x8c, 0xa4, 0xbb, 0x24, 0xf4, 0xdd, + 0x5d, 0x32, 0x18, 0xb1, 0xb6, 0xda, 0xc1, 0xc0, 0x88, 0xb6, 0xda, 0x19, 0xde, 0x03, 0xb3, 0xbc, + 0x37, 0x27, 0xa2, 0x25, 0x0c, 0xed, 0x89, 0x68, 0x49, 0x26, 0x82, 0xdf, 0xce, 0x44, 0x64, 0x56, + 0xf0, 0xc7, 0x99, 0x15, 0x99, 0x35, 0x8c, 0x9c, 0x59, 0x93, 0x89, 0xe1, 0xaf, 0x33, 0x31, 0x99, + 0x0d, 0x8c, 0x9d, 0xd9, 0x18, 0x23, 0x24, 0xfc, 0xb3, 0x46, 0xc8, 0x60, 0xc2, 0x7c, 0x55, 0xbd, + 0xa4, 0x1c, 0x21, 0x30, 0xd2, 0x52, 0x30, 0x63, 0xfd, 0x5a, 0x09, 0x7e, 0xd4, 0x9f, 0xf2, 0xbf, + 0x99, 0xf4, 0x34, 0x27, 0x69, 0xa3, 0x17, 0xc8, 0x11, 0xce, 0x68, 0x81, 0x1c, 0x35, 0x17, 0x08, + 0xe7, 0xc4, 0x85, 0xe1, 0x12, 0x61, 0x42, 0x5c, 0x1a, 0xae, 0x10, 0xa6, 0xc4, 0x95, 0x61, 0x8e, + 0x00, 0xc4, 0xdc, 0x30, 0x22, 0xcc, 0x88, 0x11, 0x83, 0x29, 0xeb, 0x55, 0x78, 0xcc, 0x52, 0x2c, + 0xe1, 0x82, 0xde, 0xa9, 0xc2, 0x43, 0x8a, 0xa5, 0x1e, 0x70, 0x3b, 0xb8, 0xa4, 0x01, 0x37, 0x83, + 0xf9, 0x1d, 0x1b, 0x7c, 0xff, 0xff, 0x90, 0x75, 0xd3, 0xd7, 0x62, 0x1b, 0x9b, 0x06, 0x0c, 0x57, + 0x93, 0x90, 0xaa, 0x12, 0xfe, 0xac, 0xc9, 0xc3, 0xaf, 0x84, 0x62, 0x07, 0x9f, 0x79, 0xba, 0x62, + 0x99, 0x6f, 0x9a, 0xb5, 0xfe, 0x0a, 0x00, 0x00, 0xff, 0xff, 0xb4, 0xcc, 0x03, 0x27, 0x7c, 0x02, + 0x00, 0x00, +} diff --git a/pkg/sentry/context/BUILD b/pkg/sentry/context/BUILD deleted file mode 100644 index 8dc1a77b1..000000000 --- a/pkg/sentry/context/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "context", - srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/context", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/amutex", - "//pkg/log", - ], -) diff --git a/pkg/sentry/context/context_state_autogen.go b/pkg/sentry/context/context_state_autogen.go new file mode 100755 index 000000000..7dd55bfea --- /dev/null +++ b/pkg/sentry/context/context_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package context + diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/context/contexttest/BUILD deleted file mode 100644 index 3b6841b7e..000000000 --- a/pkg/sentry/context/contexttest/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "contexttest", - testonly = 1, - srcs = ["contexttest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/context/contexttest", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/memutil", - "//pkg/sentry/context", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/platform/ptrace", - "//pkg/sentry/uniqueid", - ], -) diff --git a/pkg/sentry/context/contexttest/contexttest.go b/pkg/sentry/context/contexttest/contexttest.go deleted file mode 100644 index 6d0120518..000000000 --- a/pkg/sentry/context/contexttest/contexttest.go +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package contexttest builds a test context.Context. -package contexttest - -import ( - "os" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/memutil" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ptrace" - "gvisor.dev/gvisor/pkg/sentry/uniqueid" -) - -// Context returns a Context that may be used in tests. Uses ptrace as the -// platform.Platform. -// -// Note that some filesystems may require a minimal kernel for testing, which -// this test context does not provide. For such tests, see kernel/contexttest. -func Context(tb testing.TB) context.Context { - const memfileName = "contexttest-memory" - memfd, err := memutil.CreateMemFD(memfileName, 0) - if err != nil { - tb.Fatalf("error creating application memory file: %v", err) - } - memfile := os.NewFile(uintptr(memfd), memfileName) - mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{}) - if err != nil { - memfile.Close() - tb.Fatalf("error creating pgalloc.MemoryFile: %v", err) - } - p, err := ptrace.New() - if err != nil { - tb.Fatal(err) - } - // Test usage of context.Background is fine. - return &TestContext{ - Context: context.Background(), - l: limits.NewLimitSet(), - mf: mf, - platform: p, - otherValues: make(map[interface{}]interface{}), - } -} - -// TestContext represents a context with minimal functionality suitable for -// running tests. -type TestContext struct { - context.Context - l *limits.LimitSet - mf *pgalloc.MemoryFile - platform platform.Platform - otherValues map[interface{}]interface{} -} - -// globalUniqueID tracks incremental unique identifiers for tests. -var globalUniqueID uint64 - -// globalUniqueIDProvider implements unix.UniqueIDProvider. -type globalUniqueIDProvider struct{} - -// UniqueID implements unix.UniqueIDProvider.UniqueID. -func (*globalUniqueIDProvider) UniqueID() uint64 { - return atomic.AddUint64(&globalUniqueID, 1) -} - -// lastInotifyCookie is a monotonically increasing counter for generating unique -// inotify cookies. Must be accessed using atomic ops. -var lastInotifyCookie uint32 - -// hostClock implements ktime.Clock. -type hostClock struct { - ktime.WallRateClock - ktime.NoClockEvents -} - -// Now implements ktime.Clock.Now. -func (hostClock) Now() ktime.Time { - return ktime.FromNanoseconds(time.Now().UnixNano()) -} - -// RegisterValue registers additional values with this test context. Useful for -// providing values from external packages that contexttest can't depend on. -func (t *TestContext) RegisterValue(key, value interface{}) { - t.otherValues[key] = value -} - -// Value implements context.Context. -func (t *TestContext) Value(key interface{}) interface{} { - switch key { - case limits.CtxLimits: - return t.l - case pgalloc.CtxMemoryFile: - return t.mf - case pgalloc.CtxMemoryFileProvider: - return t - case platform.CtxPlatform: - return t.platform - case uniqueid.CtxGlobalUniqueID: - return (*globalUniqueIDProvider).UniqueID(nil) - case uniqueid.CtxGlobalUniqueIDProvider: - return &globalUniqueIDProvider{} - case uniqueid.CtxInotifyCookie: - return atomic.AddUint32(&lastInotifyCookie, 1) - case ktime.CtxRealtimeClock: - return hostClock{} - default: - if val, ok := t.otherValues[key]; ok { - return val - } - return t.Context.Value(key) - } -} - -// MemoryFile implements pgalloc.MemoryFileProvider.MemoryFile. -func (t *TestContext) MemoryFile() *pgalloc.MemoryFile { - return t.mf -} - -// RootContext returns a Context that may be used in tests that need root -// credentials. Uses ptrace as the platform.Platform. -func RootContext(tb testing.TB) context.Context { - return WithCreds(Context(tb), auth.NewRootCredentials(auth.NewRootUserNamespace())) -} - -// WithCreds returns a copy of ctx carrying creds. -func WithCreds(ctx context.Context, creds *auth.Credentials) context.Context { - return &authContext{ctx, creds} -} - -type authContext struct { - context.Context - creds *auth.Credentials -} - -// Value implements context.Context. -func (ac *authContext) Value(key interface{}) interface{} { - switch key { - case auth.CtxCredentials: - return ac.creds - default: - return ac.Context.Value(key) - } -} - -// WithLimitSet returns a copy of ctx carrying l. -func WithLimitSet(ctx context.Context, l *limits.LimitSet) context.Context { - return limitContext{ctx, l} -} - -type limitContext struct { - context.Context - l *limits.LimitSet -} - -// Value implements context.Context. -func (lc limitContext) Value(key interface{}) interface{} { - switch key { - case limits.CtxLimits: - return lc.l - default: - return lc.Context.Value(key) - } -} diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD deleted file mode 100644 index 15a1fe8a9..000000000 --- a/pkg/sentry/control/BUILD +++ /dev/null @@ -1,45 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "control", - srcs = [ - "control.go", - "pprof.go", - "proc.go", - "state.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/control", - visibility = [ - "//pkg/sentry:internal", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/fd", - "//pkg/log", - "//pkg/sentry/fs", - "//pkg/sentry/fs/host", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/state", - "//pkg/sentry/usage", - "//pkg/sentry/watchdog", - "//pkg/urpc", - ], -) - -go_test( - name = "control_test", - size = "small", - srcs = ["proc_test.go"], - embed = [":control"], - deps = [ - "//pkg/log", - "//pkg/sentry/kernel/time", - "//pkg/sentry/usage", - ], -) diff --git a/pkg/sentry/control/control_state_autogen.go b/pkg/sentry/control/control_state_autogen.go new file mode 100755 index 000000000..a1de4bc6d --- /dev/null +++ b/pkg/sentry/control/control_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package control + diff --git a/pkg/sentry/control/proc_test.go b/pkg/sentry/control/proc_test.go deleted file mode 100644 index d8ada2694..000000000 --- a/pkg/sentry/control/proc_test.go +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package control - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/log" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/usage" -) - -func init() { - log.SetLevel(log.Debug) -} - -// Tests that ProcessData.Table() prints with the correct format. -func TestProcessListTable(t *testing.T) { - testCases := []struct { - pl []*Process - expected string - }{ - { - pl: []*Process{}, - expected: "UID PID PPID C STIME TIME CMD", - }, - { - pl: []*Process{ - { - UID: 0, - PID: 0, - PPID: 0, - C: 0, - STime: "0", - Time: "0", - Cmd: "zero", - }, - { - UID: 1, - PID: 1, - PPID: 1, - C: 1, - STime: "1", - Time: "1", - Cmd: "one", - }, - }, - expected: `UID PID PPID C STIME TIME CMD -0 0 0 0 0 0 zero -1 1 1 1 1 1 one`, - }, - } - - for _, tc := range testCases { - output := ProcessListToTable(tc.pl) - - if tc.expected != output { - t.Errorf("PrintTable(%v): got:\n%s\nwant:\n%s", tc.pl, output, tc.expected) - } - } -} - -func TestProcessListJSON(t *testing.T) { - testCases := []struct { - pl []*Process - expected string - }{ - { - pl: []*Process{}, - expected: "[]", - }, - { - pl: []*Process{ - { - UID: 0, - PID: 0, - PPID: 0, - C: 0, - STime: "0", - Time: "0", - Cmd: "zero", - }, - { - UID: 1, - PID: 1, - PPID: 1, - C: 1, - STime: "1", - Time: "1", - Cmd: "one", - }, - }, - expected: "[0,1]", - }, - } - - for _, tc := range testCases { - output, err := PrintPIDsJSON(tc.pl) - if err != nil { - t.Errorf("failed to generate JSON: %v", err) - } - - if tc.expected != output { - t.Errorf("PrintJSON(%v): got:\n%s\nwant:\n%s", tc.pl, output, tc.expected) - } - } -} - -func TestPercentCPU(t *testing.T) { - testCases := []struct { - stats usage.CPUStats - startTime ktime.Time - now ktime.Time - expected int32 - }{ - { - // Verify that 100% use is capped at 99. - stats: usage.CPUStats{UserTime: 1e9, SysTime: 1e9}, - startTime: ktime.FromNanoseconds(7e9), - now: ktime.FromNanoseconds(9e9), - expected: 99, - }, - { - // Verify that if usage > lifetime, we get at most 99% - // usage. - stats: usage.CPUStats{UserTime: 2e9, SysTime: 2e9}, - startTime: ktime.FromNanoseconds(7e9), - now: ktime.FromNanoseconds(9e9), - expected: 99, - }, - { - // Verify that 50% usage is reported correctly. - stats: usage.CPUStats{UserTime: 1e9, SysTime: 1e9}, - startTime: ktime.FromNanoseconds(12e9), - now: ktime.FromNanoseconds(16e9), - expected: 50, - }, - { - // Verify that 0% usage is reported correctly. - stats: usage.CPUStats{UserTime: 0, SysTime: 0}, - startTime: ktime.FromNanoseconds(12e9), - now: ktime.FromNanoseconds(14e9), - expected: 0, - }, - } - - for _, tc := range testCases { - if pcpu := percentCPU(tc.stats, tc.startTime, tc.now); pcpu != tc.expected { - t.Errorf("percentCPU(%v, %v, %v): got %d, want %d", tc.stats, tc.startTime, tc.now, pcpu, tc.expected) - } - } -} diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD deleted file mode 100644 index 7e8918722..000000000 --- a/pkg/sentry/device/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "device", - srcs = ["device.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/device", - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/abi/linux"], -) - -go_test( - name = "device_test", - size = "small", - srcs = ["device_test.go"], - embed = [":device"], -) diff --git a/pkg/sentry/device/device_state_autogen.go b/pkg/sentry/device/device_state_autogen.go new file mode 100755 index 000000000..be40f6c77 --- /dev/null +++ b/pkg/sentry/device/device_state_autogen.go @@ -0,0 +1,52 @@ +// automatically generated by stateify. + +package device + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Registry) beforeSave() {} +func (x *Registry) save(m state.Map) { + x.beforeSave() + m.Save("lastAnonDeviceMinor", &x.lastAnonDeviceMinor) + m.Save("devices", &x.devices) +} + +func (x *Registry) afterLoad() {} +func (x *Registry) load(m state.Map) { + m.Load("lastAnonDeviceMinor", &x.lastAnonDeviceMinor) + m.Load("devices", &x.devices) +} + +func (x *ID) beforeSave() {} +func (x *ID) save(m state.Map) { + x.beforeSave() + m.Save("Major", &x.Major) + m.Save("Minor", &x.Minor) +} + +func (x *ID) afterLoad() {} +func (x *ID) load(m state.Map) { + m.Load("Major", &x.Major) + m.Load("Minor", &x.Minor) +} + +func (x *Device) beforeSave() {} +func (x *Device) save(m state.Map) { + x.beforeSave() + m.Save("ID", &x.ID) + m.Save("last", &x.last) +} + +func (x *Device) afterLoad() {} +func (x *Device) load(m state.Map) { + m.Load("ID", &x.ID) + m.Load("last", &x.last) +} + +func init() { + state.Register("device.Registry", (*Registry)(nil), state.Fns{Save: (*Registry).save, Load: (*Registry).load}) + state.Register("device.ID", (*ID)(nil), state.Fns{Save: (*ID).save, Load: (*ID).load}) + state.Register("device.Device", (*Device)(nil), state.Fns{Save: (*Device).save, Load: (*Device).load}) +} diff --git a/pkg/sentry/device/device_test.go b/pkg/sentry/device/device_test.go deleted file mode 100644 index e3f51ce4f..000000000 --- a/pkg/sentry/device/device_test.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package device - -import ( - "testing" -) - -func TestMultiDevice(t *testing.T) { - device := &MultiDevice{} - - // Check that Load fails to install virtual inodes that are - // uninitialized. - if device.Load(MultiDeviceKey{}, 0) { - t.Fatalf("got load of invalid virtual inode 0, want unsuccessful") - } - - inode := device.Map(MultiDeviceKey{}) - - // Assert that the same raw device and inode map to - // a consistent virtual inode. - if i := device.Map(MultiDeviceKey{}); i != inode { - t.Fatalf("got inode %d, want %d in %s", i, inode, device) - } - - // Assert that a new inode or new device does not conflict. - if i := device.Map(MultiDeviceKey{Device: 0, Inode: 1}); i == inode { - t.Fatalf("got reused inode %d, want new distinct inode in %s", i, device) - } - last := device.Map(MultiDeviceKey{Device: 1, Inode: 0}) - if last == inode { - t.Fatalf("got reused inode %d, want new distinct inode in %s", last, device) - } - - // Virtual is the virtual inode we want to load. - virtual := last + 1 - - // Assert that we can load a virtual inode at a new place. - if !device.Load(MultiDeviceKey{Device: 0, Inode: 2}, virtual) { - t.Fatalf("got load of virtual inode %d failed, want success in %s", virtual, device) - } - - // Assert that the next inode skips over the loaded one. - if i := device.Map(MultiDeviceKey{Device: 0, Inode: 3}); i != virtual+1 { - t.Fatalf("got inode %d, want %d in %s", i, virtual+1, device) - } -} diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD deleted file mode 100644 index e5b9b8bc2..000000000 --- a/pkg/sentry/fs/BUILD +++ /dev/null @@ -1,134 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "fs", - srcs = [ - "attr.go", - "context.go", - "copy_up.go", - "dentry.go", - "dirent.go", - "dirent_cache.go", - "dirent_cache_limiter.go", - "dirent_list.go", - "dirent_state.go", - "event_list.go", - "file.go", - "file_operations.go", - "file_overlay.go", - "file_state.go", - "filesystems.go", - "flags.go", - "fs.go", - "inode.go", - "inode_inotify.go", - "inode_operations.go", - "inode_overlay.go", - "inotify.go", - "inotify_event.go", - "inotify_watch.go", - "mock.go", - "mount.go", - "mount_overlay.go", - "mounts.go", - "offset.go", - "overlay.go", - "path.go", - "restore.go", - "save.go", - "seek.go", - "splice.go", - "sync.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/amutex", - "//pkg/log", - "//pkg/metric", - "//pkg/p9", - "//pkg/refs", - "//pkg/secio", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs/lock", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/platform", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/uniqueid", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/state", - "//pkg/syserror", - "//pkg/waiter", - ], -) - -go_template_instance( - name = "dirent_list", - out = "dirent_list.go", - package = "fs", - prefix = "dirent", - template = "//pkg/ilist:generic_list", - types = { - "Linker": "*Dirent", - "Element": "*Dirent", - }, -) - -go_template_instance( - name = "event_list", - out = "event_list.go", - package = "fs", - prefix = "event", - template = "//pkg/ilist:generic_list", - types = { - "Linker": "*Event", - "Element": "*Event", - }, -) - -go_test( - name = "fs_x_test", - size = "small", - srcs = [ - "copy_up_test.go", - "file_overlay_test.go", - "inode_overlay_test.go", - "mounts_test.go", - ], - deps = [ - ":fs", - "//pkg/sentry/context", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/kernel/contexttest", - "//pkg/sentry/usermem", - "//pkg/syserror", - ], -) - -go_test( - name = "fs_test", - size = "small", - srcs = [ - "dirent_cache_test.go", - "dirent_refs_test.go", - "mount_test.go", - "path_test.go", - ], - embed = [":fs"], - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - ], -) diff --git a/pkg/sentry/fs/README.md b/pkg/sentry/fs/README.md deleted file mode 100644 index f53ed3eaa..000000000 --- a/pkg/sentry/fs/README.md +++ /dev/null @@ -1,229 +0,0 @@ -This package provides an implementation of the Linux virtual filesystem. - -[TOC] - -## Overview - -- An `fs.Dirent` caches an `fs.Inode` in memory at a path in the VFS, giving - the `fs.Inode` a relative position with respect to other `fs.Inode`s. - -- If an `fs.Dirent` is referenced by two file descriptors, then those file - descriptors are coherent with each other: they depend on the same - `fs.Inode`. - -- A mount point is an `fs.Dirent` for which `fs.Dirent.mounted` is true. It - exposes the root of a mounted filesystem. - -- The `fs.Inode` produced by a registered filesystem on mount(2) owns an - `fs.MountedFilesystem` from which other `fs.Inode`s will be looked up. For a - remote filesystem, the `fs.MountedFilesystem` owns the connection to that - remote filesystem. - -- In general: - -``` -fs.Inode <------------------------------ -| | -| | -produced by | -exactly one | -| responsible for the -| virtual identity of -v | -fs.MountedFilesystem ------------------- -``` - -Glossary: - -- VFS: virtual filesystem. - -- inode: a virtual file object holding a cached view of a file on a backing - filesystem (includes metadata and page caches). - -- superblock: the virtual state of a mounted filesystem (e.g. the virtual - inode number set). - -- mount namespace: a view of the mounts under a root (during path traversal, - the VFS makes visible/follows the mount point that is in the current task's - mount namespace). - -## Save and restore - -An application's hard dependencies on filesystem state can be broken down into -two categories: - -- The state necessary to execute a traversal on or view the *virtual* - filesystem hierarchy, regardless of what files an application has open. - -- The state necessary to represent open files. - -The first is always necessary to save and restore. An application may never have -any open file descriptors, but across save and restore it should see a coherent -view of any mount namespace. NOTE(b/63601033): Currently only one "initial" -mount namespace is supported. - -The second is so that system calls across save and restore are coherent with -each other (e.g. so that unintended re-reads or overwrites do not occur). - -Specifically this state is: - -- An `fs.MountManager` containing mount points. - -- A `kernel.FDMap` containing pointers to open files. - -Anything else managed by the VFS that can be easily loaded into memory from a -filesystem is synced back to those filesystems and is not saved. Examples are -pages in page caches used for optimizations (i.e. readahead and writeback), and -directory entries used to accelerate path lookups. - -### Mount points - -Saving and restoring a mount point means saving and restoring: - -- The root of the mounted filesystem. - -- Mount flags, which control how the VFS interacts with the mounted - filesystem. - -- Any relevant metadata about the mounted filesystem. - -- All `fs.Inode`s referenced by the application that reside under the mount - point. - -`fs.MountedFilesystem` is metadata about a filesystem that is mounted. It is -referenced by every `fs.Inode` loaded into memory under the mount point -including the `fs.Inode` of the mount point itself. The `fs.MountedFilesystem` -maps file objects on the filesystem to a virtualized `fs.Inode` number and vice -versa. - -To restore all `fs.Inode`s under a given mount point, each `fs.Inode` leverages -its dependency on an `fs.MountedFilesystem`. Since the `fs.MountedFilesystem` -knows how an `fs.Inode` maps to a file object on a backing filesystem, this -mapping can be trivially consulted by each `fs.Inode` when the `fs.Inode` is -restored. - -In detail, a mount point is saved in two steps: - -- First, after the kernel is paused but before state.Save, we walk all mount - namespaces and install a mapping from `fs.Inode` numbers to file paths - relative to the root of the mounted filesystem in each - `fs.MountedFilesystem`. This is subsequently called the set of `fs.Inode` - mappings. - -- Second, during state.Save, each `fs.MountedFilesystem` decides whether to - save the set of `fs.Inode` mappings. In-memory filesystems, like tmpfs, have - no need to save a set of `fs.Inode` mappings, since the `fs.Inode`s can be - entirely encoded in state file. Each `fs.MountedFilesystem` also optionally - saves the device name from when the filesystem was originally mounted. Each - `fs.Inode` saves its virtual identifier and a reference to a - `fs.MountedFilesystem`. - -A mount point is restored in two steps: - -- First, before state.Load, all mount configurations are stored in a global - `fs.RestoreEnvironment`. This tells us what mount points the user wants to - restore and how to re-establish pointers to backing filesystems. - -- Second, during state.Load, each `fs.MountedFilesystem` optionally searches - for a mount in the `fs.RestoreEnvironment` that matches its saved device - name. The `fs.MountedFilesystem` then restablishes a pointer to the root of - the mounted filesystem. For example, the mount specification provides the - network connection for a mounted remote filesystem client to communicate - with its remote file server. The `fs.MountedFilesystem` also trivially loads - its set of `fs.Inode` mappings. When an `fs.Inode` is encountered, the - `fs.Inode` loads its virtual identifier and its reference a - `fs.MountedFilesystem`. It uses the `fs.MountedFilesystem` to obtain the - root of the mounted filesystem and the `fs.Inode` mappings to obtain the - relative file path to its data. With these, the `fs.Inode` re-establishes a - pointer to its file object. - -A mount point can trivially restore its `fs.Inode`s in parallel since -`fs.Inode`s have a restore dependency on their `fs.MountedFilesystem` and not on -each other. - -### Open files - -An `fs.File` references the following filesystem objects: - -```go -fs.File -> fs.Dirent -> fs.Inode -> fs.MountedFilesystem -``` - -The `fs.Inode` is restored using its `fs.MountedFilesystem`. The -[Mount points](#mount-points) section above describes how this happens in -detail. The `fs.Dirent` restores its pointer to an `fs.Inode`, pointers to -parent and children `fs.Dirents`, and the basename of the file. - -Otherwise an `fs.File` restores flags, an offset, and a unique identifier (only -used internally). - -It may use the `fs.Inode`, which it indirectly holds a reference on through the -`fs.Dirent`, to restablish an open file handle on the backing filesystem (e.g. -to continue reading and writing). - -## Overlay - -The overlay implementation in the fs package takes Linux overlayfs as a frame of -reference but corrects for several POSIX consistency errors. - -In Linux overlayfs, the `struct inode` used for reading and writing to the same -file may be different. This is because the `struct inode` is dissociated with -the process of copying up the file from the upper to the lower directory. Since -flock(2) and fcntl(2) locks, inotify(7) watches, page caches, and a file's -identity are all stored directly or indirectly off the `struct inode`, these -properties of the `struct inode` may be stale after the first modification. This -can lead to file locking bugs, missed inotify events, and inconsistent data in -shared memory mappings of files, to name a few problems. - -The fs package maintains a single `fs.Inode` to represent a directory entry in -an overlay and defines operations on this `fs.Inode` which synchronize with the -copy up process. This achieves several things: - -+ File locks, inotify watches, and the identity of the file need not be copied - at all. - -+ Memory mappings of files coordinate with the copy up process so that if a - file in the lower directory is memory mapped, all references to it are - invalidated, forcing the application to re-fault on memory mappings of the - file under the upper directory. - -The `fs.Inode` holds metadata about files in the upper and/or lower directories -via an `fs.overlayEntry`. The `fs.overlayEntry` implements the `fs.Mappable` -interface. It multiplexes between upper and lower directory memory mappings and -stores a copy of memory references so they can be transferred to the upper -directory `fs.Mappable` when the file is copied up. - -The lower filesystem in an overlay may contain another (nested) overlay, but the -upper filesystem may not contain another overlay. In other words, nested -overlays form a tree structure that only allows branching in the lower -filesystem. - -Caching decisions in the overlay are delegated to the upper filesystem, meaning -that the Keep and Revalidate methods on the overlay return the same values as -the upper filesystem. A small wrinkle is that the lower filesystem is not -allowed to return `true` from Revalidate, as the overlay can not reload inodes -from the lower filesystem. A lower filesystem that does return `true` from -Revalidate will trigger a panic. - -The `fs.Inode` also holds a reference to a `fs.MountedFilesystem` that -normalizes across the mounted filesystem state of the upper and lower -directories. - -When a file is copied from the lower to the upper directory, attempts to -interact with the file block until the copy completes. All copying synchronizes -with rename(2). - -## Future Work - -### Overlay - -When a file is copied from a lower directory to an upper directory, several -locks are taken: the global renamuMu and the copyMu of the `fs.Inode` being -copied. This blocks operations on the file, including fault handling of memory -mappings. Performance could be improved by copying files into a temporary -directory that resides on the same filesystem as the upper directory and doing -an atomic rename, holding locks only during the rename operation. - -Additionally files are copied up synchronously. For large files, this causes a -noticeable latency. Performance could be improved by pipelining copies at -non-overlapping file offsets. diff --git a/pkg/sentry/fs/anon/BUILD b/pkg/sentry/fs/anon/BUILD deleted file mode 100644 index ae1c9cf76..000000000 --- a/pkg/sentry/fs/anon/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "anon", - srcs = [ - "anon.go", - "device.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/anon", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/usermem", - ], -) diff --git a/pkg/sentry/fs/anon/anon_state_autogen.go b/pkg/sentry/fs/anon/anon_state_autogen.go new file mode 100755 index 000000000..fcb914212 --- /dev/null +++ b/pkg/sentry/fs/anon/anon_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package anon + diff --git a/pkg/sentry/fs/ashmem/BUILD b/pkg/sentry/fs/ashmem/BUILD deleted file mode 100644 index 5b755cec5..000000000 --- a/pkg/sentry/fs/ashmem/BUILD +++ /dev/null @@ -1,65 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -go_library( - name = "ashmem", - srcs = [ - "area.go", - "device.go", - "pin_board.go", - "uint64_range.go", - "uint64_set.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/ashmem", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) - -go_test( - name = "ashmem_test", - size = "small", - srcs = ["pin_board_test.go"], - embed = [":ashmem"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/usermem", - ], -) - -go_template_instance( - name = "uint64_range", - out = "uint64_range.go", - package = "ashmem", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "uint64_set", - out = "uint64_set.go", - package = "ashmem", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "Range", - "Value": "noValue", - "Functions": "setFunctions", - }, -) diff --git a/pkg/sentry/fs/ashmem/ashmem_state_autogen.go b/pkg/sentry/fs/ashmem/ashmem_state_autogen.go new file mode 100755 index 000000000..13defb033 --- /dev/null +++ b/pkg/sentry/fs/ashmem/ashmem_state_autogen.go @@ -0,0 +1,123 @@ +// automatically generated by stateify. + +package ashmem + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Area) beforeSave() {} +func (x *Area) save(m state.Map) { + x.beforeSave() + m.Save("ad", &x.ad) + m.Save("tmpfsFile", &x.tmpfsFile) + m.Save("name", &x.name) + m.Save("size", &x.size) + m.Save("perms", &x.perms) + m.Save("pb", &x.pb) +} + +func (x *Area) afterLoad() {} +func (x *Area) load(m state.Map) { + m.Load("ad", &x.ad) + m.Load("tmpfsFile", &x.tmpfsFile) + m.Load("name", &x.name) + m.Load("size", &x.size) + m.Load("perms", &x.perms) + m.Load("pb", &x.pb) +} + +func (x *Device) beforeSave() {} +func (x *Device) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *Device) afterLoad() {} +func (x *Device) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *PinBoard) beforeSave() {} +func (x *PinBoard) save(m state.Map) { + x.beforeSave() + m.Save("Set", &x.Set) +} + +func (x *PinBoard) afterLoad() {} +func (x *PinBoard) load(m state.Map) { + m.Load("Set", &x.Set) +} + +func (x *Range) beforeSave() {} +func (x *Range) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *Range) afterLoad() {} +func (x *Range) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *Set) beforeSave() {} +func (x *Set) save(m state.Map) { + x.beforeSave() + var root *SegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *Set) afterLoad() {} +func (x *Set) load(m state.Map) { + m.LoadValue("root", new(*SegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*SegmentDataSlices)) }) +} + +func (x *node) beforeSave() {} +func (x *node) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *node) afterLoad() {} +func (x *node) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *SegmentDataSlices) beforeSave() {} +func (x *SegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *SegmentDataSlices) afterLoad() {} +func (x *SegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func init() { + state.Register("ashmem.Area", (*Area)(nil), state.Fns{Save: (*Area).save, Load: (*Area).load}) + state.Register("ashmem.Device", (*Device)(nil), state.Fns{Save: (*Device).save, Load: (*Device).load}) + state.Register("ashmem.PinBoard", (*PinBoard)(nil), state.Fns{Save: (*PinBoard).save, Load: (*PinBoard).load}) + state.Register("ashmem.Range", (*Range)(nil), state.Fns{Save: (*Range).save, Load: (*Range).load}) + state.Register("ashmem.Set", (*Set)(nil), state.Fns{Save: (*Set).save, Load: (*Set).load}) + state.Register("ashmem.node", (*node)(nil), state.Fns{Save: (*node).save, Load: (*node).load}) + state.Register("ashmem.SegmentDataSlices", (*SegmentDataSlices)(nil), state.Fns{Save: (*SegmentDataSlices).save, Load: (*SegmentDataSlices).load}) +} diff --git a/pkg/sentry/fs/ashmem/pin_board_test.go b/pkg/sentry/fs/ashmem/pin_board_test.go deleted file mode 100644 index 1ccd38f56..000000000 --- a/pkg/sentry/fs/ashmem/pin_board_test.go +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ashmem - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -func TestPinBoard(t *testing.T) { - pb := NewPinBoard() - - // Confirm that all pages are pinned. - if !pb.RangePinnedStatus(RangeFromAshmemPin(linux.AshmemPin{0, 0})) { - t.Errorf("RangePinnedStatus(all pages) returned false (unpinned) at start.") - } - - // Unpin pages [1, 11) (counting from 0) - pb.UnpinRange(RangeFromAshmemPin(linux.AshmemPin{ - usermem.PageSize, - usermem.PageSize * 10, - })) - - // Confirm that pages [1, 11) are unpinned and that page 0 and pages - // larger than 10 are pinned. - pinned := []linux.AshmemPin{ - { - 0, - usermem.PageSize, - }, { - usermem.PageSize * 11, - 0, - }, - } - - for _, pin := range pinned { - if !pb.RangePinnedStatus(RangeFromAshmemPin(pin)) { - t.Errorf("RangePinnedStatus(AshmemPin{offset (pages): %v, len (pages): %v}) returned false (unpinned).", - pin.Offset, pin.Len) - } - } - - unpinned := []linux.AshmemPin{ - { - usermem.PageSize, - usermem.PageSize * 10, - }, - } - - for _, pin := range unpinned { - if pb.RangePinnedStatus(RangeFromAshmemPin(pin)) { - t.Errorf("RangePinnedStatus(AshmemPin{offset (pages): %v, len (pages): %v}) returned true (pinned).", - pin.Offset, pin.Len) - } - } - - // Pin pages [2, 6). - pb.PinRange(RangeFromAshmemPin(linux.AshmemPin{ - usermem.PageSize * 2, - usermem.PageSize * 4, - })) - - // Confirm that pages 0, [2, 6) and pages larger than 10 are pinned - // while others remain unpinned. - pinned = []linux.AshmemPin{ - { - 0, - usermem.PageSize, - }, - { - usermem.PageSize * 2, - usermem.PageSize * 4, - }, - { - usermem.PageSize * 11, - 0, - }, - } - - for _, pin := range pinned { - if !pb.RangePinnedStatus(RangeFromAshmemPin(pin)) { - t.Errorf("RangePinnedStatus(AshmemPin{offset (pages): %v, len (pages): %v}) returned false (unpinned).", - pin.Offset, pin.Len) - } - } - - unpinned = []linux.AshmemPin{ - { - usermem.PageSize, - usermem.PageSize, - }, { - usermem.PageSize * 6, - usermem.PageSize * 5, - }, - } - - for _, pin := range unpinned { - if pb.RangePinnedStatus(RangeFromAshmemPin(pin)) { - t.Errorf("RangePinnedStatus(AshmemPin{offset (pages): %v, len (pages): %v}) returned true (pinned).", - pin.Offset, pin.Len) - } - } - - // Status of a partially pinned range is unpinned. - if pb.RangePinnedStatus(RangeFromAshmemPin(linux.AshmemPin{0, 0})) { - t.Errorf("RangePinnedStatus(all pages) returned true (pinned).") - } - - // Pin the whole range again. - pb.PinRange(RangeFromAshmemPin(linux.AshmemPin{0, 0})) - - // Confirm that all pages are pinned. - if !pb.RangePinnedStatus(RangeFromAshmemPin(linux.AshmemPin{0, 0})) { - t.Errorf("RangePinnedStatus(all pages) returned false (unpinned) at start.") - } -} diff --git a/pkg/segment/range.go b/pkg/sentry/fs/ashmem/uint64_range.go index 4d4aeffef..d71a10b16 100644..100755 --- a/pkg/segment/range.go +++ b/pkg/sentry/fs/ashmem/uint64_range.go @@ -1,31 +1,14 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package segment - -// T is a required type parameter that must be an integral type. -type T uint64 +package ashmem // A Range represents a contiguous range of T. // // +stateify savable type Range struct { // Start is the inclusive start of the range. - Start T + Start uint64 // End is the exclusive end of the range. - End T + End uint64 } // WellFormed returns true if r.Start <= r.End. All other methods on a Range @@ -35,12 +18,12 @@ func (r Range) WellFormed() bool { } // Length returns the length of the range. -func (r Range) Length() T { +func (r Range) Length() uint64 { return r.End - r.Start } // Contains returns true if r contains x. -func (r Range) Contains(x T) bool { +func (r Range) Contains(x uint64) bool { return r.Start <= x && x < r.End } @@ -74,6 +57,6 @@ func (r Range) Intersect(r2 Range) Range { // CanSplitAt returns true if it is legal to split a segment spanning the range // r at x; that is, splitting at x would produce two ranges, both of which have // non-zero length. -func (r Range) CanSplitAt(x T) bool { +func (r Range) CanSplitAt(x uint64) bool { return r.Contains(x) && r.Start < x } diff --git a/pkg/segment/set.go b/pkg/sentry/fs/ashmem/uint64_set.go index 982eb3fdd..6e435325f 100644..100755 --- a/pkg/segment/set.go +++ b/pkg/sentry/fs/ashmem/uint64_set.go @@ -1,73 +1,10 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package segment provides tools for working with collections of segments. A -// segment is a key-value mapping, where the key is a non-empty contiguous -// range of values of type Key, and the value is a single value of type Value. -// -// Clients using this package must use the go_template_instance rule in -// tools/go_generics/defs.bzl to create an instantiation of this -// template package, providing types to use in place of Key, Range, Value, and -// Functions. See pkg/segment/test/BUILD for a usage example. -package segment +package ashmem import ( "bytes" "fmt" ) -// Key is a required type parameter that must be an integral type. -type Key uint64 - -// Range is a required type parameter equivalent to Range<Key>. -type Range interface{} - -// Value is a required type parameter. -type Value interface{} - -// Functions is a required type parameter that must be a struct implementing -// the methods defined by Functions. -type Functions interface { - // MinKey returns the minimum allowed key. - MinKey() Key - - // MaxKey returns the maximum allowed key + 1. - MaxKey() Key - - // ClearValue deinitializes the given value. (For example, if Value is a - // pointer or interface type, ClearValue should set it to nil.) - ClearValue(*Value) - - // Merge attempts to merge the values corresponding to two consecutive - // segments. If successful, Merge returns (merged value, true). Otherwise, - // it returns (unspecified, false). - // - // Preconditions: r1.End == r2.Start. - // - // Postconditions: If merging succeeds, val1 and val2 are invalidated. - Merge(r1 Range, val1 Value, r2 Range, val2 Value) (Value, bool) - - // Split splits a segment's value at a key within its range, such that the - // first returned value corresponds to the range [r.Start, split) and the - // second returned value corresponds to the range [split, r.End). - // - // Preconditions: r.Start < split < r.End. - // - // Postconditions: The original value val is invalidated. - Split(r Range, val Value, split Key) (Value, Value) -} - const ( // minDegree is the minimum degree of an internal node in a Set B-tree. // @@ -117,8 +54,8 @@ func (s *Set) IsEmptyRange(r Range) bool { } // Span returns the total size of all segments in the set. -func (s *Set) Span() Key { - var sz Key +func (s *Set) Span() uint64 { + var sz uint64 for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { sz += seg.Range().Length() } @@ -127,14 +64,14 @@ func (s *Set) Span() Key { // SpanRange returns the total size of the intersection of segments in the set // with the given range. -func (s *Set) SpanRange(r Range) Key { +func (s *Set) SpanRange(r Range) uint64 { switch { case r.Length() < 0: panic(fmt.Sprintf("invalid range %v", r)) case r.Length() == 0: return 0 } - var sz Key + var sz uint64 for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { sz += seg.Range().Intersect(r).Length() } @@ -181,11 +118,10 @@ func (s *Set) LastGap() GapIterator { // segment is found, the returned Iterator is non-terminal and the // returned GapIterator is terminal. Otherwise, the returned Iterator is // terminal and the returned GapIterator is non-terminal. -func (s *Set) Find(key Key) (Iterator, GapIterator) { +func (s *Set) Find(key uint64) (Iterator, GapIterator) { n := &s.root for { - // Binary search invariant: the correct value of i lies within [lower, - // upper]. + lower := 0 upper := n.nrSegments for lower < upper { @@ -209,7 +145,7 @@ func (s *Set) Find(key Key) (Iterator, GapIterator) { // FindSegment returns the segment whose range contains the given key. If no // such segment exists, FindSegment returns a terminal iterator. -func (s *Set) FindSegment(key Key) Iterator { +func (s *Set) FindSegment(key uint64) Iterator { seg, _ := s.Find(key) return seg } @@ -217,7 +153,7 @@ func (s *Set) FindSegment(key Key) Iterator { // LowerBoundSegment returns the segment with the lowest range that contains a // key greater than or equal to min. If no such segment exists, // LowerBoundSegment returns a terminal iterator. -func (s *Set) LowerBoundSegment(min Key) Iterator { +func (s *Set) LowerBoundSegment(min uint64) Iterator { seg, gap := s.Find(min) if seg.Ok() { return seg @@ -228,7 +164,7 @@ func (s *Set) LowerBoundSegment(min Key) Iterator { // UpperBoundSegment returns the segment with the highest range that contains a // key less than or equal to max. If no such segment exists, UpperBoundSegment // returns a terminal iterator. -func (s *Set) UpperBoundSegment(max Key) Iterator { +func (s *Set) UpperBoundSegment(max uint64) Iterator { seg, gap := s.Find(max) if seg.Ok() { return seg @@ -239,14 +175,14 @@ func (s *Set) UpperBoundSegment(max Key) Iterator { // FindGap returns the gap containing the given key. If no such gap exists // (i.e. the set contains a segment containing that key), FindGap returns a // terminal iterator. -func (s *Set) FindGap(key Key) GapIterator { +func (s *Set) FindGap(key uint64) GapIterator { _, gap := s.Find(key) return gap } // LowerBoundGap returns the gap with the lowest range that is greater than or // equal to min. -func (s *Set) LowerBoundGap(min Key) GapIterator { +func (s *Set) LowerBoundGap(min uint64) GapIterator { seg, gap := s.Find(min) if gap.Ok() { return gap @@ -256,7 +192,7 @@ func (s *Set) LowerBoundGap(min Key) GapIterator { // UpperBoundGap returns the gap with the highest range that is less than or // equal to max. -func (s *Set) UpperBoundGap(max Key) GapIterator { +func (s *Set) UpperBoundGap(max uint64) GapIterator { seg, gap := s.Find(max) if gap.Ok() { return gap @@ -268,7 +204,7 @@ func (s *Set) UpperBoundGap(max Key) GapIterator { // segment can be merged with adjacent segments, Add will do so. If the new // segment would overlap an existing segment, Add returns false. If Add // succeeds, all existing iterators are invalidated. -func (s *Set) Add(r Range, val Value) bool { +func (s *Set) Add(r Range, val noValue) bool { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -287,7 +223,7 @@ func (s *Set) Add(r Range, val Value) bool { // If it would overlap an existing segment, AddWithoutMerging does nothing and // returns false. If AddWithoutMerging succeeds, all existing iterators are // invalidated. -func (s *Set) AddWithoutMerging(r Range, val Value) bool { +func (s *Set) AddWithoutMerging(r Range, val noValue) bool { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -314,7 +250,7 @@ func (s *Set) AddWithoutMerging(r Range, val Value) bool { // Merge, but may be more efficient. Note that there is no unchecked variant of // Insert since Insert must retrieve and inspect gap's predecessor and // successor segments regardless. -func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { +func (s *Set) Insert(gap GapIterator, r Range, val noValue) Iterator { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -326,12 +262,12 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) } if prev.Ok() && prev.End() == r.Start { - if mval, ok := (Functions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + if mval, ok := (setFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { prev.SetEndUnchecked(r.End) prev.SetValue(mval) if next.Ok() && next.Start() == r.End { val = mval - if mval, ok := (Functions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + if mval, ok := (setFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { prev.SetEndUnchecked(next.End()) prev.SetValue(mval) return s.Remove(next).PrevSegment() @@ -341,7 +277,7 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } } if next.Ok() && next.Start() == r.End { - if mval, ok := (Functions{}).Merge(r, val, next.Range(), next.Value()); ok { + if mval, ok := (setFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { next.SetStartUnchecked(r.Start) next.SetValue(mval) return next @@ -356,7 +292,7 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { // // If the gap cannot accommodate the segment, or if r is invalid, // InsertWithoutMerging panics. -func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator { +func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val noValue) Iterator { if r.Length() <= 0 { panic(fmt.Sprintf("invalid segment range %v", r)) } @@ -371,7 +307,7 @@ func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator // (including gap, but not including the returned iterator) are invalidated. // // Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). -func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator { +func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val noValue) Iterator { gap = gap.node.rebalanceBeforeInsert(gap) copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) @@ -385,25 +321,18 @@ func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) // All existing iterators (including seg, but not including the returned // iterator) are invalidated. func (s *Set) Remove(seg Iterator) GapIterator { - // We only want to remove directly from a leaf node. + if seg.node.hasChildren { - // Since seg.node has children, the removed segment must have a - // predecessor (at the end of the rightmost leaf of its left child - // subtree). Move the contents of that predecessor into the removed - // segment's position, and remove that predecessor instead. (We choose - // to steal the predecessor rather than the successor because removing - // from the end of a leaf node doesn't involve any copying unless - // merging is required.) + victim := seg.PrevSegment() - // This must be unchecked since until victim is removed, seg and victim - // overlap. + seg.SetRangeUnchecked(victim.Range()) seg.SetValue(victim.Value()) return s.Remove(victim).NextGap() } copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) - Functions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + setFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) seg.node.nrSegments-- return seg.node.rebalanceAfterRemove(GapIterator{seg.node, seg.index}) } @@ -450,9 +379,8 @@ func (s *Set) Merge(first, second Iterator) Iterator { // second, first == second.PrevSegment(). func (s *Set) MergeUnchecked(first, second Iterator) Iterator { if first.End() == second.Start() { - if mval, ok := (Functions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { - // N.B. This must be unchecked because until s.Remove(second), first - // overlaps second. + if mval, ok := (setFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + first.SetEndUnchecked(second.End()) first.SetValue(mval) return s.Remove(second).PrevSegment() @@ -520,7 +448,7 @@ func (s *Set) MergeAdjacent(r Range) { // end of the segment's range, so splitting would produce a segment with zero // length, or because split falls outside the segment's range altogether), // Split panics. -func (s *Set) Split(seg Iterator, split Key) (Iterator, Iterator) { +func (s *Set) Split(seg Iterator, split uint64) (Iterator, Iterator) { if !seg.Range().CanSplitAt(split) { panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) } @@ -532,20 +460,20 @@ func (s *Set) Split(seg Iterator, split Key) (Iterator, Iterator) { // seg, but not including the returned iterators) are invalidated. // // Preconditions: seg.Start() < key < seg.End(). -func (s *Set) SplitUnchecked(seg Iterator, split Key) (Iterator, Iterator) { - val1, val2 := (Functions{}).Split(seg.Range(), seg.Value(), split) +func (s *Set) SplitUnchecked(seg Iterator, split uint64) (Iterator, Iterator) { + val1, val2 := (setFunctions{}).Split(seg.Range(), seg.Value(), split) end2 := seg.End() seg.SetEndUnchecked(split) seg.SetValue(val1) seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), Range{split, end2}, val2) - // seg may now be invalid due to the Insert. + return seg2.PrevSegment(), seg2 } // SplitAt splits the segment straddling split, if one exists. SplitAt returns // true if a segment was split and false otherwise. If SplitAt splits a // segment, all existing iterators are invalidated. -func (s *Set) SplitAt(split Key) bool { +func (s *Set) SplitAt(split uint64) bool { if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { s.SplitUnchecked(seg, split) return true @@ -591,8 +519,7 @@ func (s *Set) ApplyContiguous(r Range, fn func(seg Iterator)) GapIterator { } seg = gap.NextSegment() if !seg.Ok() { - // This implies that the last segment extended all the - // way to the maximum value, since the gap was empty. + return GapIterator{} } } @@ -634,7 +561,7 @@ type node struct { // Nodes store keys and values in separate arrays to maximize locality in // the common case (scanning keys for lookup). keys [maxDegree - 1]Range - values [maxDegree - 1]Value + values [maxDegree - 1]noValue children [maxDegree]*node } @@ -683,9 +610,7 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { return gap } if n.parent == nil { - // n is root. Move all segments before and after n's median segment - // into new child nodes adjacent to the median segment, which is now - // the only segment in root. + left := &node{ nrSegments: minDegree - 1, parent: n, @@ -727,10 +652,7 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { } return GapIterator{right, gap.index - minDegree} } - // n is non-root. Move n's median segment into its parent node (which can't - // be full because we've already invoked n.parent.rebalanceBeforeInsert) - // and move all segments after n's median into a new sibling node (the - // median segment's right child subtree). + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[minDegree-1], n.values[minDegree-1] @@ -758,7 +680,7 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { } } n.nrSegments = minDegree - 1 - // gap.node can't be n.parent because gaps are always in leaf nodes. + if gap.node != n { return gap } @@ -780,27 +702,10 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { return gap } if n.parent == nil { - // Root is allowed to be deficient. + return gap } - // There's one other thing we can do before resorting to unsplitting. - // If either sibling node has at least minDegree segments, rotate that - // sibling's closest segment through the segment in the parent that - // separates us. That is, given: - // - // ... D ... - // / \ - // ... B C] [E ... - // - // where the node containing E is deficient, end up with: - // - // ... C ... - // / \ - // ... B] [D E ... - // - // As in Set.Remove, prefer rotating from the end of the sibling to the - // left: by precondition, n.node has fewer segments (to memcpy) than - // the sibling does. + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= minDegree { copy(n.keys[1:], n.keys[:n.nrSegments]) copy(n.values[1:], n.values[:n.nrSegments]) @@ -808,7 +713,7 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { n.values[0] = n.parent.values[n.parentIndex-1] n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] - Functions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + setFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) if n.hasChildren { copy(n.children[1:], n.children[:n.nrSegments+1]) n.children[0] = sibling.children[sibling.nrSegments] @@ -836,7 +741,7 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { n.parent.values[n.parentIndex] = sibling.values[0] copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) - Functions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + setFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) if n.hasChildren { n.children[n.nrSegments+1] = sibling.children[0] copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) @@ -857,15 +762,10 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } return gap } - // Otherwise, we must unsplit. + p := n.parent if p.nrSegments == 1 { - // Merge all segments in both n and its sibling back into n.parent. - // This is the reverse of the root splitting case in - // node.rebalanceBeforeInsert. (Because we require minDegree >= 3, - // only root can have 1 segment in this path, so this reduces the - // height of the tree by 1, without violating the constraint that - // all leaf nodes remain at the same depth.) + left, right := p.children[0], p.children[1] p.nrSegments = left.nrSegments + right.nrSegments + 1 p.hasChildren = left.hasChildren @@ -906,8 +806,7 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { left = n right = n.nextSibling() } - // Fix up gap first since we need the old left.nrSegments, which - // merging will change. + if gap.node == right { gap = GapIterator{left, gap.index + left.nrSegments + 1} } @@ -925,14 +824,14 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { left.nrSegments += right.nrSegments + 1 copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) - Functions{}.ClearValue(&p.values[p.nrSegments-1]) + setFunctions{}.ClearValue(&p.values[p.nrSegments-1]) copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) for i := 0; i < p.nrSegments; i++ { p.children[i].parentIndex = i } p.children[p.nrSegments] = nil p.nrSegments-- - // This process robs p of one segment, so recurse into rebalancing p. + n = p } } @@ -971,13 +870,13 @@ func (seg Iterator) Range() Range { // Start is equivalent to Range().Start, but should be preferred if only the // start of the range is needed. -func (seg Iterator) Start() Key { +func (seg Iterator) Start() uint64 { return seg.node.keys[seg.index].Start } // End is equivalent to Range().End, but should be preferred if only the end of // the range is needed. -func (seg Iterator) End() Key { +func (seg Iterator) End() uint64 { return seg.node.keys[seg.index].End } @@ -1017,7 +916,7 @@ func (seg Iterator) SetRange(r Range) { // // Preconditions: The new start must be valid: start < seg.End(); if // seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). -func (seg Iterator) SetStartUnchecked(start Key) { +func (seg Iterator) SetStartUnchecked(start uint64) { seg.node.keys[seg.index].Start = start } @@ -1025,7 +924,7 @@ func (seg Iterator) SetStartUnchecked(start Key) { // cause the iterated segment to overlap another segment, or would result in an // invalid range, SetStart panics. This operation does not invalidate any // iterators. -func (seg Iterator) SetStart(start Key) { +func (seg Iterator) SetStart(start uint64) { if start >= seg.End() { panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) } @@ -1040,7 +939,7 @@ func (seg Iterator) SetStart(start Key) { // // Preconditions: The new end must be valid: end > seg.Start(); if // seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). -func (seg Iterator) SetEndUnchecked(end Key) { +func (seg Iterator) SetEndUnchecked(end uint64) { seg.node.keys[seg.index].End = end } @@ -1048,7 +947,7 @@ func (seg Iterator) SetEndUnchecked(end Key) { // the iterated segment to overlap another segment, or would result in an // invalid range, SetEnd panics. This operation does not invalidate any // iterators. -func (seg Iterator) SetEnd(end Key) { +func (seg Iterator) SetEnd(end uint64) { if end <= seg.Start() { panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) } @@ -1059,20 +958,20 @@ func (seg Iterator) SetEnd(end Key) { } // Value returns a copy of the iterated segment's value. -func (seg Iterator) Value() Value { +func (seg Iterator) Value() noValue { return seg.node.values[seg.index] } // ValuePtr returns a pointer to the iterated segment's value. The pointer is // invalidated if the iterator is invalidated. This operation does not // invalidate any iterators. -func (seg Iterator) ValuePtr() *Value { +func (seg Iterator) ValuePtr() *noValue { return &seg.node.values[seg.index] } // SetValue mutates the iterated segment's value. This operation does not // invalidate any iterators. -func (seg Iterator) SetValue(val Value) { +func (seg Iterator) SetValue(val noValue) { seg.node.values[seg.index] = val } @@ -1109,8 +1008,7 @@ func (seg Iterator) NextSegment() Iterator { // PrevGap returns the gap immediately before the iterated segment. func (seg Iterator) PrevGap() GapIterator { if seg.node.hasChildren { - // Note that this isn't recursive because the last segment in a subtree - // must be in a leaf node. + return seg.node.children[seg.index].lastSegment().NextGap() } return GapIterator{seg.node, seg.index} @@ -1189,20 +1087,20 @@ func (gap GapIterator) Range() Range { // Start is equivalent to Range().Start, but should be preferred if only the // start of the range is needed. -func (gap GapIterator) Start() Key { +func (gap GapIterator) Start() uint64 { if ps := gap.PrevSegment(); ps.Ok() { return ps.End() } - return Functions{}.MinKey() + return setFunctions{}.MinKey() } // End is equivalent to Range().End, but should be preferred if only the end of // the range is needed. -func (gap GapIterator) End() Key { +func (gap GapIterator) End() uint64 { if ns := gap.NextSegment(); ns.Ok() { return ns.Start() } - return Functions{}.MaxKey() + return setFunctions{}.MaxKey() } // IsEmpty returns true if the iterated gap is empty (that is, the "gap" is @@ -1269,11 +1167,10 @@ func segmentAfterPosition(n *node, i int) Iterator { return Iterator{n, i} } -func zeroValueSlice(slice []Value) { - // TODO(jamieliu): check if Go is actually smart enough to optimize a - // ClearValue that assigns nil to a memset here +func zeroValueSlice(slice []noValue) { + for i := range slice { - Functions{}.ClearValue(&slice[i]) + setFunctions{}.ClearValue(&slice[i]) } } @@ -1323,9 +1220,9 @@ func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) { // // +stateify savable type SegmentDataSlices struct { - Start []Key - End []Key - Values []Value + Start []uint64 + End []uint64 + Values []noValue } // ExportSortedSlice returns a copy of all segments in the given set, in ascending @@ -1362,3 +1259,12 @@ func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error { } return nil } +func (s *Set) saveRoot() *SegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *Set) loadRoot(sds *SegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/binder/BUILD b/pkg/sentry/fs/binder/BUILD deleted file mode 100644 index 639a5603a..000000000 --- a/pkg/sentry/fs/binder/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "binder", - srcs = [ - "binder.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/binder", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/binder/binder_state_autogen.go b/pkg/sentry/fs/binder/binder_state_autogen.go new file mode 100755 index 000000000..1f321e3b6 --- /dev/null +++ b/pkg/sentry/fs/binder/binder_state_autogen.go @@ -0,0 +1,40 @@ +// automatically generated by stateify. + +package binder + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Device) beforeSave() {} +func (x *Device) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *Device) afterLoad() {} +func (x *Device) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *Proc) beforeSave() {} +func (x *Proc) save(m state.Map) { + x.beforeSave() + m.Save("bd", &x.bd) + m.Save("task", &x.task) + m.Save("mfp", &x.mfp) + m.Save("mapped", &x.mapped) +} + +func (x *Proc) afterLoad() {} +func (x *Proc) load(m state.Map) { + m.Load("bd", &x.bd) + m.Load("task", &x.task) + m.Load("mfp", &x.mfp) + m.Load("mapped", &x.mapped) +} + +func init() { + state.Register("binder.Device", (*Device)(nil), state.Fns{Save: (*Device).save, Load: (*Device).load}) + state.Register("binder.Proc", (*Proc)(nil), state.Fns{Save: (*Proc).save, Load: (*Proc).load}) +} diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go deleted file mode 100644 index 72110f856..000000000 --- a/pkg/sentry/fs/copy_up_test.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "bytes" - "crypto/rand" - "fmt" - "io" - "sync" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/fs" - _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -const ( - // origFileSize is the original file size. This many bytes should be - // copied up before the test file is modified. - origFileSize = 4096 - - // truncatedFileSize is the size to truncate all test files. - truncateFileSize = 10 -) - -// TestConcurrentCopyUp is a copy up stress test for an overlay. -// -// It creates a 64-level deep directory tree in the lower filesystem and -// populates the last subdirectory with 64 files containing random content: -// -// /lower -// /sudir0/.../subdir63/ -// /file0 -// ... -// /file63 -// -// The files are truncated concurrently by 4 goroutines per file. -// These goroutines contend with copying up all parent 64 subdirectories -// as well as the final file content. -// -// At the end of the test, we assert that the files respect the new truncated -// size and contain the content we expect. -func TestConcurrentCopyUp(t *testing.T) { - ctx := contexttest.Context(t) - files := makeOverlayTestFiles(t) - - var wg sync.WaitGroup - for _, file := range files { - for i := 0; i < 4; i++ { - wg.Add(1) - go func(o *overlayTestFile) { - if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil { - t.Fatalf("failed to copy up: %v", err) - } - wg.Done() - }(file) - } - } - wg.Wait() - - for _, file := range files { - got := make([]byte, origFileSize) - n, err := file.File.Readv(ctx, usermem.BytesIOSequence(got)) - if int(n) != truncateFileSize { - t.Fatalf("read %d bytes from file, want %d", n, truncateFileSize) - } - if err != nil && err != io.EOF { - t.Fatalf("read got error %v, want nil", err) - } - if !bytes.Equal(got[:n], file.content[:truncateFileSize]) { - t.Fatalf("file content is %v, want %v", got[:n], file.content[:truncateFileSize]) - } - } -} - -type overlayTestFile struct { - File *fs.File - name string - content []byte -} - -func makeOverlayTestFiles(t *testing.T) []*overlayTestFile { - ctx := contexttest.Context(t) - - // Create a lower tmpfs mount. - fsys, _ := fs.FindFilesystem("tmpfs") - lower, err := fsys.Mount(contexttest.Context(t), "", fs.MountSourceFlags{}, "", nil) - if err != nil { - t.Fatalf("failed to mount tmpfs: %v", err) - } - lowerRoot := fs.NewDirent(lower, "") - - // Make a deep set of subdirectories that everyone shares. - next := lowerRoot - for i := 0; i < 64; i++ { - name := fmt.Sprintf("subdir%d", i) - err := next.CreateDirectory(ctx, lowerRoot, name, fs.FilePermsFromMode(0777)) - if err != nil { - t.Fatalf("failed to create dir %q: %v", name, err) - } - next, err = next.Walk(ctx, lowerRoot, name) - if err != nil { - t.Fatalf("failed to walk to %q: %v", name, err) - } - } - - // Make a bunch of files in the last directory. - var files []*overlayTestFile - for i := 0; i < 64; i++ { - name := fmt.Sprintf("file%d", i) - f, err := next.Create(ctx, next, name, fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - t.Fatalf("failed to create file %q: %v", name, err) - } - defer f.DecRef() - - relname, _ := f.Dirent.FullName(lowerRoot) - - o := &overlayTestFile{ - name: relname, - content: make([]byte, origFileSize), - } - - if _, err := rand.Read(o.content); err != nil { - t.Fatalf("failed to read from /dev/urandom: %v", err) - } - - if _, err := f.Writev(ctx, usermem.BytesIOSequence(o.content)); err != nil { - t.Fatalf("failed to write content to file %q: %v", name, err) - } - - files = append(files, o) - } - - // Create an empty upper tmpfs mount which we will copy up into. - upper, err := fsys.Mount(ctx, "", fs.MountSourceFlags{}, "", nil) - if err != nil { - t.Fatalf("failed to mount tmpfs: %v", err) - } - - // Construct an overlay root. - overlay, err := fs.NewOverlayRoot(ctx, upper, lower, fs.MountSourceFlags{}) - if err != nil { - t.Fatalf("failed to construct overlay root: %v", err) - } - - // Create a MountNamespace to traverse the file system. - mns, err := fs.NewMountNamespace(ctx, overlay) - if err != nil { - t.Fatalf("failed to construct mount manager: %v", err) - } - - // Walk to all of the files in the overlay, open them readable. - for _, f := range files { - maxTraversals := uint(0) - d, err := mns.FindInode(ctx, mns.Root(), mns.Root(), f.name, &maxTraversals) - if err != nil { - t.Fatalf("failed to find %q: %v", f.name, err) - } - defer d.DecRef() - - f.File, err = d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("failed to open file %q readable: %v", f.name, err) - } - } - - return files -} diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD deleted file mode 100644 index a9b03d172..000000000 --- a/pkg/sentry/fs/dev/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "dev", - srcs = [ - "dev.go", - "device.go", - "fs.go", - "full.go", - "null.go", - "random.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/dev", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/rand", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/ashmem", - "//pkg/sentry/fs/binder", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", - "//pkg/sentry/safemem", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/dev/dev_state_autogen.go b/pkg/sentry/fs/dev/dev_state_autogen.go new file mode 100755 index 000000000..0afbc170f --- /dev/null +++ b/pkg/sentry/fs/dev/dev_state_autogen.go @@ -0,0 +1,108 @@ +// automatically generated by stateify. + +package dev + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *filesystem) beforeSave() {} +func (x *filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystem) afterLoad() {} +func (x *filesystem) load(m state.Map) { +} + +func (x *fullDevice) beforeSave() {} +func (x *fullDevice) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *fullDevice) afterLoad() {} +func (x *fullDevice) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *fullFileOperations) beforeSave() {} +func (x *fullFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *fullFileOperations) afterLoad() {} +func (x *fullFileOperations) load(m state.Map) { +} + +func (x *nullDevice) beforeSave() {} +func (x *nullDevice) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *nullDevice) afterLoad() {} +func (x *nullDevice) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *nullFileOperations) beforeSave() {} +func (x *nullFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *nullFileOperations) afterLoad() {} +func (x *nullFileOperations) load(m state.Map) { +} + +func (x *zeroDevice) beforeSave() {} +func (x *zeroDevice) save(m state.Map) { + x.beforeSave() + m.Save("nullDevice", &x.nullDevice) +} + +func (x *zeroDevice) afterLoad() {} +func (x *zeroDevice) load(m state.Map) { + m.Load("nullDevice", &x.nullDevice) +} + +func (x *zeroFileOperations) beforeSave() {} +func (x *zeroFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *zeroFileOperations) afterLoad() {} +func (x *zeroFileOperations) load(m state.Map) { +} + +func (x *randomDevice) beforeSave() {} +func (x *randomDevice) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *randomDevice) afterLoad() {} +func (x *randomDevice) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *randomFileOperations) beforeSave() {} +func (x *randomFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *randomFileOperations) afterLoad() {} +func (x *randomFileOperations) load(m state.Map) { +} + +func init() { + state.Register("dev.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load}) + state.Register("dev.fullDevice", (*fullDevice)(nil), state.Fns{Save: (*fullDevice).save, Load: (*fullDevice).load}) + state.Register("dev.fullFileOperations", (*fullFileOperations)(nil), state.Fns{Save: (*fullFileOperations).save, Load: (*fullFileOperations).load}) + state.Register("dev.nullDevice", (*nullDevice)(nil), state.Fns{Save: (*nullDevice).save, Load: (*nullDevice).load}) + state.Register("dev.nullFileOperations", (*nullFileOperations)(nil), state.Fns{Save: (*nullFileOperations).save, Load: (*nullFileOperations).load}) + state.Register("dev.zeroDevice", (*zeroDevice)(nil), state.Fns{Save: (*zeroDevice).save, Load: (*zeroDevice).load}) + state.Register("dev.zeroFileOperations", (*zeroFileOperations)(nil), state.Fns{Save: (*zeroFileOperations).save, Load: (*zeroFileOperations).load}) + state.Register("dev.randomDevice", (*randomDevice)(nil), state.Fns{Save: (*randomDevice).save, Load: (*randomDevice).load}) + state.Register("dev.randomFileOperations", (*randomFileOperations)(nil), state.Fns{Save: (*randomFileOperations).save, Load: (*randomFileOperations).load}) +} diff --git a/pkg/sentry/fs/dirent_cache_test.go b/pkg/sentry/fs/dirent_cache_test.go deleted file mode 100644 index 395c879f5..000000000 --- a/pkg/sentry/fs/dirent_cache_test.go +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "testing" -) - -func TestDirentCache(t *testing.T) { - const maxSize = 5 - - c := NewDirentCache(maxSize) - - // Size starts at 0. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Create a Dirent d. - d := NewNegativeDirent("") - - // c does not contain d. - if got, want := c.contains(d), false; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add d to the cache. - c.Add(d) - - // Size is now 1. - if got, want := c.Size(), uint64(1); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add maxSize-1 more elements. d should be oldest element. - for i := 0; i < maxSize-1; i++ { - c.Add(NewNegativeDirent("")) - } - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // "Bump" d to the front by re-adding it. - c.Add(d) - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add maxSize-1 more elements. d should again be oldest element. - for i := 0; i < maxSize-1; i++ { - c.Add(NewNegativeDirent("")) - } - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c contains d. - if got, want := c.contains(d), true; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Add one more element, which will bump d from the cache. - c.Add(NewNegativeDirent("")) - - // Size is maxSize. - if got, want := c.Size(), uint64(maxSize); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // c does not contain d. - if got, want := c.contains(d), false; got != want { - t.Errorf("c.contains(d) got %v want %v", got, want) - } - - // Invalidating causes size to be 0 and list to be empty. - c.Invalidate() - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - if got, want := c.list.Empty(), true; got != want { - t.Errorf("c.list.Empty() got %v, want %v", got, want) - } - - // Fill cache with maxSize dirents. - for i := 0; i < maxSize; i++ { - c.Add(NewNegativeDirent("")) - } -} - -func TestDirentCacheLimiter(t *testing.T) { - const ( - globalMaxSize = 5 - maxSize = 3 - ) - - limit := NewDirentCacheLimiter(globalMaxSize) - c1 := NewDirentCache(maxSize) - c1.limit = limit - c2 := NewDirentCache(maxSize) - c2.limit = limit - - // Create a Dirent d. - d := NewNegativeDirent("") - - // Add d to the cache. - c1.Add(d) - if got, want := c1.Size(), uint64(1); got != want { - t.Errorf("c1.Size() got %v, want %v", got, want) - } - - // Add maxSize-1 more elements. d should be oldest element. - for i := 0; i < maxSize-1; i++ { - c1.Add(NewNegativeDirent("")) - } - if got, want := c1.Size(), uint64(maxSize); got != want { - t.Errorf("c1.Size() got %v, want %v", got, want) - } - - // Check that d is still there. - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - - // Fill up the other cache, it will start dropping old entries from the cache - // when the global limit is reached. - for i := 0; i < maxSize; i++ { - c2.Add(NewNegativeDirent("")) - } - - // Check is what's remaining from global max. - if got, want := c2.Size(), globalMaxSize-maxSize; int(got) != want { - t.Errorf("c2.Size() got %v, want %v", got, want) - } - - // Check that d was not dropped. - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - - // Add an entry that will eventually be dropped. Check is done later... - drop := NewNegativeDirent("") - c1.Add(drop) - - // Check that d is bumped to front even when global limit is reached. - c1.Add(d) - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - - // Add 2 more element and check that: - // - d is still in the list: to verify that d was bumped - // - d2/d3 are in the list: older entries are dropped when global limit is - // reached. - // - drop is not in the list: indeed older elements are dropped. - d2 := NewNegativeDirent("") - c1.Add(d2) - d3 := NewNegativeDirent("") - c1.Add(d3) - if got, want := c1.contains(d), true; got != want { - t.Errorf("c1.contains(d) got %v want %v", got, want) - } - if got, want := c1.contains(d2), true; got != want { - t.Errorf("c1.contains(d2) got %v want %v", got, want) - } - if got, want := c1.contains(d3), true; got != want { - t.Errorf("c1.contains(d3) got %v want %v", got, want) - } - if got, want := c1.contains(drop), false; got != want { - t.Errorf("c1.contains(drop) got %v want %v", got, want) - } - - // Drop all entries from one cache. The other will be allowed to grow. - c1.Invalidate() - c2.Add(NewNegativeDirent("")) - if got, want := c2.Size(), uint64(maxSize); got != want { - t.Errorf("c2.Size() got %v, want %v", got, want) - } -} - -// TestNilDirentCache tests that a nil cache supports all cache operations, but -// treats them as noop. -func TestNilDirentCache(t *testing.T) { - // Create a nil cache. - var c *DirentCache - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Call Add. - c.Add(NewNegativeDirent("")) - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Call Remove. - c.Remove(NewNegativeDirent("")) - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } - - // Call Invalidate. - c.Invalidate() - - // Size is zero. - if got, want := c.Size(), uint64(0); got != want { - t.Errorf("c.Size() got %v, want %v", got, want) - } -} diff --git a/pkg/sentry/fs/dirent_list.go b/pkg/sentry/fs/dirent_list.go new file mode 100755 index 000000000..750961a48 --- /dev/null +++ b/pkg/sentry/fs/dirent_list.go @@ -0,0 +1,173 @@ +package fs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type direntElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (direntElementMapper) linkerFor(elem *Dirent) *Dirent { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type direntList struct { + head *Dirent + tail *Dirent +} + +// Reset resets list l to the empty state. +func (l *direntList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *direntList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *direntList) Front() *Dirent { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *direntList) Back() *Dirent { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *direntList) PushFront(e *Dirent) { + direntElementMapper{}.linkerFor(e).SetNext(l.head) + direntElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + direntElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *direntList) PushBack(e *Dirent) { + direntElementMapper{}.linkerFor(e).SetNext(nil) + direntElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + direntElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *direntList) PushBackList(m *direntList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + direntElementMapper{}.linkerFor(l.tail).SetNext(m.head) + direntElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *direntList) InsertAfter(b, e *Dirent) { + a := direntElementMapper{}.linkerFor(b).Next() + direntElementMapper{}.linkerFor(e).SetNext(a) + direntElementMapper{}.linkerFor(e).SetPrev(b) + direntElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + direntElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *direntList) InsertBefore(a, e *Dirent) { + b := direntElementMapper{}.linkerFor(a).Prev() + direntElementMapper{}.linkerFor(e).SetNext(a) + direntElementMapper{}.linkerFor(e).SetPrev(b) + direntElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + direntElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *direntList) Remove(e *Dirent) { + prev := direntElementMapper{}.linkerFor(e).Prev() + next := direntElementMapper{}.linkerFor(e).Next() + + if prev != nil { + direntElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + direntElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type direntEntry struct { + next *Dirent + prev *Dirent +} + +// Next returns the entry that follows e in the list. +func (e *direntEntry) Next() *Dirent { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *direntEntry) Prev() *Dirent { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *direntEntry) SetNext(elem *Dirent) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *direntEntry) SetPrev(elem *Dirent) { + e.prev = elem +} diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go deleted file mode 100644 index 73ffd15a6..000000000 --- a/pkg/sentry/fs/dirent_refs_test.go +++ /dev/null @@ -1,417 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" -) - -func newMockDirInode(ctx context.Context, cache *DirentCache) *Inode { - return NewMockInode(ctx, NewMockMountSource(cache), StableAttr{Type: Directory}) -} - -func TestWalkPositive(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - root := NewDirent(newMockDirInode(ctx, nil), "root") - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - name := "d" - d, err := root.walk(ctx, root, name, false) - if err != nil { - t.Fatalf("root.walk(root, %q) got %v, want nil", name, err) - } - - if got := root.ReadRefs(); got != 2 { - t.Fatalf("root has a ref count of %d, want %d", got, 2) - } - - if got := d.ReadRefs(); got != 1 { - t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 1) - } - - d.DecRef() - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - if got := d.ReadRefs(); got != 0 { - t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 0) - } - - root.flush() - - if got := len(root.children); got != 0 { - t.Fatalf("root has %d children, want %d", got, 0) - } -} - -func TestWalkNegative(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - root := NewDirent(NewEmptyDir(ctx, nil), "root") - mn := root.Inode.InodeOperations.(*mockInodeOperationsLookupNegative) - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - name := "d" - for i := 0; i < 100; i++ { - _, err := root.walk(ctx, root, name, false) - if err != syscall.ENOENT { - t.Fatalf("root.walk(root, %q) got %v, want %v", name, err, syscall.ENOENT) - } - } - - if got := root.ReadRefs(); got != 1 { - t.Fatalf("root has a ref count of %d, want %d", got, 1) - } - - if got := len(root.children); got != 1 { - t.Fatalf("root has %d children, want %d", got, 1) - } - - w, ok := root.children[name] - if !ok { - t.Fatalf("root wants child at %q", name) - } - - child := w.Get() - if child == nil { - t.Fatalf("root wants to resolve weak reference") - } - - if !child.(*Dirent).IsNegative() { - t.Fatalf("root found positive child at %q, want negative", name) - } - - if got := child.(*Dirent).ReadRefs(); got != 2 { - t.Fatalf("child has a ref count of %d, want %d", got, 2) - } - - child.DecRef() - - if got := child.(*Dirent).ReadRefs(); got != 1 { - t.Fatalf("child has a ref count of %d, want %d", got, 1) - } - - if got := len(root.children); got != 1 { - t.Fatalf("root has %d children, want %d", got, 1) - } - - root.DecRef() - - if got := root.ReadRefs(); got != 0 { - t.Fatalf("root has a ref count of %d, want %d", got, 0) - } - - AsyncBarrier() - - if got := mn.releaseCalled; got != true { - t.Fatalf("root.Close was called %v, want true", got) - } -} - -type mockInodeOperationsLookupNegative struct { - *MockInodeOperations - releaseCalled bool -} - -func NewEmptyDir(ctx context.Context, cache *DirentCache) *Inode { - m := NewMockMountSource(cache) - return NewInode(&mockInodeOperationsLookupNegative{ - MockInodeOperations: NewMockInodeOperations(ctx), - }, m, StableAttr{Type: Directory}) -} - -func (m *mockInodeOperationsLookupNegative) Lookup(ctx context.Context, dir *Inode, p string) (*Dirent, error) { - return NewNegativeDirent(p), nil -} - -func (m *mockInodeOperationsLookupNegative) Release(context.Context) { - m.releaseCalled = true -} - -func TestHashNegativeToPositive(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - root := NewDirent(NewEmptyDir(ctx, nil), "root") - - name := "d" - _, err := root.walk(ctx, root, name, false) - if err != syscall.ENOENT { - t.Fatalf("root.walk(root, %q) got %v, want %v", name, err, syscall.ENOENT) - } - - if got := root.exists(ctx, root, name); got != false { - t.Fatalf("got %q exists, want does not exist", name) - } - - f, err := root.Create(ctx, root, name, FileFlags{}, FilePermissions{}) - if err != nil { - t.Fatalf("root.Create(%q, _), got error %v, want nil", name, err) - } - d := f.Dirent - - if d.IsNegative() { - t.Fatalf("got negative Dirent, want positive") - } - - if got := d.ReadRefs(); got != 1 { - t.Fatalf("child %q has a ref count of %d, want %d", name, got, 1) - } - - if got := root.ReadRefs(); got != 2 { - t.Fatalf("root has a ref count of %d, want %d", got, 2) - } - - if got := len(root.children); got != 1 { - t.Fatalf("got %d children, want %d", got, 1) - } - - w, ok := root.children[name] - if !ok { - t.Fatalf("failed to find weak reference to %q", name) - } - - child := w.Get() - if child == nil { - t.Fatalf("want to resolve weak reference") - } - - if child.(*Dirent) != d { - t.Fatalf("got foreign child") - } -} - -func TestRevalidate(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - for _, test := range []struct { - // desc is the test's description. - desc string - - // Whether to make negative Dirents. - makeNegative bool - }{ - { - desc: "Revalidate negative Dirent", - makeNegative: true, - }, - { - desc: "Revalidate positive Dirent", - makeNegative: false, - }, - } { - t.Run(test.desc, func(t *testing.T) { - root := NewDirent(NewMockInodeRevalidate(ctx, test.makeNegative), "root") - - name := "d" - d1, err := root.walk(ctx, root, name, false) - if !test.makeNegative && err != nil { - t.Fatalf("root.walk(root, %q) got %v, want nil", name, err) - } - d2, err := root.walk(ctx, root, name, false) - if !test.makeNegative && err != nil { - t.Fatalf("root.walk(root, %q) got %v, want nil", name, err) - } - if !test.makeNegative && d1 == d2 { - t.Fatalf("revalidating walk got same *Dirent, want different") - } - if got := len(root.children); got != 1 { - t.Errorf("revalidating walk got %d children, want %d", got, 1) - } - }) - } -} - -type MockInodeOperationsRevalidate struct { - *MockInodeOperations - makeNegative bool -} - -func NewMockInodeRevalidate(ctx context.Context, makeNegative bool) *Inode { - mn := NewMockInodeOperations(ctx) - m := NewMockMountSource(nil) - m.MountSourceOperations.(*MockMountSourceOps).revalidate = true - return NewInode(&MockInodeOperationsRevalidate{MockInodeOperations: mn, makeNegative: makeNegative}, m, StableAttr{Type: Directory}) -} - -func (m *MockInodeOperationsRevalidate) Lookup(ctx context.Context, dir *Inode, p string) (*Dirent, error) { - if !m.makeNegative { - return m.MockInodeOperations.Lookup(ctx, dir, p) - } - return NewNegativeDirent(p), nil -} - -func TestCreateExtraRefs(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - for _, test := range []struct { - // desc is the test's description. - desc string - - // root is the Dirent to create from. - root *Dirent - - // expected references on walked Dirent. - refs int64 - }{ - { - desc: "Create caching", - root: NewDirent(NewEmptyDir(ctx, NewDirentCache(1)), "root"), - refs: 2, - }, - { - desc: "Create not caching", - root: NewDirent(NewEmptyDir(ctx, nil), "root"), - refs: 1, - }, - } { - t.Run(test.desc, func(t *testing.T) { - name := "d" - f, err := test.root.Create(ctx, test.root, name, FileFlags{}, FilePermissions{}) - if err != nil { - t.Fatalf("root.Create(root, %q) failed: %v", name, err) - } - d := f.Dirent - - if got := d.ReadRefs(); got != test.refs { - t.Errorf("dirent has a ref count of %d, want %d", got, test.refs) - } - }) - } -} - -func TestRemoveExtraRefs(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - for _, test := range []struct { - // desc is the test's description. - desc string - - // root is the Dirent to make and remove from. - root *Dirent - }{ - { - desc: "Remove caching", - root: NewDirent(NewEmptyDir(ctx, NewDirentCache(1)), "root"), - }, - { - desc: "Remove not caching", - root: NewDirent(NewEmptyDir(ctx, nil), "root"), - }, - } { - t.Run(test.desc, func(t *testing.T) { - name := "d" - f, err := test.root.Create(ctx, test.root, name, FileFlags{}, FilePermissions{}) - if err != nil { - t.Fatalf("root.Create(%q, _) failed: %v", name, err) - } - d := f.Dirent - - if err := test.root.Remove(contexttest.Context(t), test.root, name); err != nil { - t.Fatalf("root.Remove(root, %q) failed: %v", name, err) - } - - if got := d.ReadRefs(); got != 1 { - t.Fatalf("dirent has a ref count of %d, want %d", got, 1) - } - - d.DecRef() - - test.root.flush() - - if got := len(test.root.children); got != 0 { - t.Errorf("root has %d children, want %d", got, 0) - } - }) - } -} - -func TestRenameExtraRefs(t *testing.T) { - // refs == 0 -> one reference. - // refs == -1 -> has been destroyed. - - ctx := contexttest.Context(t) - for _, test := range []struct { - // desc is the test's description. - desc string - - // cache of extra Dirent references, may be nil. - cache *DirentCache - }{ - { - desc: "Rename no caching", - cache: nil, - }, - { - desc: "Rename caching", - cache: NewDirentCache(5), - }, - } { - t.Run(test.desc, func(t *testing.T) { - dirAttr := StableAttr{Type: Directory} - - oldParent := NewDirent(NewMockInode(ctx, NewMockMountSource(test.cache), dirAttr), "old_parent") - newParent := NewDirent(NewMockInode(ctx, NewMockMountSource(test.cache), dirAttr), "new_parent") - - renamed, err := oldParent.Walk(ctx, oldParent, "old_child") - if err != nil { - t.Fatalf("Walk(oldParent, %q) got error %v, want nil", "old_child", err) - } - replaced, err := newParent.Walk(ctx, oldParent, "new_child") - if err != nil { - t.Fatalf("Walk(newParent, %q) got error %v, want nil", "new_child", err) - } - - if err := Rename(contexttest.RootContext(t), oldParent /*root */, oldParent, "old_child", newParent, "new_child"); err != nil { - t.Fatalf("Rename got error %v, want nil", err) - } - - oldParent.flush() - newParent.flush() - - // Expect to have only active references. - if got := renamed.ReadRefs(); got != 1 { - t.Errorf("renamed has ref count %d, want only active references %d", got, 1) - } - if got := replaced.ReadRefs(); got != 1 { - t.Errorf("replaced has ref count %d, want only active references %d", got, 1) - } - }) - } -} diff --git a/pkg/sentry/fs/event_list.go b/pkg/sentry/fs/event_list.go new file mode 100755 index 000000000..c94cb03e1 --- /dev/null +++ b/pkg/sentry/fs/event_list.go @@ -0,0 +1,173 @@ +package fs + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type eventElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (eventElementMapper) linkerFor(elem *Event) *Event { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type eventList struct { + head *Event + tail *Event +} + +// Reset resets list l to the empty state. +func (l *eventList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *eventList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *eventList) Front() *Event { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *eventList) Back() *Event { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *eventList) PushFront(e *Event) { + eventElementMapper{}.linkerFor(e).SetNext(l.head) + eventElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + eventElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *eventList) PushBack(e *Event) { + eventElementMapper{}.linkerFor(e).SetNext(nil) + eventElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + eventElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *eventList) PushBackList(m *eventList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + eventElementMapper{}.linkerFor(l.tail).SetNext(m.head) + eventElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *eventList) InsertAfter(b, e *Event) { + a := eventElementMapper{}.linkerFor(b).Next() + eventElementMapper{}.linkerFor(e).SetNext(a) + eventElementMapper{}.linkerFor(e).SetPrev(b) + eventElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + eventElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *eventList) InsertBefore(a, e *Event) { + b := eventElementMapper{}.linkerFor(a).Prev() + eventElementMapper{}.linkerFor(e).SetNext(a) + eventElementMapper{}.linkerFor(e).SetPrev(b) + eventElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + eventElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *eventList) Remove(e *Event) { + prev := eventElementMapper{}.linkerFor(e).Prev() + next := eventElementMapper{}.linkerFor(e).Next() + + if prev != nil { + eventElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + eventElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type eventEntry struct { + next *Event + prev *Event +} + +// Next returns the entry that follows e in the list. +func (e *eventEntry) Next() *Event { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *eventEntry) Prev() *Event { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *eventEntry) SetNext(elem *Event) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *eventEntry) SetPrev(elem *Event) { + e.prev = elem +} diff --git a/pkg/sentry/fs/ext4/BUILD b/pkg/sentry/fs/ext4/BUILD deleted file mode 100644 index 9dce67635..000000000 --- a/pkg/sentry/fs/ext4/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "ext4", - srcs = ["fs.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/ext4", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/fs", - ], -) diff --git a/pkg/sentry/fs/ext4/fs.go b/pkg/sentry/fs/ext4/fs.go deleted file mode 100644 index 5c7274821..000000000 --- a/pkg/sentry/fs/ext4/fs.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package ext4 implements the ext4 filesystem. -package ext4 - -import ( - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -// filesystem implements fs.Filesystem for ext4. -// -// +stateify savable -type filesystem struct{} - -func init() { - fs.RegisterFilesystem(&filesystem{}) -} - -// FilesystemName is the name under which the filesystem is registered. -// Name matches fs/ext4/super.c:ext4_fs_type.name. -const FilesystemName = "ext4" - -// Name is the name of the file system. -func (*filesystem) Name() string { - return FilesystemName -} - -// AllowUserMount prohibits users from using mount(2) with this file system. -func (*filesystem) AllowUserMount() bool { - return false -} - -// AllowUserList prohibits this filesystem to be listed in /proc/filesystems. -func (*filesystem) AllowUserList() bool { - return false -} - -// Flags returns properties of the filesystem. -// -// In Linux, ext4 returns FS_REQUIRES_DEV. See fs/ext4/super.c -func (*filesystem) Flags() fs.FilesystemFlags { - return fs.FilesystemRequiresDev -} - -// Mount returns the root inode of the ext4 fs. -func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, cgroupsInt interface{}) (*fs.Inode, error) { - panic("unimplemented") -} diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD deleted file mode 100644 index bf00b9c09..000000000 --- a/pkg/sentry/fs/fdpipe/BUILD +++ /dev/null @@ -1,48 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "fdpipe", - srcs = [ - "pipe.go", - "pipe_opener.go", - "pipe_state.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/fdpipe", - imports = ["gvisor.dev/gvisor/pkg/sentry/fs"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/log", - "//pkg/secio", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/safemem", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) - -go_test( - name = "fdpipe_test", - size = "small", - srcs = [ - "pipe_opener_test.go", - "pipe_test.go", - ], - embed = [":fdpipe"], - deps = [ - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/usermem", - "//pkg/syserror", - "@com_github_google_uuid//:go_default_library", - ], -) diff --git a/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go b/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go new file mode 100755 index 000000000..38c1ed916 --- /dev/null +++ b/pkg/sentry/fs/fdpipe/fdpipe_state_autogen.go @@ -0,0 +1,27 @@ +// automatically generated by stateify. + +package fdpipe + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/sentry/fs" +) + +func (x *pipeOperations) save(m state.Map) { + x.beforeSave() + var flags fs.FileFlags = x.saveFlags() + m.SaveValue("flags", flags) + m.Save("opener", &x.opener) + m.Save("readAheadBuffer", &x.readAheadBuffer) +} + +func (x *pipeOperations) load(m state.Map) { + m.LoadWait("opener", &x.opener) + m.Load("readAheadBuffer", &x.readAheadBuffer) + m.LoadValue("flags", new(fs.FileFlags), func(y interface{}) { x.loadFlags(y.(fs.FileFlags)) }) + m.AfterLoad(x.afterLoad) +} + +func init() { + state.Register("fdpipe.pipeOperations", (*pipeOperations)(nil), state.Fns{Save: (*pipeOperations).save, Load: (*pipeOperations).load}) +} diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go deleted file mode 100644 index 7ae61b711..000000000 --- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go +++ /dev/null @@ -1,522 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fdpipe - -import ( - "bytes" - "fmt" - "io" - "os" - "path" - "syscall" - "testing" - "time" - - "github.com/google/uuid" - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserror" -) - -type hostOpener struct { - name string -} - -func (h *hostOpener) NonBlockingOpen(_ context.Context, p fs.PermMask) (*fd.FD, error) { - var flags int - switch { - case p.Read && p.Write: - flags = syscall.O_RDWR - case p.Write: - flags = syscall.O_WRONLY - case p.Read: - flags = syscall.O_RDONLY - default: - return nil, syscall.EINVAL - } - f, err := syscall.Open(h.name, flags|syscall.O_NONBLOCK, 0666) - if err != nil { - return nil, err - } - return fd.New(f), nil -} - -func pipename() string { - return fmt.Sprintf(path.Join(os.TempDir(), "test-named-pipe-%s"), uuid.New()) -} - -func mkpipe(name string) error { - return syscall.Mknod(name, syscall.S_IFIFO|0666, 0) -} - -func TestTryOpen(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // makePipe is true if the test case should create the pipe. - makePipe bool - - // flags are the fs.FileFlags used to open the pipe. - flags fs.FileFlags - - // expectFile is true if a fs.File is expected. - expectFile bool - - // err is the expected error - err error - }{ - { - desc: "FileFlags lacking Read and Write are invalid", - makePipe: false, - flags: fs.FileFlags{}, /* bogus */ - expectFile: false, - err: syscall.EINVAL, - }, - { - desc: "NonBlocking Read only error returns immediately", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Read: true, NonBlocking: true}, - expectFile: false, - err: syscall.ENOENT, - }, - { - desc: "NonBlocking Read only success returns immediately", - makePipe: true, - flags: fs.FileFlags{Read: true, NonBlocking: true}, - expectFile: true, - err: nil, - }, - { - desc: "NonBlocking Write only error returns immediately", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Write: true, NonBlocking: true}, - expectFile: false, - err: syscall.ENOENT, - }, - { - desc: "NonBlocking Write only no reader error returns immediately", - makePipe: true, - flags: fs.FileFlags{Write: true, NonBlocking: true}, - expectFile: false, - err: syscall.ENXIO, - }, - { - desc: "ReadWrite error returns immediately", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Read: true, Write: true}, - expectFile: false, - err: syscall.ENOENT, - }, - { - desc: "ReadWrite returns immediately", - makePipe: true, - flags: fs.FileFlags{Read: true, Write: true}, - expectFile: true, - err: nil, - }, - { - desc: "Blocking Write only returns open error", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Write: true}, - expectFile: false, - err: syscall.ENOENT, /* from bogus perms */ - }, - { - desc: "Blocking Read only returns open error", - makePipe: false, /* causes the error */ - flags: fs.FileFlags{Read: true}, - expectFile: false, - err: syscall.ENOENT, - }, - { - desc: "Blocking Write only returns with syserror.ErrWouldBlock", - makePipe: true, - flags: fs.FileFlags{Write: true}, - expectFile: false, - err: syserror.ErrWouldBlock, - }, - { - desc: "Blocking Read only returns with syserror.ErrWouldBlock", - makePipe: true, - flags: fs.FileFlags{Read: true}, - expectFile: false, - err: syserror.ErrWouldBlock, - }, - } { - name := pipename() - if test.makePipe { - // Create the pipe. We do this per-test case to keep tests independent. - if err := mkpipe(name); err != nil { - t.Errorf("%s: failed to make host pipe: %v", test.desc, err) - continue - } - defer syscall.Unlink(name) - } - - // Use a host opener to keep things simple. - opener := &hostOpener{name: name} - - pipeOpenState := &pipeOpenState{} - ctx := contexttest.Context(t) - pipeOps, err := pipeOpenState.TryOpen(ctx, opener, test.flags) - if unwrapError(err) != test.err { - t.Errorf("%s: got error %v, want %v", test.desc, err, test.err) - if pipeOps != nil { - // Cleanup the state of the pipe, and remove the fd from the - // fdnotifier. Sadly this needed to maintain the correctness - // of other tests because the fdnotifier is global. - pipeOps.Release() - } - continue - } - if (pipeOps != nil) != test.expectFile { - t.Errorf("%s: got non-nil file %v, want %v", test.desc, pipeOps != nil, test.expectFile) - } - if pipeOps != nil { - // Same as above. - pipeOps.Release() - } - } -} - -func TestPipeOpenUnblocksEventually(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // partnerIsReader is true if the goroutine opening the same pipe as the test case - // should open the pipe read only. Otherwise write only. This also means that the - // test case will open the pipe in the opposite way. - partnerIsReader bool - - // partnerIsBlocking is true if the goroutine opening the same pipe as the test case - // should do so without the O_NONBLOCK flag, otherwise opens the pipe with O_NONBLOCK - // until ENXIO is not returned. - partnerIsBlocking bool - }{ - { - desc: "Blocking Read with blocking writer partner opens eventually", - partnerIsReader: false, - partnerIsBlocking: true, - }, - { - desc: "Blocking Write with blocking reader partner opens eventually", - partnerIsReader: true, - partnerIsBlocking: true, - }, - { - desc: "Blocking Read with non-blocking writer partner opens eventually", - partnerIsReader: false, - partnerIsBlocking: false, - }, - { - desc: "Blocking Write with non-blocking reader partner opens eventually", - partnerIsReader: true, - partnerIsBlocking: false, - }, - } { - // Create the pipe. We do this per-test case to keep tests independent. - name := pipename() - if err := mkpipe(name); err != nil { - t.Errorf("%s: failed to make host pipe: %v", test.desc, err) - continue - } - defer syscall.Unlink(name) - - // Spawn the partner. - type fderr struct { - fd int - err error - } - errch := make(chan fderr, 1) - go func() { - var flags int - if test.partnerIsReader { - flags = syscall.O_RDONLY - } else { - flags = syscall.O_WRONLY - } - if test.partnerIsBlocking { - fd, err := syscall.Open(name, flags, 0666) - errch <- fderr{fd: fd, err: err} - } else { - var fd int - err := error(syscall.ENXIO) - for err == syscall.ENXIO { - fd, err = syscall.Open(name, flags|syscall.O_NONBLOCK, 0666) - time.Sleep(1 * time.Second) - } - errch <- fderr{fd: fd, err: err} - } - }() - - // Setup file flags for either a read only or write only open. - flags := fs.FileFlags{ - Read: !test.partnerIsReader, - Write: test.partnerIsReader, - } - - // Open the pipe in a blocking way, which should succeed eventually. - opener := &hostOpener{name: name} - ctx := contexttest.Context(t) - pipeOps, err := Open(ctx, opener, flags) - if pipeOps != nil { - // Same as TestTryOpen. - pipeOps.Release() - } - - // Check that the partner opened the file successfully. - e := <-errch - if e.err != nil { - t.Errorf("%s: partner got error %v, wanted nil", test.desc, e.err) - continue - } - // If so, then close the partner fd to avoid leaking an fd. - syscall.Close(e.fd) - - // Check that our blocking open was successful. - if err != nil { - t.Errorf("%s: blocking open got error %v, wanted nil", test.desc, err) - continue - } - if pipeOps == nil { - t.Errorf("%s: blocking open got nil file, wanted non-nil", test.desc) - continue - } - } -} - -func TestCopiedReadAheadBuffer(t *testing.T) { - // Create the pipe. - name := pipename() - if err := mkpipe(name); err != nil { - t.Fatalf("failed to make host pipe: %v", err) - } - defer syscall.Unlink(name) - - // We're taking advantage of the fact that pipes opened read only always return - // success, but internally they are not deemed "opened" until we're sure that - // another writer comes along. This means we can open the same pipe write only - // with no problems + write to it, given that opener.Open already tried to open - // the pipe RDONLY and succeeded, which we know happened if TryOpen returns - // syserror.ErrwouldBlock. - // - // This simulates the open(RDONLY) <-> open(WRONLY)+write race we care about, but - // does not cause our test to be racy (which would be terrible). - opener := &hostOpener{name: name} - pipeOpenState := &pipeOpenState{} - ctx := contexttest.Context(t) - pipeOps, err := pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true}) - if pipeOps != nil { - pipeOps.Release() - t.Fatalf("open(%s, %o) got file, want nil", name, syscall.O_RDONLY) - } - if err != syserror.ErrWouldBlock { - t.Fatalf("open(%s, %o) got error %v, want %v", name, syscall.O_RDONLY, err, syserror.ErrWouldBlock) - } - - // Then open the same pipe write only and write some bytes to it. The next - // time we try to open the pipe read only again via the pipeOpenState, we should - // succeed and buffer some of the bytes written. - fd, err := syscall.Open(name, syscall.O_WRONLY, 0666) - if err != nil { - t.Fatalf("open(%s, %o) got error %v, want nil", name, syscall.O_WRONLY, err) - } - defer syscall.Close(fd) - - data := []byte("hello") - if n, err := syscall.Write(fd, data); n != len(data) || err != nil { - t.Fatalf("write(%v) got (%d, %v), want (%d, nil)", data, n, err, len(data)) - } - - // Try the read again, knowing that it should succeed this time. - pipeOps, err = pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true}) - if pipeOps == nil { - t.Fatalf("open(%s, %o) got nil file, want not nil", name, syscall.O_RDONLY) - } - defer pipeOps.Release() - - if err != nil { - t.Fatalf("open(%s, %o) got error %v, want nil", name, syscall.O_RDONLY, err) - } - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(inode, "pipe"), fs.FileFlags{Read: true}, pipeOps) - - // Check that the file we opened points to a pipe with a non-empty read ahead buffer. - bufsize := len(pipeOps.readAheadBuffer) - if bufsize != 1 { - t.Fatalf("read ahead buffer got %d bytes, want %d", bufsize, 1) - } - - // Now for the final test, try to read everything in, expecting to get back all of - // the bytes that were written at once. Note that in the wild there is no atomic - // read size so expecting to get all bytes from a single writer when there are - // multiple readers is a bad expectation. - buf := make([]byte, len(data)) - ioseq := usermem.BytesIOSequence(buf) - n, err := pipeOps.Read(ctx, file, ioseq, 0) - if err != nil { - t.Fatalf("read request got error %v, want nil", err) - } - if n != int64(len(data)) { - t.Fatalf("read request got %d bytes, want %d", n, len(data)) - } - if !bytes.Equal(buf, data) { - t.Errorf("read request got bytes [%v], want [%v]", buf, data) - } -} - -func TestPipeHangup(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // flags control how we open our end of the pipe and must be read - // only or write only. They also dicate how a coordinating partner - // fd is opened, which is their inverse (read only -> write only, etc). - flags fs.FileFlags - - // hangupSelf if true causes the test case to close our end of the pipe - // and causes hangup errors to be asserted on our coordinating partner's - // fd. If hangupSelf is false, then our partner's fd is closed and the - // hangup errors are expected on our end of the pipe. - hangupSelf bool - }{ - { - desc: "Read only gets hangup error", - flags: fs.FileFlags{Read: true}, - }, - { - desc: "Write only gets hangup error", - flags: fs.FileFlags{Write: true}, - }, - { - desc: "Read only generates hangup error", - flags: fs.FileFlags{Read: true}, - hangupSelf: true, - }, - { - desc: "Write only generates hangup error", - flags: fs.FileFlags{Write: true}, - hangupSelf: true, - }, - } { - if test.flags.Read == test.flags.Write { - t.Errorf("%s: test requires a single reader or writer", test.desc) - continue - } - - // Create the pipe. We do this per-test case to keep tests independent. - name := pipename() - if err := mkpipe(name); err != nil { - t.Errorf("%s: failed to make host pipe: %v", test.desc, err) - continue - } - defer syscall.Unlink(name) - - // Fire off a partner routine which tries to open the same pipe blocking, - // which will synchronize with us. The channel allows us to get back the - // fd once we expect this partner routine to succeed, so we can manifest - // hangup events more directly. - fdchan := make(chan int, 1) - go func() { - // Be explicit about the flags to protect the test from - // misconfiguration. - var flags int - if test.flags.Read { - flags = syscall.O_WRONLY - } else { - flags = syscall.O_RDONLY - } - fd, err := syscall.Open(name, flags, 0666) - if err != nil { - t.Logf("Open(%q, %o, 0666) partner failed: %v", name, flags, err) - } - fdchan <- fd - }() - - // Open our end in a blocking way to ensure that we coordinate. - opener := &hostOpener{name: name} - ctx := contexttest.Context(t) - pipeOps, err := Open(ctx, opener, test.flags) - if err != nil { - t.Errorf("%s: Open got error %v, want nil", test.desc, err) - continue - } - // Don't defer file.DecRef here because that causes the hangup we're - // trying to test for. - - // Expect the partner routine to have coordinated with us and get back - // its open fd. - f := <-fdchan - if f < 0 { - t.Errorf("%s: partner routine got fd %d, want > 0", test.desc, f) - pipeOps.Release() - continue - } - - if test.hangupSelf { - // Hangup self and assert that our partner got the expected hangup - // error. - pipeOps.Release() - - if test.flags.Read { - // Partner is writer. - assertWriterHungup(t, test.desc, fd.NewReadWriter(f)) - } else { - // Partner is reader. - assertReaderHungup(t, test.desc, fd.NewReadWriter(f)) - } - } else { - // Hangup our partner and expect us to get the hangup error. - syscall.Close(f) - defer pipeOps.Release() - - if test.flags.Read { - assertReaderHungup(t, test.desc, pipeOps.(*pipeOperations).file) - } else { - assertWriterHungup(t, test.desc, pipeOps.(*pipeOperations).file) - } - } - } -} - -func assertReaderHungup(t *testing.T, desc string, reader io.Reader) bool { - // Drain the pipe completely, it might have crap in it, but expect EOF eventually. - var err error - for err == nil { - _, err = reader.Read(make([]byte, 10)) - } - if err != io.EOF { - t.Errorf("%s: read from self after hangup got error %v, want %v", desc, err, io.EOF) - return false - } - return true -} - -func assertWriterHungup(t *testing.T, desc string, writer io.Writer) bool { - if _, err := writer.Write([]byte("hello")); unwrapError(err) != syscall.EPIPE { - t.Errorf("%s: write to self after hangup got error %v, want %v", desc, err, syscall.EPIPE) - return false - } - return true -} diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go deleted file mode 100644 index 21ea14359..000000000 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ /dev/null @@ -1,489 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fdpipe - -import ( - "bytes" - "io" - "os" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserror" -) - -func singlePipeFD() (int, error) { - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - return -1, err - } - syscall.Close(fds[1]) - return fds[0], nil -} - -func singleDirFD() (int, error) { - return syscall.Open(os.TempDir(), syscall.O_RDONLY, 0666) -} - -func mockPipeDirent(t *testing.T) *fs.Dirent { - ctx := contexttest.Context(t) - node := fs.NewMockInodeOperations(ctx) - node.UAttr = fs.UnstableAttr{ - Perms: fs.FilePermissions{ - User: fs.PermMask{Read: true, Write: true}, - }, - } - inode := fs.NewInode(node, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - BlockSize: usermem.PageSize, - }) - return fs.NewDirent(inode, "") -} - -func TestNewPipe(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // getfd generates the fd to pass to newPipeOperations. - getfd func() (int, error) - - // flags are the fs.FileFlags passed to newPipeOperations. - flags fs.FileFlags - - // readAheadBuffer is the buffer passed to newPipeOperations. - readAheadBuffer []byte - - // err is the expected error. - err error - }{ - { - desc: "Cannot make new pipe from bad fd", - getfd: func() (int, error) { return -1, nil }, - err: syscall.EINVAL, - }, - { - desc: "Cannot make new pipe from non-pipe fd", - getfd: singleDirFD, - err: syscall.EINVAL, - }, - { - desc: "Can make new pipe from pipe fd", - getfd: singlePipeFD, - flags: fs.FileFlags{Read: true}, - readAheadBuffer: []byte("hello"), - }, - } { - gfd, err := test.getfd() - if err != nil { - t.Errorf("%s: getfd got (%d, %v), want (fd, nil)", test.desc, gfd, err) - continue - } - f := fd.New(gfd) - - p, err := newPipeOperations(contexttest.Context(t), nil, test.flags, f, test.readAheadBuffer) - if p != nil { - // This is necessary to remove the fd from the global fd notifier. - defer p.Release() - } else { - // If there is no p to DecRef on, because newPipeOperations failed, then the - // file still needs to be closed. - defer f.Close() - } - - if err != test.err { - t.Errorf("%s: got error %v, want %v", test.desc, err, test.err) - continue - } - // Check the state of the pipe given that it was successfully opened. - if err == nil { - if p == nil { - t.Errorf("%s: got nil pipe and nil error, want (pipe, nil)", test.desc) - continue - } - if flags := p.flags; test.flags != flags { - t.Errorf("%s: got file flags %s, want %s", test.desc, flags, test.flags) - continue - } - if len(test.readAheadBuffer) != len(p.readAheadBuffer) { - t.Errorf("%s: got read ahead buffer length %d, want %d", test.desc, len(p.readAheadBuffer), len(test.readAheadBuffer)) - continue - } - fileFlags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(p.file.FD()), syscall.F_GETFL, 0) - if errno != 0 { - t.Errorf("%s: failed to get file flags for fd %d, got %v, want 0", test.desc, p.file.FD(), errno) - continue - } - if fileFlags&syscall.O_NONBLOCK == 0 { - t.Errorf("%s: pipe is blocking, expected non-blocking", test.desc) - continue - } - if !fdnotifier.HasFD(int32(f.FD())) { - t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD) - } - } - } -} - -func TestPipeDestruction(t *testing.T) { - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - f := fd.New(fds[0]) - - // We don't care about the other end, just use the read end. - syscall.Close(fds[1]) - - // Test the read end, but it doesn't really matter which. - p, err := newPipeOperations(contexttest.Context(t), nil, fs.FileFlags{Read: true}, f, nil) - if err != nil { - f.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - // Drop our only reference, which should trigger the destructor. - p.Release() - - if fdnotifier.HasFD(int32(fds[0])) { - t.Fatalf("after DecRef fdnotifier has fd %d, want no longer registered", fds[0]) - } - if p.file != nil { - t.Errorf("after DecRef got file, want nil") - } -} - -type Seek struct{} - -type ReadDir struct{} - -type Writev struct { - Src usermem.IOSequence -} - -type Readv struct { - Dst usermem.IOSequence -} - -type Fsync struct{} - -func TestPipeRequest(t *testing.T) { - for _, test := range []struct { - // desc is the test's description. - desc string - - // request to execute. - context interface{} - - // flags determines whether to use the read or write end - // of the pipe, for this test it can only be Read or Write. - flags fs.FileFlags - - // keepOpenPartner if false closes the other end of the pipe, - // otherwise this is delayed until the end of the test. - keepOpenPartner bool - - // expected error - err error - }{ - { - desc: "ReadDir on pipe returns ENOTDIR", - context: &ReadDir{}, - err: syscall.ENOTDIR, - }, - { - desc: "Fsync on pipe returns EINVAL", - context: &Fsync{}, - err: syscall.EINVAL, - }, - { - desc: "Seek on pipe returns ESPIPE", - context: &Seek{}, - err: syscall.ESPIPE, - }, - { - desc: "Readv on pipe from empty buffer returns nil", - context: &Readv{Dst: usermem.BytesIOSequence(nil)}, - flags: fs.FileFlags{Read: true}, - }, - { - desc: "Readv on pipe from non-empty buffer and closed partner returns EOF", - context: &Readv{Dst: usermem.BytesIOSequence(make([]byte, 10))}, - flags: fs.FileFlags{Read: true}, - err: io.EOF, - }, - { - desc: "Readv on pipe from non-empty buffer and open partner returns EWOULDBLOCK", - context: &Readv{Dst: usermem.BytesIOSequence(make([]byte, 10))}, - flags: fs.FileFlags{Read: true}, - keepOpenPartner: true, - err: syserror.ErrWouldBlock, - }, - { - desc: "Writev on pipe from empty buffer returns nil", - context: &Writev{Src: usermem.BytesIOSequence(nil)}, - flags: fs.FileFlags{Write: true}, - }, - { - desc: "Writev on pipe from non-empty buffer and closed partner returns EPIPE", - context: &Writev{Src: usermem.BytesIOSequence([]byte("hello"))}, - flags: fs.FileFlags{Write: true}, - err: syscall.EPIPE, - }, - { - desc: "Writev on pipe from non-empty buffer and open partner succeeds", - context: &Writev{Src: usermem.BytesIOSequence([]byte("hello"))}, - flags: fs.FileFlags{Write: true}, - keepOpenPartner: true, - }, - } { - if test.flags.Read && test.flags.Write { - panic("both read and write not supported for this test") - } - - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - t.Errorf("%s: failed to create pipes: got %v, want nil", test.desc, err) - continue - } - - // Configure the fd and partner fd based on the file flags. - testFd, partnerFd := fds[0], fds[1] - if test.flags.Write { - testFd, partnerFd = fds[1], fds[0] - } - - // Configure closing the fds. - if test.keepOpenPartner { - defer syscall.Close(partnerFd) - } else { - syscall.Close(partnerFd) - } - - // Create the pipe. - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, test.flags, fd.New(testFd), nil) - if err != nil { - t.Fatalf("%s: newPipeOperations got error %v, want nil", test.desc, err) - } - defer p.Release() - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) - file := fs.NewFile(ctx, fs.NewDirent(inode, "pipe"), fs.FileFlags{Read: true}, p) - - // Issue request via the appropriate function. - switch c := test.context.(type) { - case *Seek: - _, err = p.Seek(ctx, file, 0, 0) - case *ReadDir: - _, err = p.Readdir(ctx, file, nil) - case *Readv: - _, err = p.Read(ctx, file, c.Dst, 0) - case *Writev: - _, err = p.Write(ctx, file, c.Src, 0) - case *Fsync: - err = p.Fsync(ctx, file, 0, fs.FileMaxOffset, fs.SyncAll) - default: - t.Errorf("%s: unknown request type %T", test.desc, test.context) - } - - if unwrapError(err) != test.err { - t.Errorf("%s: got error %v, want %v", test.desc, err, test.err) - } - } -} - -func TestPipeReadAheadBuffer(t *testing.T) { - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - rfile := fd.New(fds[0]) - - // Eventually close the write end, which is not wrapped in a pipe object. - defer syscall.Close(fds[1]) - - // Write some bytes to this end. - data := []byte("world") - if n, err := syscall.Write(fds[1], data); n != len(data) || err != nil { - rfile.Close() - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(data)) - } - // Close the write end immediately, we don't care about it. - - buffered := []byte("hello ") - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, rfile, buffered) - if err != nil { - rfile.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - defer p.Release() - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(inode, "pipe"), fs.FileFlags{Read: true}, p) - - // In total we expect to read data + buffered. - total := append(buffered, data...) - - buf := make([]byte, len(total)) - iov := usermem.BytesIOSequence(buf) - n, err := p.Read(contexttest.Context(t), file, iov, 0) - if err != nil { - t.Fatalf("read request got error %v, want nil", err) - } - if n != int64(len(total)) { - t.Fatalf("read request got %d bytes, want %d", n, len(total)) - } - if !bytes.Equal(buf, total) { - t.Errorf("read request got bytes [%v], want [%v]", buf, total) - } -} - -// This is very important for pipes in general because they can return EWOULDBLOCK and for -// those that block they must continue until they have read all of the data (and report it -// as such. -func TestPipeReadsAccumulate(t *testing.T) { - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - rfile := fd.New(fds[0]) - - // Eventually close the write end, it doesn't depend on a pipe object. - defer syscall.Close(fds[1]) - - // Get a new read only pipe reference. - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, rfile, nil) - if err != nil { - rfile.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - // Don't forget to remove the fd from the fd notifier. Otherwise other tests will - // likely be borked, because it's global :( - defer p.Release() - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(inode, "pipe"), fs.FileFlags{Read: true}, p) - - // Write some some bytes to the pipe. - data := []byte("some message") - if n, err := syscall.Write(fds[1], data); n != len(data) || err != nil { - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(data)) - } - - // Construct a segment vec that is a bit more than we have written so we trigger - // an EWOULDBLOCK. - wantBytes := len(data) + 1 - readBuffer := make([]byte, wantBytes) - iov := usermem.BytesIOSequence(readBuffer) - n, err := p.Read(ctx, file, iov, 0) - total := n - iov = iov.DropFirst64(n) - if err != syserror.ErrWouldBlock { - t.Fatalf("Readv got error %v, want %v", err, syserror.ErrWouldBlock) - } - - // Write a few more bytes to allow us to read more/accumulate. - extra := []byte("extra") - if n, err := syscall.Write(fds[1], extra); n != len(extra) || err != nil { - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(extra)) - } - - // This time, using the same request, we should not block. - n, err = p.Read(ctx, file, iov, 0) - total += n - if err != nil { - t.Fatalf("Readv got error %v, want nil", err) - } - - // Assert that the result we got back is cumulative. - if total != int64(wantBytes) { - t.Fatalf("Readv sequence got %d bytes, want %d", total, wantBytes) - } - - if want := append(data, extra[0]); !bytes.Equal(readBuffer, want) { - t.Errorf("Readv sequence got %v, want %v", readBuffer, want) - } -} - -// Same as TestReadsAccumulate. -func TestPipeWritesAccumulate(t *testing.T) { - fds := make([]int, 2) - if err := syscall.Pipe(fds); err != nil { - t.Fatalf("failed to create pipes: got %v, want nil", err) - } - wfile := fd.New(fds[1]) - - // Eventually close the read end, it doesn't depend on a pipe object. - defer syscall.Close(fds[0]) - - // Get a new write only pipe reference. - ctx := contexttest.Context(t) - p, err := newPipeOperations(ctx, nil, fs.FileFlags{Write: true}, wfile, nil) - if err != nil { - wfile.Close() - t.Fatalf("newPipeOperations got error %v, want nil", err) - } - // Don't forget to remove the fd from the fd notifier. Otherwise other tests will - // likely be borked, because it's global :( - defer p.Release() - - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{ - Type: fs.Pipe, - }) - file := fs.NewFile(ctx, fs.NewDirent(inode, "pipe"), fs.FileFlags{Read: true}, p) - - // Construct a segment vec that is larger than the pipe size to trigger an EWOULDBLOCK. - wantBytes := 65536 * 2 - writeBuffer := make([]byte, wantBytes) - for i := 0; i < wantBytes; i++ { - writeBuffer[i] = 'a' - } - iov := usermem.BytesIOSequence(writeBuffer) - n, err := p.Write(ctx, file, iov, 0) - total := n - iov = iov.DropFirst64(n) - if err != syserror.ErrWouldBlock { - t.Fatalf("Writev got error %v, want %v", err, syserror.ErrWouldBlock) - } - - // Read the entire pipe buf size to make space for the second half. - throwAway := make([]byte, 65536) - if n, err := syscall.Read(fds[0], throwAway); n != len(throwAway) || err != nil { - t.Fatalf("write to pipe got (%d, %v), want (%d, nil)", n, err, len(throwAway)) - } - - // This time we should not block. - n, err = p.Write(ctx, file, iov, 0) - total += n - if err != nil { - t.Fatalf("Writev got error %v, want nil", err) - } - - // Assert that the result we got back is cumulative. - if total != int64(wantBytes) { - t.Fatalf("Writev sequence got %d bytes, want %d", total, wantBytes) - } -} diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go deleted file mode 100644 index 1df826f6c..000000000 --- a/pkg/sentry/fs/file_overlay_test.go +++ /dev/null @@ -1,275 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" -) - -func TestReaddir(t *testing.T) { - ctx := contexttest.Context(t) - ctx = &rootContext{ - Context: ctx, - root: fs.NewDirent(newTestRamfsDir(ctx, nil, nil), "root"), - } - for _, test := range []struct { - // Test description. - desc string - - // Lookup parameters. - dir *fs.Inode - - // Want from lookup. - err error - names []string - }{ - { - desc: "no upper, lower has entries", - dir: fs.NewTestOverlayDir(ctx, - nil, /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - {name: "b"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "b"}, - }, - { - desc: "upper has entries, no lower", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - {name: "b"}, - }, nil), /* upper */ - nil, /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "b"}, - }, - { - desc: "upper and lower, entries combine", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "b"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "b"}, - }, - { - desc: "upper and lower, entries combine, none are masked", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - }, []string{"b"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "c"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "c"}, - }, - { - desc: "upper and lower, entries combine, upper masks some of lower", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - {name: "a"}, - }, []string{"b"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - {name: "b"}, /* will be masked */ - {name: "c"}, - }, nil), /* lower */ - false /* revalidate */), - names: []string{".", "..", "a", "c"}, - }, - } { - t.Run(test.desc, func(t *testing.T) { - openDir, err := test.dir.GetFile(ctx, fs.NewDirent(test.dir, "stub"), fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("GetFile got error %v, want nil", err) - } - stubSerializer := &fs.CollectEntriesSerializer{} - err = openDir.Readdir(ctx, stubSerializer) - if err != test.err { - t.Fatalf("Readdir got error %v, want nil", err) - } - if err != nil { - return - } - if !reflect.DeepEqual(stubSerializer.Order, test.names) { - t.Errorf("Readdir got names %v, want %v", stubSerializer.Order, test.names) - } - }) - } -} - -func TestReaddirRevalidation(t *testing.T) { - ctx := contexttest.Context(t) - ctx = &rootContext{ - Context: ctx, - root: fs.NewDirent(newTestRamfsDir(ctx, nil, nil), "root"), - } - - // Create an overlay with two directories, each with one file. - upper := newTestRamfsDir(ctx, []dirContent{{name: "a"}}, nil) - lower := newTestRamfsDir(ctx, []dirContent{{name: "b"}}, nil) - overlay := fs.NewTestOverlayDir(ctx, upper, lower, true /* revalidate */) - - // Get a handle to the dirent in the upper filesystem so that we can - // modify it without going through the dirent. - upperDir := upper.InodeOperations.(*dir).InodeOperations.(*ramfs.Dir) - - // Check that overlay returns the files from both upper and lower. - openDir, err := overlay.GetFile(ctx, fs.NewDirent(overlay, "stub"), fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("GetFile got error %v, want nil", err) - } - ser := &fs.CollectEntriesSerializer{} - if err := openDir.Readdir(ctx, ser); err != nil { - t.Fatalf("Readdir got error %v, want nil", err) - } - got, want := ser.Order, []string{".", "..", "a", "b"} - if !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got names %v, want %v", got, want) - } - - // Remove "a" from the upper and add "c". - if err := upperDir.Remove(ctx, upper, "a"); err != nil { - t.Fatalf("error removing child: %v", err) - } - upperDir.AddChild(ctx, "c", fs.NewInode(fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermissions{}, 0), - upper.MountSource, fs.StableAttr{Type: fs.RegularFile})) - - // Seek to beginning of the directory and do the readdir again. - if _, err := openDir.Seek(ctx, fs.SeekSet, 0); err != nil { - t.Fatalf("error seeking to beginning of dir: %v", err) - } - ser = &fs.CollectEntriesSerializer{} - if err := openDir.Readdir(ctx, ser); err != nil { - t.Fatalf("Readdir got error %v, want nil", err) - } - - // Readdir should return the updated children. - got, want = ser.Order, []string{".", "..", "b", "c"} - if !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got names %v, want %v", got, want) - } -} - -// TestReaddirOverlayFrozen tests that calling Readdir on an overlay file with -// a frozen dirent tree does not make Readdir calls to the underlying files. -func TestReaddirOverlayFrozen(t *testing.T) { - ctx := contexttest.Context(t) - - // Create an overlay with two directories, each with two files. - upper := newTestRamfsDir(ctx, []dirContent{{name: "upper-file1"}, {name: "upper-file2"}}, nil) - lower := newTestRamfsDir(ctx, []dirContent{{name: "lower-file1"}, {name: "lower-file2"}}, nil) - overlayInode := fs.NewTestOverlayDir(ctx, upper, lower, false) - - // Set that overlay as the root. - root := fs.NewDirent(overlayInode, "root") - ctx = &rootContext{ - Context: ctx, - root: root, - } - - // Check that calling Readdir on the root now returns all 4 files (2 - // from each layer in the overlay). - rootFile, err := root.Inode.GetFile(ctx, root, fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("root.Inode.GetFile failed: %v", err) - } - defer rootFile.DecRef() - ser := &fs.CollectEntriesSerializer{} - if err := rootFile.Readdir(ctx, ser); err != nil { - t.Fatalf("rootFile.Readdir failed: %v", err) - } - if got, want := ser.Order, []string{".", "..", "lower-file1", "lower-file2", "upper-file1", "upper-file2"}; !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got names %v, want %v", got, want) - } - - // Readdir should have been called on upper and lower. - upperDir := upper.InodeOperations.(*dir) - lowerDir := lower.InodeOperations.(*dir) - if !upperDir.ReaddirCalled { - t.Errorf("upperDir.ReaddirCalled got %v, want true", upperDir.ReaddirCalled) - } - if !lowerDir.ReaddirCalled { - t.Errorf("lowerDir.ReaddirCalled got %v, want true", lowerDir.ReaddirCalled) - } - - // Reset. - upperDir.ReaddirCalled = false - lowerDir.ReaddirCalled = false - - // Take references on "upper-file1" and "lower-file1", pinning them in - // the dirent tree. - for _, name := range []string{"upper-file1", "lower-file1"} { - if _, err := root.Walk(ctx, root, name); err != nil { - t.Fatalf("root.Walk(%q) failed: %v", name, err) - } - // Don't drop a reference on the returned dirent so that it - // will stay in the tree. - } - - // Freeze the dirent tree. - root.Freeze() - - // Seek back to the beginning of the file. - if _, err := rootFile.Seek(ctx, fs.SeekSet, 0); err != nil { - t.Fatalf("error seeking to beginning of directory: %v", err) - } - - // Calling Readdir on the root now will return only the pinned - // children. - ser = &fs.CollectEntriesSerializer{} - if err := rootFile.Readdir(ctx, ser); err != nil { - t.Fatalf("rootFile.Readdir failed: %v", err) - } - if got, want := ser.Order, []string{".", "..", "lower-file1", "upper-file1"}; !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got names %v, want %v", got, want) - } - - // Readdir should NOT have been called on upper or lower. - if upperDir.ReaddirCalled { - t.Errorf("upperDir.ReaddirCalled got %v, want false", upperDir.ReaddirCalled) - } - if lowerDir.ReaddirCalled { - t.Errorf("lowerDir.ReaddirCalled got %v, want false", lowerDir.ReaddirCalled) - } -} - -type rootContext struct { - context.Context - root *fs.Dirent -} - -// Value implements context.Context. -func (r *rootContext) Value(key interface{}) interface{} { - switch key { - case fs.CtxRoot: - r.root.IncRef() - return r.root - default: - return r.Context.Value(key) - } -} diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD deleted file mode 100644 index a9d6d9301..000000000 --- a/pkg/sentry/fs/filetest/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "filetest", - testonly = 1, - srcs = ["filetest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/filetest", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go deleted file mode 100644 index 856477480..000000000 --- a/pkg/sentry/fs/filetest/filetest.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package filetest provides a test implementation of an fs.File. -package filetest - -import ( - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/anon" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -// TestFileOperations is an implementation of the File interface. It provides all -// required methods. -type TestFileOperations struct { - fsutil.FileNoopRelease `state:"nosave"` - fsutil.FilePipeSeek `state:"nosave"` - fsutil.FileNotDirReaddir `state:"nosave"` - fsutil.FileNoFsync `state:"nosave"` - fsutil.FileNoopFlush `state:"nosave"` - fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoIoctl `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` - fsutil.FileUseInodeUnstableAttr `state:"nosave"` - waiter.AlwaysReady `state:"nosave"` -} - -// NewTestFile creates and initializes a new test file. -func NewTestFile(tb testing.TB) *fs.File { - ctx := contexttest.Context(tb) - dirent := fs.NewDirent(anon.NewInode(ctx), "test") - return fs.NewFile(ctx, dirent, fs.FileFlags{}, &TestFileOperations{}) -} - -// Read just fails the request. -func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Readv not implemented") -} - -// Write just fails the request. -func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Writev not implemented") -} diff --git a/pkg/sentry/fs/fs_state_autogen.go b/pkg/sentry/fs/fs_state_autogen.go new file mode 100755 index 000000000..f547ccd0f --- /dev/null +++ b/pkg/sentry/fs/fs_state_autogen.go @@ -0,0 +1,626 @@ +// automatically generated by stateify. + +package fs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *StableAttr) beforeSave() {} +func (x *StableAttr) save(m state.Map) { + x.beforeSave() + m.Save("Type", &x.Type) + m.Save("DeviceID", &x.DeviceID) + m.Save("InodeID", &x.InodeID) + m.Save("BlockSize", &x.BlockSize) + m.Save("DeviceFileMajor", &x.DeviceFileMajor) + m.Save("DeviceFileMinor", &x.DeviceFileMinor) +} + +func (x *StableAttr) afterLoad() {} +func (x *StableAttr) load(m state.Map) { + m.Load("Type", &x.Type) + m.Load("DeviceID", &x.DeviceID) + m.Load("InodeID", &x.InodeID) + m.Load("BlockSize", &x.BlockSize) + m.Load("DeviceFileMajor", &x.DeviceFileMajor) + m.Load("DeviceFileMinor", &x.DeviceFileMinor) +} + +func (x *UnstableAttr) beforeSave() {} +func (x *UnstableAttr) save(m state.Map) { + x.beforeSave() + m.Save("Size", &x.Size) + m.Save("Usage", &x.Usage) + m.Save("Perms", &x.Perms) + m.Save("Owner", &x.Owner) + m.Save("AccessTime", &x.AccessTime) + m.Save("ModificationTime", &x.ModificationTime) + m.Save("StatusChangeTime", &x.StatusChangeTime) + m.Save("Links", &x.Links) +} + +func (x *UnstableAttr) afterLoad() {} +func (x *UnstableAttr) load(m state.Map) { + m.Load("Size", &x.Size) + m.Load("Usage", &x.Usage) + m.Load("Perms", &x.Perms) + m.Load("Owner", &x.Owner) + m.Load("AccessTime", &x.AccessTime) + m.Load("ModificationTime", &x.ModificationTime) + m.Load("StatusChangeTime", &x.StatusChangeTime) + m.Load("Links", &x.Links) +} + +func (x *AttrMask) beforeSave() {} +func (x *AttrMask) save(m state.Map) { + x.beforeSave() + m.Save("Type", &x.Type) + m.Save("DeviceID", &x.DeviceID) + m.Save("InodeID", &x.InodeID) + m.Save("BlockSize", &x.BlockSize) + m.Save("Size", &x.Size) + m.Save("Usage", &x.Usage) + m.Save("Perms", &x.Perms) + m.Save("UID", &x.UID) + m.Save("GID", &x.GID) + m.Save("AccessTime", &x.AccessTime) + m.Save("ModificationTime", &x.ModificationTime) + m.Save("StatusChangeTime", &x.StatusChangeTime) + m.Save("Links", &x.Links) +} + +func (x *AttrMask) afterLoad() {} +func (x *AttrMask) load(m state.Map) { + m.Load("Type", &x.Type) + m.Load("DeviceID", &x.DeviceID) + m.Load("InodeID", &x.InodeID) + m.Load("BlockSize", &x.BlockSize) + m.Load("Size", &x.Size) + m.Load("Usage", &x.Usage) + m.Load("Perms", &x.Perms) + m.Load("UID", &x.UID) + m.Load("GID", &x.GID) + m.Load("AccessTime", &x.AccessTime) + m.Load("ModificationTime", &x.ModificationTime) + m.Load("StatusChangeTime", &x.StatusChangeTime) + m.Load("Links", &x.Links) +} + +func (x *PermMask) beforeSave() {} +func (x *PermMask) save(m state.Map) { + x.beforeSave() + m.Save("Read", &x.Read) + m.Save("Write", &x.Write) + m.Save("Execute", &x.Execute) +} + +func (x *PermMask) afterLoad() {} +func (x *PermMask) load(m state.Map) { + m.Load("Read", &x.Read) + m.Load("Write", &x.Write) + m.Load("Execute", &x.Execute) +} + +func (x *FilePermissions) beforeSave() {} +func (x *FilePermissions) save(m state.Map) { + x.beforeSave() + m.Save("User", &x.User) + m.Save("Group", &x.Group) + m.Save("Other", &x.Other) + m.Save("Sticky", &x.Sticky) + m.Save("SetUID", &x.SetUID) + m.Save("SetGID", &x.SetGID) +} + +func (x *FilePermissions) afterLoad() {} +func (x *FilePermissions) load(m state.Map) { + m.Load("User", &x.User) + m.Load("Group", &x.Group) + m.Load("Other", &x.Other) + m.Load("Sticky", &x.Sticky) + m.Load("SetUID", &x.SetUID) + m.Load("SetGID", &x.SetGID) +} + +func (x *FileOwner) beforeSave() {} +func (x *FileOwner) save(m state.Map) { + x.beforeSave() + m.Save("UID", &x.UID) + m.Save("GID", &x.GID) +} + +func (x *FileOwner) afterLoad() {} +func (x *FileOwner) load(m state.Map) { + m.Load("UID", &x.UID) + m.Load("GID", &x.GID) +} + +func (x *DentAttr) beforeSave() {} +func (x *DentAttr) save(m state.Map) { + x.beforeSave() + m.Save("Type", &x.Type) + m.Save("InodeID", &x.InodeID) +} + +func (x *DentAttr) afterLoad() {} +func (x *DentAttr) load(m state.Map) { + m.Load("Type", &x.Type) + m.Load("InodeID", &x.InodeID) +} + +func (x *SortedDentryMap) beforeSave() {} +func (x *SortedDentryMap) save(m state.Map) { + x.beforeSave() + m.Save("names", &x.names) + m.Save("entries", &x.entries) +} + +func (x *SortedDentryMap) afterLoad() {} +func (x *SortedDentryMap) load(m state.Map) { + m.Load("names", &x.names) + m.Load("entries", &x.entries) +} + +func (x *Dirent) save(m state.Map) { + x.beforeSave() + var children map[string]*Dirent = x.saveChildren() + m.SaveValue("children", children) + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("userVisible", &x.userVisible) + m.Save("Inode", &x.Inode) + m.Save("name", &x.name) + m.Save("parent", &x.parent) + m.Save("deleted", &x.deleted) + m.Save("frozen", &x.frozen) + m.Save("mounted", &x.mounted) +} + +func (x *Dirent) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("userVisible", &x.userVisible) + m.Load("Inode", &x.Inode) + m.Load("name", &x.name) + m.Load("parent", &x.parent) + m.Load("deleted", &x.deleted) + m.Load("frozen", &x.frozen) + m.Load("mounted", &x.mounted) + m.LoadValue("children", new(map[string]*Dirent), func(y interface{}) { x.loadChildren(y.(map[string]*Dirent)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *DirentCache) beforeSave() {} +func (x *DirentCache) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.currentSize) { m.Failf("currentSize is %v, expected zero", x.currentSize) } + if !state.IsZeroValue(x.list) { m.Failf("list is %v, expected zero", x.list) } + m.Save("maxSize", &x.maxSize) + m.Save("limit", &x.limit) +} + +func (x *DirentCache) afterLoad() {} +func (x *DirentCache) load(m state.Map) { + m.Load("maxSize", &x.maxSize) + m.Load("limit", &x.limit) +} + +func (x *DirentCacheLimiter) beforeSave() {} +func (x *DirentCacheLimiter) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.count) { m.Failf("count is %v, expected zero", x.count) } + m.Save("max", &x.max) +} + +func (x *DirentCacheLimiter) afterLoad() {} +func (x *DirentCacheLimiter) load(m state.Map) { + m.Load("max", &x.max) +} + +func (x *direntList) beforeSave() {} +func (x *direntList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *direntList) afterLoad() {} +func (x *direntList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *direntEntry) beforeSave() {} +func (x *direntEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *direntEntry) afterLoad() {} +func (x *direntEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *eventList) beforeSave() {} +func (x *eventList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *eventList) afterLoad() {} +func (x *eventList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *eventEntry) beforeSave() {} +func (x *eventEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *eventEntry) afterLoad() {} +func (x *eventEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *File) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("UniqueID", &x.UniqueID) + m.Save("Dirent", &x.Dirent) + m.Save("flags", &x.flags) + m.Save("async", &x.async) + m.Save("FileOperations", &x.FileOperations) + m.Save("offset", &x.offset) +} + +func (x *File) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("UniqueID", &x.UniqueID) + m.Load("Dirent", &x.Dirent) + m.Load("flags", &x.flags) + m.Load("async", &x.async) + m.LoadWait("FileOperations", &x.FileOperations) + m.Load("offset", &x.offset) + m.AfterLoad(x.afterLoad) +} + +func (x *overlayFileOperations) beforeSave() {} +func (x *overlayFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("upper", &x.upper) + m.Save("lower", &x.lower) + m.Save("dirCursor", &x.dirCursor) + m.Save("dirCache", &x.dirCache) +} + +func (x *overlayFileOperations) afterLoad() {} +func (x *overlayFileOperations) load(m state.Map) { + m.Load("upper", &x.upper) + m.Load("lower", &x.lower) + m.Load("dirCursor", &x.dirCursor) + m.Load("dirCache", &x.dirCache) +} + +func (x *overlayMappingIdentity) beforeSave() {} +func (x *overlayMappingIdentity) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("id", &x.id) + m.Save("overlayFile", &x.overlayFile) +} + +func (x *overlayMappingIdentity) afterLoad() {} +func (x *overlayMappingIdentity) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("id", &x.id) + m.Load("overlayFile", &x.overlayFile) +} + +func (x *MountSourceFlags) beforeSave() {} +func (x *MountSourceFlags) save(m state.Map) { + x.beforeSave() + m.Save("ReadOnly", &x.ReadOnly) + m.Save("NoAtime", &x.NoAtime) + m.Save("ForcePageCache", &x.ForcePageCache) + m.Save("NoExec", &x.NoExec) +} + +func (x *MountSourceFlags) afterLoad() {} +func (x *MountSourceFlags) load(m state.Map) { + m.Load("ReadOnly", &x.ReadOnly) + m.Load("NoAtime", &x.NoAtime) + m.Load("ForcePageCache", &x.ForcePageCache) + m.Load("NoExec", &x.NoExec) +} + +func (x *FileFlags) beforeSave() {} +func (x *FileFlags) save(m state.Map) { + x.beforeSave() + m.Save("Direct", &x.Direct) + m.Save("NonBlocking", &x.NonBlocking) + m.Save("Sync", &x.Sync) + m.Save("Append", &x.Append) + m.Save("Read", &x.Read) + m.Save("Write", &x.Write) + m.Save("Pread", &x.Pread) + m.Save("Pwrite", &x.Pwrite) + m.Save("Directory", &x.Directory) + m.Save("Async", &x.Async) + m.Save("LargeFile", &x.LargeFile) +} + +func (x *FileFlags) afterLoad() {} +func (x *FileFlags) load(m state.Map) { + m.Load("Direct", &x.Direct) + m.Load("NonBlocking", &x.NonBlocking) + m.Load("Sync", &x.Sync) + m.Load("Append", &x.Append) + m.Load("Read", &x.Read) + m.Load("Write", &x.Write) + m.Load("Pread", &x.Pread) + m.Load("Pwrite", &x.Pwrite) + m.Load("Directory", &x.Directory) + m.Load("Async", &x.Async) + m.Load("LargeFile", &x.LargeFile) +} + +func (x *Inode) beforeSave() {} +func (x *Inode) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("InodeOperations", &x.InodeOperations) + m.Save("StableAttr", &x.StableAttr) + m.Save("LockCtx", &x.LockCtx) + m.Save("Watches", &x.Watches) + m.Save("MountSource", &x.MountSource) + m.Save("overlay", &x.overlay) +} + +func (x *Inode) afterLoad() {} +func (x *Inode) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("InodeOperations", &x.InodeOperations) + m.Load("StableAttr", &x.StableAttr) + m.Load("LockCtx", &x.LockCtx) + m.Load("Watches", &x.Watches) + m.Load("MountSource", &x.MountSource) + m.Load("overlay", &x.overlay) +} + +func (x *LockCtx) beforeSave() {} +func (x *LockCtx) save(m state.Map) { + x.beforeSave() + m.Save("Posix", &x.Posix) + m.Save("BSD", &x.BSD) +} + +func (x *LockCtx) afterLoad() {} +func (x *LockCtx) load(m state.Map) { + m.Load("Posix", &x.Posix) + m.Load("BSD", &x.BSD) +} + +func (x *Watches) beforeSave() {} +func (x *Watches) save(m state.Map) { + x.beforeSave() + m.Save("ws", &x.ws) + m.Save("unlinked", &x.unlinked) +} + +func (x *Watches) afterLoad() {} +func (x *Watches) load(m state.Map) { + m.Load("ws", &x.ws) + m.Load("unlinked", &x.unlinked) +} + +func (x *Inotify) beforeSave() {} +func (x *Inotify) save(m state.Map) { + x.beforeSave() + m.Save("id", &x.id) + m.Save("events", &x.events) + m.Save("scratch", &x.scratch) + m.Save("nextWatch", &x.nextWatch) + m.Save("watches", &x.watches) +} + +func (x *Inotify) afterLoad() {} +func (x *Inotify) load(m state.Map) { + m.Load("id", &x.id) + m.Load("events", &x.events) + m.Load("scratch", &x.scratch) + m.Load("nextWatch", &x.nextWatch) + m.Load("watches", &x.watches) +} + +func (x *Event) beforeSave() {} +func (x *Event) save(m state.Map) { + x.beforeSave() + m.Save("eventEntry", &x.eventEntry) + m.Save("wd", &x.wd) + m.Save("mask", &x.mask) + m.Save("cookie", &x.cookie) + m.Save("len", &x.len) + m.Save("name", &x.name) +} + +func (x *Event) afterLoad() {} +func (x *Event) load(m state.Map) { + m.Load("eventEntry", &x.eventEntry) + m.Load("wd", &x.wd) + m.Load("mask", &x.mask) + m.Load("cookie", &x.cookie) + m.Load("len", &x.len) + m.Load("name", &x.name) +} + +func (x *Watch) beforeSave() {} +func (x *Watch) save(m state.Map) { + x.beforeSave() + m.Save("owner", &x.owner) + m.Save("wd", &x.wd) + m.Save("target", &x.target) + m.Save("unpinned", &x.unpinned) + m.Save("mask", &x.mask) + m.Save("pins", &x.pins) +} + +func (x *Watch) afterLoad() {} +func (x *Watch) load(m state.Map) { + m.Load("owner", &x.owner) + m.Load("wd", &x.wd) + m.Load("target", &x.target) + m.Load("unpinned", &x.unpinned) + m.Load("mask", &x.mask) + m.Load("pins", &x.pins) +} + +func (x *MountSource) beforeSave() {} +func (x *MountSource) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("MountSourceOperations", &x.MountSourceOperations) + m.Save("FilesystemType", &x.FilesystemType) + m.Save("Flags", &x.Flags) + m.Save("fscache", &x.fscache) + m.Save("direntRefs", &x.direntRefs) +} + +func (x *MountSource) afterLoad() {} +func (x *MountSource) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("MountSourceOperations", &x.MountSourceOperations) + m.Load("FilesystemType", &x.FilesystemType) + m.Load("Flags", &x.Flags) + m.Load("fscache", &x.fscache) + m.Load("direntRefs", &x.direntRefs) +} + +func (x *SimpleMountSourceOperations) beforeSave() {} +func (x *SimpleMountSourceOperations) save(m state.Map) { + x.beforeSave() + m.Save("keep", &x.keep) + m.Save("revalidate", &x.revalidate) +} + +func (x *SimpleMountSourceOperations) afterLoad() {} +func (x *SimpleMountSourceOperations) load(m state.Map) { + m.Load("keep", &x.keep) + m.Load("revalidate", &x.revalidate) +} + +func (x *overlayMountSourceOperations) beforeSave() {} +func (x *overlayMountSourceOperations) save(m state.Map) { + x.beforeSave() + m.Save("upper", &x.upper) + m.Save("lower", &x.lower) +} + +func (x *overlayMountSourceOperations) afterLoad() {} +func (x *overlayMountSourceOperations) load(m state.Map) { + m.Load("upper", &x.upper) + m.Load("lower", &x.lower) +} + +func (x *overlayFilesystem) beforeSave() {} +func (x *overlayFilesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *overlayFilesystem) afterLoad() {} +func (x *overlayFilesystem) load(m state.Map) { +} + +func (x *Mount) beforeSave() {} +func (x *Mount) save(m state.Map) { + x.beforeSave() + m.Save("ID", &x.ID) + m.Save("ParentID", &x.ParentID) + m.Save("root", &x.root) + m.Save("previous", &x.previous) +} + +func (x *Mount) afterLoad() {} +func (x *Mount) load(m state.Map) { + m.Load("ID", &x.ID) + m.Load("ParentID", &x.ParentID) + m.Load("root", &x.root) + m.Load("previous", &x.previous) +} + +func (x *MountNamespace) beforeSave() {} +func (x *MountNamespace) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("userns", &x.userns) + m.Save("root", &x.root) + m.Save("mounts", &x.mounts) + m.Save("mountID", &x.mountID) +} + +func (x *MountNamespace) afterLoad() {} +func (x *MountNamespace) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("userns", &x.userns) + m.Load("root", &x.root) + m.Load("mounts", &x.mounts) + m.Load("mountID", &x.mountID) +} + +func (x *overlayEntry) beforeSave() {} +func (x *overlayEntry) save(m state.Map) { + x.beforeSave() + m.Save("lowerExists", &x.lowerExists) + m.Save("lower", &x.lower) + m.Save("mappings", &x.mappings) + m.Save("upper", &x.upper) +} + +func (x *overlayEntry) afterLoad() {} +func (x *overlayEntry) load(m state.Map) { + m.Load("lowerExists", &x.lowerExists) + m.Load("lower", &x.lower) + m.Load("mappings", &x.mappings) + m.Load("upper", &x.upper) +} + +func init() { + state.Register("fs.StableAttr", (*StableAttr)(nil), state.Fns{Save: (*StableAttr).save, Load: (*StableAttr).load}) + state.Register("fs.UnstableAttr", (*UnstableAttr)(nil), state.Fns{Save: (*UnstableAttr).save, Load: (*UnstableAttr).load}) + state.Register("fs.AttrMask", (*AttrMask)(nil), state.Fns{Save: (*AttrMask).save, Load: (*AttrMask).load}) + state.Register("fs.PermMask", (*PermMask)(nil), state.Fns{Save: (*PermMask).save, Load: (*PermMask).load}) + state.Register("fs.FilePermissions", (*FilePermissions)(nil), state.Fns{Save: (*FilePermissions).save, Load: (*FilePermissions).load}) + state.Register("fs.FileOwner", (*FileOwner)(nil), state.Fns{Save: (*FileOwner).save, Load: (*FileOwner).load}) + state.Register("fs.DentAttr", (*DentAttr)(nil), state.Fns{Save: (*DentAttr).save, Load: (*DentAttr).load}) + state.Register("fs.SortedDentryMap", (*SortedDentryMap)(nil), state.Fns{Save: (*SortedDentryMap).save, Load: (*SortedDentryMap).load}) + state.Register("fs.Dirent", (*Dirent)(nil), state.Fns{Save: (*Dirent).save, Load: (*Dirent).load}) + state.Register("fs.DirentCache", (*DirentCache)(nil), state.Fns{Save: (*DirentCache).save, Load: (*DirentCache).load}) + state.Register("fs.DirentCacheLimiter", (*DirentCacheLimiter)(nil), state.Fns{Save: (*DirentCacheLimiter).save, Load: (*DirentCacheLimiter).load}) + state.Register("fs.direntList", (*direntList)(nil), state.Fns{Save: (*direntList).save, Load: (*direntList).load}) + state.Register("fs.direntEntry", (*direntEntry)(nil), state.Fns{Save: (*direntEntry).save, Load: (*direntEntry).load}) + state.Register("fs.eventList", (*eventList)(nil), state.Fns{Save: (*eventList).save, Load: (*eventList).load}) + state.Register("fs.eventEntry", (*eventEntry)(nil), state.Fns{Save: (*eventEntry).save, Load: (*eventEntry).load}) + state.Register("fs.File", (*File)(nil), state.Fns{Save: (*File).save, Load: (*File).load}) + state.Register("fs.overlayFileOperations", (*overlayFileOperations)(nil), state.Fns{Save: (*overlayFileOperations).save, Load: (*overlayFileOperations).load}) + state.Register("fs.overlayMappingIdentity", (*overlayMappingIdentity)(nil), state.Fns{Save: (*overlayMappingIdentity).save, Load: (*overlayMappingIdentity).load}) + state.Register("fs.MountSourceFlags", (*MountSourceFlags)(nil), state.Fns{Save: (*MountSourceFlags).save, Load: (*MountSourceFlags).load}) + state.Register("fs.FileFlags", (*FileFlags)(nil), state.Fns{Save: (*FileFlags).save, Load: (*FileFlags).load}) + state.Register("fs.Inode", (*Inode)(nil), state.Fns{Save: (*Inode).save, Load: (*Inode).load}) + state.Register("fs.LockCtx", (*LockCtx)(nil), state.Fns{Save: (*LockCtx).save, Load: (*LockCtx).load}) + state.Register("fs.Watches", (*Watches)(nil), state.Fns{Save: (*Watches).save, Load: (*Watches).load}) + state.Register("fs.Inotify", (*Inotify)(nil), state.Fns{Save: (*Inotify).save, Load: (*Inotify).load}) + state.Register("fs.Event", (*Event)(nil), state.Fns{Save: (*Event).save, Load: (*Event).load}) + state.Register("fs.Watch", (*Watch)(nil), state.Fns{Save: (*Watch).save, Load: (*Watch).load}) + state.Register("fs.MountSource", (*MountSource)(nil), state.Fns{Save: (*MountSource).save, Load: (*MountSource).load}) + state.Register("fs.SimpleMountSourceOperations", (*SimpleMountSourceOperations)(nil), state.Fns{Save: (*SimpleMountSourceOperations).save, Load: (*SimpleMountSourceOperations).load}) + state.Register("fs.overlayMountSourceOperations", (*overlayMountSourceOperations)(nil), state.Fns{Save: (*overlayMountSourceOperations).save, Load: (*overlayMountSourceOperations).load}) + state.Register("fs.overlayFilesystem", (*overlayFilesystem)(nil), state.Fns{Save: (*overlayFilesystem).save, Load: (*overlayFilesystem).load}) + state.Register("fs.Mount", (*Mount)(nil), state.Fns{Save: (*Mount).save, Load: (*Mount).load}) + state.Register("fs.MountNamespace", (*MountNamespace)(nil), state.Fns{Save: (*MountNamespace).save, Load: (*MountNamespace).load}) + state.Register("fs.overlayEntry", (*overlayEntry)(nil), state.Fns{Save: (*overlayEntry).save, Load: (*overlayEntry).load}) +} diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD deleted file mode 100644 index 6499f87ac..000000000 --- a/pkg/sentry/fs/fsutil/BUILD +++ /dev/null @@ -1,118 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "dirty_set_impl", - out = "dirty_set_impl.go", - imports = { - "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", - }, - package = "fsutil", - prefix = "Dirty", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "memmap.MappableRange", - "Value": "DirtyInfo", - "Functions": "dirtySetFunctions", - }, -) - -go_template_instance( - name = "frame_ref_set_impl", - out = "frame_ref_set_impl.go", - imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", - }, - package = "fsutil", - prefix = "frameRef", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "platform.FileRange", - "Value": "uint64", - "Functions": "frameRefSetFunctions", - }, -) - -go_template_instance( - name = "file_range_set_impl", - out = "file_range_set_impl.go", - imports = { - "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", - }, - package = "fsutil", - prefix = "FileRange", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "memmap.MappableRange", - "Value": "uint64", - "Functions": "fileRangeSetFunctions", - }, -) - -go_library( - name = "fsutil", - srcs = [ - "dirty_set.go", - "dirty_set_impl.go", - "file.go", - "file_range_set.go", - "file_range_set_impl.go", - "frame_ref_set.go", - "frame_ref_set_impl.go", - "fsutil.go", - "host_file_mapper.go", - "host_file_mapper_state.go", - "host_file_mapper_unsafe.go", - "host_mappable.go", - "inode.go", - "inode_cached.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/fsutil", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/safemem", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/state", - "//pkg/syserror", - "//pkg/waiter", - ], -) - -go_test( - name = "fsutil_test", - size = "small", - srcs = [ - "dirty_set_test.go", - "inode_cached_test.go", - ], - embed = [":fsutil"], - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/safemem", - "//pkg/sentry/usermem", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/fs/fsutil/README.md b/pkg/sentry/fs/fsutil/README.md deleted file mode 100644 index 8be367334..000000000 --- a/pkg/sentry/fs/fsutil/README.md +++ /dev/null @@ -1,207 +0,0 @@ -This package provides utilities for implementing virtual filesystem objects. - -[TOC] - -## Page cache - -`CachingInodeOperations` implements a page cache for files that cannot use the -host page cache. Normally these are files that store their data in a remote -filesystem. This also applies to files that are accessed on a platform that does -not support directly memory mapping host file descriptors (e.g. the ptrace -platform). - -An `CachingInodeOperations` buffers regions of a single file into memory. It is -owned by an `fs.Inode`, the in-memory representation of a file (all open file -descriptors are backed by an `fs.Inode`). The `fs.Inode` provides operations for -reading memory into an `CachingInodeOperations`, to represent the contents of -the file in-memory, and for writing memory out, to relieve memory pressure on -the kernel and to synchronize in-memory changes to filesystems. - -An `CachingInodeOperations` enables readable and/or writable memory access to -file content. Files can be mapped shared or private, see mmap(2). When a file is -mapped shared, changes to the file via write(2) and truncate(2) are reflected in -the shared memory region. Conversely, when the shared memory region is modified, -changes to the file are visible via read(2). Multiple shared mappings of the -same file are coherent with each other. This is consistent with Linux. - -When a file is mapped private, updates to the mapped memory are not visible to -other memory mappings. Updates to the mapped memory are also not reflected in -the file content as seen by read(2). If the file is changed after a private -mapping is created, for instance by write(2), the change to the file may or may -not be reflected in the private mapping. This is consistent with Linux. - -An `CachingInodeOperations` keeps track of ranges of memory that were modified -(or "dirtied"). When the file is explicitly synced via fsync(2), only the dirty -ranges are written out to the filesystem. Any error returned indicates a failure -to write all dirty memory of an `CachingInodeOperations` to the filesystem. In -this case the filesystem may be in an inconsistent state. The same operation can -be performed on the shared memory itself using msync(2). If neither fsync(2) nor -msync(2) is performed, then the dirty memory is written out in accordance with -the `CachingInodeOperations` eviction strategy (see below) and there is no -guarantee that memory will be written out successfully in full. - -### Memory allocation and eviction - -An `CachingInodeOperations` implements the following allocation and eviction -strategy: - -- Memory is allocated and brought up to date with the contents of a file when - a region of mapped memory is accessed (or "faulted on"). - -- Dirty memory is written out to filesystems when an fsync(2) or msync(2) - operation is performed on a memory mapped file, for all memory mapped files - when saved, and/or when there are no longer any memory mappings of a range - of a file, see munmap(2). As the latter implies, in the absence of a panic - or SIGKILL, dirty memory is written out for all memory mapped files when an - application exits. - -- Memory is freed when there are no longer any memory mappings of a range of a - file (e.g. when an application exits). This behavior is consistent with - Linux for shared memory that has been locked via mlock(2). - -Notably, memory is not allocated for read(2) or write(2) operations. This means -that reads and writes to the file are only accelerated by an -`CachingInodeOperations` if the file being read or written has been memory -mapped *and* if the shared memory has been accessed at the region being read or -written. This diverges from Linux which buffers memory into a page cache on -read(2) proactively (i.e. readahead) and delays writing it out to filesystems on -write(2) (i.e. writeback). The absence of these optimizations is not visible to -applications beyond less than optimal performance when repeatedly reading and/or -writing to same region of a file. See [Future Work](#future-work) for plans to -implement these optimizations. - -Additionally, memory held by `CachingInodeOperationss` is currently unbounded in -size. An `CachingInodeOperations` does not write out dirty memory and free it -under system memory pressure. This can cause pathological memory usage. - -When memory is written back, an `CachingInodeOperations` may write regions of -shared memory that were never modified. This is due to the strategy of -minimizing page faults (see below) and handling only a subset of memory write -faults. In the absence of an application or sentry crash, it is guaranteed that -if a region of shared memory was written to, it is written back to a filesystem. - -### Life of a shared memory mapping - -A file is memory mapped via mmap(2). For example, if `A` is an address, an -application may execute: - -``` -mmap(A, 0x1000, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0); -``` - -This creates a shared mapping of fd that reflects 4k of the contents of fd -starting at offset 0, accessible at address `A`. This in turn creates a virtual -memory area region ("vma") which indicates that [`A`, `A`+0x1000) is now a valid -address range for this application to access. - -At this point, memory has not been allocated in the file's -`CachingInodeOperations`. It is also the case that the address range [`A`, -`A`+0x1000) has not been mapped on the host on behalf of the application. If the -application then tries to modify 8 bytes of the shared memory: - -``` -char buffer[] = "aaaaaaaa"; -memcpy(A, buffer, 8); -``` - -The host then sends a `SIGSEGV` to the sentry because the address range [`A`, -`A`+8) is not mapped on the host. The `SIGSEGV` indicates that the memory was -accessed writable. The sentry looks up the vma associated with [`A`, `A`+8), -finds the file that was mapped and its `CachingInodeOperations`. It then calls -`CachingInodeOperations.Translate` which allocates memory to back [`A`, `A`+8). -It may choose to allocate more memory (i.e. do "readahead") to minimize -subsequent faults. - -Memory that is allocated comes from a host tmpfs file (see -`pgalloc.MemoryFile`). The host tmpfs file memory is brought up to date with the -contents of the mapped file on its filesystem. The region of the host tmpfs file -that reflects the mapped file is then mapped into the host address space of the -application so that subsequent memory accesses do not repeatedly generate a -`SIGSEGV`. - -The range that was allocated, including any extra memory allocation to minimize -faults, is marked dirty due to the write fault. This overcounts dirty memory if -the extra memory allocated is never modified. - -To make the scenario more interesting, imagine that this application spawns -another process and maps the same file in the exact same way: - -``` -mmap(A, 0x1000, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0); -``` - -Imagine that this process then tries to modify the file again but with only 4 -bytes: - -``` -char buffer[] = "bbbb"; -memcpy(A, buffer, 4); -``` - -Since the first process has already mapped and accessed the same region of the -file writable, `CachingInodeOperations.Translate` is called but returns the -memory that has already been allocated rather than allocating new memory. The -address range [`A`, `A`+0x1000) reflects the same cached view of the file as the -first process sees. For example, reading 8 bytes from the file from either -process via read(2) starting at offset 0 returns a consistent "bbbbaaaa". - -When this process no longer needs the shared memory, it may do: - -``` -munmap(A, 0x1000); -``` - -At this point, the modified memory cached by the `CachingInodeOperations` is not -written back to the file because it is still in use by the first process that -mapped it. When the first process also does: - -``` -munmap(A, 0x1000); -``` - -Then the last memory mapping of the file at the range [0, 0x1000) is gone. The -file's `CachingInodeOperations` then starts writing back memory marked dirty to -the file on its filesystem. Once writing completes, regardless of whether it was -successful, the `CachingInodeOperations` frees the memory cached at the range -[0, 0x1000). - -Subsequent read(2) or write(2) operations on the file go directly to the -filesystem since there no longer exists memory for it in its -`CachingInodeOperations`. - -## Future Work - -### Page cache - -The sentry does not yet implement the readahead and writeback optimizations for -read(2) and write(2) respectively. To do so, on read(2) and/or write(2) the -sentry must ensure that memory is allocated in a page cache to read or write -into. However, the sentry cannot boundlessly allocate memory. If it did, the -host would eventually OOM-kill the sentry+application process. This means that -the sentry must implement a page cache memory allocation strategy that is -bounded by a global user or container imposed limit. When this limit is -approached, the sentry must decide from which page cache memory should be freed -so that it can allocate more memory. If it makes a poor decision, the sentry may -end up freeing and re-allocating memory to back regions of files that are -frequently used, nullifying the optimization (and in some cases causing worse -performance due to the overhead of memory allocation and general management). -This is a form of "cache thrashing". - -In Linux, much research has been done to select and implement a lightweight but -optimal page cache eviction algorithm. Linux makes use of hardware page bits to -keep track of whether memory has been accessed. The sentry does not have direct -access to hardware. Implementing a similarly lightweight and optimal page cache -eviction algorithm will need to either introduce a kernel interface to obtain -these page bits or find a suitable alternative proxy for access events. - -In Linux, readahead happens by default but is not always ideal. For instance, -for files that are not read sequentially, it would be more ideal to simply read -from only those regions of the file rather than to optimistically cache some -number of bytes ahead of the read (up to 2MB in Linux) if the bytes cached won't -be accessed. Linux implements the fadvise64(2) system call for applications to -specify that a range of a file will not be accessed sequentially. The advice bit -FADV_RANDOM turns off the readahead optimization for the given range in the -given file. However fadvise64 is rarely used by applications so Linux implements -a readahead backoff strategy if reads are not sequential. To ensure that -application performance is not degraded, the sentry must implement a similar -backoff strategy. diff --git a/pkg/sentry/fs/fsutil/dirty_set_impl.go b/pkg/sentry/fs/fsutil/dirty_set_impl.go new file mode 100755 index 000000000..588693f31 --- /dev/null +++ b/pkg/sentry/fs/fsutil/dirty_set_impl.go @@ -0,0 +1,1274 @@ +package fsutil + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/memmap" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + DirtyminDegree = 3 + + DirtymaxDegree = 2 * DirtyminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type DirtySet struct { + root Dirtynode `state:".(*DirtySegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *DirtySet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *DirtySet) IsEmptyRange(r __generics_imported0.MappableRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *DirtySet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *DirtySet) SpanRange(r __generics_imported0.MappableRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *DirtySet) FirstSegment() DirtyIterator { + if s.root.nrSegments == 0 { + return DirtyIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *DirtySet) LastSegment() DirtyIterator { + if s.root.nrSegments == 0 { + return DirtyIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *DirtySet) FirstGap() DirtyGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return DirtyGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *DirtySet) LastGap() DirtyGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return DirtyGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *DirtySet) Find(key uint64) (DirtyIterator, DirtyGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return DirtyIterator{n, i}, DirtyGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return DirtyIterator{}, DirtyGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *DirtySet) FindSegment(key uint64) DirtyIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *DirtySet) LowerBoundSegment(min uint64) DirtyIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *DirtySet) UpperBoundSegment(max uint64) DirtyIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *DirtySet) FindGap(key uint64) DirtyGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *DirtySet) LowerBoundGap(min uint64) DirtyGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *DirtySet) UpperBoundGap(max uint64) DirtyGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *DirtySet) Add(r __generics_imported0.MappableRange, val DirtyInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *DirtySet) AddWithoutMerging(r __generics_imported0.MappableRange, val DirtyInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *DirtySet) Insert(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (dirtySetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (dirtySetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (dirtySetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *DirtySet) InsertWithoutMerging(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *DirtySet) InsertWithoutMergingUnchecked(gap DirtyGapIterator, r __generics_imported0.MappableRange, val DirtyInfo) DirtyIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return DirtyIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *DirtySet) Remove(seg DirtyIterator) DirtyGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + dirtySetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(DirtyGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *DirtySet) RemoveAll() { + s.root = Dirtynode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *DirtySet) RemoveRange(r __generics_imported0.MappableRange) DirtyGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *DirtySet) Merge(first, second DirtyIterator) DirtyIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *DirtySet) MergeUnchecked(first, second DirtyIterator) DirtyIterator { + if first.End() == second.Start() { + if mval, ok := (dirtySetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return DirtyIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *DirtySet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *DirtySet) MergeRange(r __generics_imported0.MappableRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *DirtySet) MergeAdjacent(r __generics_imported0.MappableRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *DirtySet) Split(seg DirtyIterator, split uint64) (DirtyIterator, DirtyIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *DirtySet) SplitUnchecked(seg DirtyIterator, split uint64) (DirtyIterator, DirtyIterator) { + val1, val2 := (dirtySetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.MappableRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *DirtySet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *DirtySet) Isolate(seg DirtyIterator, r __generics_imported0.MappableRange) DirtyIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *DirtySet) ApplyContiguous(r __generics_imported0.MappableRange, fn func(seg DirtyIterator)) DirtyGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return DirtyGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return DirtyGapIterator{} + } + } +} + +// +stateify savable +type Dirtynode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *Dirtynode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [DirtymaxDegree - 1]__generics_imported0.MappableRange + values [DirtymaxDegree - 1]DirtyInfo + children [DirtymaxDegree]*Dirtynode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Dirtynode) firstSegment() DirtyIterator { + for n.hasChildren { + n = n.children[0] + } + return DirtyIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Dirtynode) lastSegment() DirtyIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return DirtyIterator{n, n.nrSegments - 1} +} + +func (n *Dirtynode) prevSibling() *Dirtynode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *Dirtynode) nextSibling() *Dirtynode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *Dirtynode) rebalanceBeforeInsert(gap DirtyGapIterator) DirtyGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < DirtymaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &Dirtynode{ + nrSegments: DirtyminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &Dirtynode{ + nrSegments: DirtyminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:DirtyminDegree-1], n.keys[:DirtyminDegree-1]) + copy(left.values[:DirtyminDegree-1], n.values[:DirtyminDegree-1]) + copy(right.keys[:DirtyminDegree-1], n.keys[DirtyminDegree:]) + copy(right.values[:DirtyminDegree-1], n.values[DirtyminDegree:]) + n.keys[0], n.values[0] = n.keys[DirtyminDegree-1], n.values[DirtyminDegree-1] + DirtyzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:DirtyminDegree], n.children[:DirtyminDegree]) + copy(right.children[:DirtyminDegree], n.children[DirtyminDegree:]) + DirtyzeroNodeSlice(n.children[2:]) + for i := 0; i < DirtyminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < DirtyminDegree { + return DirtyGapIterator{left, gap.index} + } + return DirtyGapIterator{right, gap.index - DirtyminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[DirtyminDegree-1], n.values[DirtyminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &Dirtynode{ + nrSegments: DirtyminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:DirtyminDegree-1], n.keys[DirtyminDegree:]) + copy(sibling.values[:DirtyminDegree-1], n.values[DirtyminDegree:]) + DirtyzeroValueSlice(n.values[DirtyminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:DirtyminDegree], n.children[DirtyminDegree:]) + DirtyzeroNodeSlice(n.children[DirtyminDegree:]) + for i := 0; i < DirtyminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = DirtyminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < DirtyminDegree { + return gap + } + return DirtyGapIterator{sibling, gap.index - DirtyminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *Dirtynode) rebalanceAfterRemove(gap DirtyGapIterator) DirtyGapIterator { + for { + if n.nrSegments >= DirtyminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= DirtyminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + dirtySetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return DirtyGapIterator{n, 0} + } + if gap.node == n { + return DirtyGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= DirtyminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + dirtySetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return DirtyGapIterator{n, n.nrSegments} + } + return DirtyGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return DirtyGapIterator{p, gap.index} + } + if gap.node == right { + return DirtyGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *Dirtynode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = DirtyGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + dirtySetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type DirtyIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *Dirtynode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg DirtyIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg DirtyIterator) Range() __generics_imported0.MappableRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg DirtyIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg DirtyIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg DirtyIterator) SetRangeUnchecked(r __generics_imported0.MappableRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg DirtyIterator) SetRange(r __generics_imported0.MappableRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg DirtyIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg DirtyIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg DirtyIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg DirtyIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg DirtyIterator) Value() DirtyInfo { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg DirtyIterator) ValuePtr() *DirtyInfo { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg DirtyIterator) SetValue(val DirtyInfo) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg DirtyIterator) PrevSegment() DirtyIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return DirtyIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return DirtyIterator{} + } + return DirtysegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg DirtyIterator) NextSegment() DirtyIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return DirtyIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return DirtyIterator{} + } + return DirtysegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg DirtyIterator) PrevGap() DirtyGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return DirtyGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg DirtyIterator) NextGap() DirtyGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return DirtyGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg DirtyIterator) PrevNonEmpty() (DirtyIterator, DirtyGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return DirtyIterator{}, gap + } + return gap.PrevSegment(), DirtyGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg DirtyIterator) NextNonEmpty() (DirtyIterator, DirtyGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return DirtyIterator{}, gap + } + return gap.NextSegment(), DirtyGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type DirtyGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *Dirtynode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap DirtyGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap DirtyGapIterator) Range() __generics_imported0.MappableRange { + return __generics_imported0.MappableRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap DirtyGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return dirtySetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap DirtyGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return dirtySetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap DirtyGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap DirtyGapIterator) PrevSegment() DirtyIterator { + return DirtysegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap DirtyGapIterator) NextSegment() DirtyIterator { + return DirtysegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap DirtyGapIterator) PrevGap() DirtyGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return DirtyGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap DirtyGapIterator) NextGap() DirtyGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return DirtyGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func DirtysegmentBeforePosition(n *Dirtynode, i int) DirtyIterator { + for i == 0 { + if n.parent == nil { + return DirtyIterator{} + } + n, i = n.parent, n.parentIndex + } + return DirtyIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func DirtysegmentAfterPosition(n *Dirtynode, i int) DirtyIterator { + for i == n.nrSegments { + if n.parent == nil { + return DirtyIterator{} + } + n, i = n.parent, n.parentIndex + } + return DirtyIterator{n, i} +} + +func DirtyzeroValueSlice(slice []DirtyInfo) { + + for i := range slice { + dirtySetFunctions{}.ClearValue(&slice[i]) + } +} + +func DirtyzeroNodeSlice(slice []*Dirtynode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *DirtySet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *Dirtynode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *Dirtynode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type DirtySegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []DirtyInfo +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *DirtySet) ExportSortedSlices() *DirtySegmentDataSlices { + var sds DirtySegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *DirtySet) ImportSortedSlices(sds *DirtySegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.MappableRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *DirtySet) saveRoot() *DirtySegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *DirtySet) loadRoot(sds *DirtySegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/fsutil/dirty_set_test.go b/pkg/sentry/fs/fsutil/dirty_set_test.go deleted file mode 100644 index 75575d994..000000000 --- a/pkg/sentry/fs/fsutil/dirty_set_test.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fsutil - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -func TestDirtySet(t *testing.T) { - var set DirtySet - set.MarkDirty(memmap.MappableRange{0, 2 * usermem.PageSize}) - set.KeepDirty(memmap.MappableRange{usermem.PageSize, 2 * usermem.PageSize}) - set.MarkClean(memmap.MappableRange{0, 2 * usermem.PageSize}) - want := &DirtySegmentDataSlices{ - Start: []uint64{usermem.PageSize}, - End: []uint64{2 * usermem.PageSize}, - Values: []DirtyInfo{{Keep: true}}, - } - if got := set.ExportSortedSlices(); !reflect.DeepEqual(got, want) { - t.Errorf("set:\n\tgot %v,\n\twant %v", got, want) - } -} diff --git a/pkg/sentry/fs/fsutil/file_range_set_impl.go b/pkg/sentry/fs/fsutil/file_range_set_impl.go new file mode 100755 index 000000000..228893348 --- /dev/null +++ b/pkg/sentry/fs/fsutil/file_range_set_impl.go @@ -0,0 +1,1274 @@ +package fsutil + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/memmap" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + FileRangeminDegree = 3 + + FileRangemaxDegree = 2 * FileRangeminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type FileRangeSet struct { + root FileRangenode `state:".(*FileRangeSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *FileRangeSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *FileRangeSet) IsEmptyRange(r __generics_imported0.MappableRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *FileRangeSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *FileRangeSet) SpanRange(r __generics_imported0.MappableRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *FileRangeSet) FirstSegment() FileRangeIterator { + if s.root.nrSegments == 0 { + return FileRangeIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *FileRangeSet) LastSegment() FileRangeIterator { + if s.root.nrSegments == 0 { + return FileRangeIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *FileRangeSet) FirstGap() FileRangeGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return FileRangeGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *FileRangeSet) LastGap() FileRangeGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return FileRangeGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *FileRangeSet) Find(key uint64) (FileRangeIterator, FileRangeGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return FileRangeIterator{n, i}, FileRangeGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return FileRangeIterator{}, FileRangeGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *FileRangeSet) FindSegment(key uint64) FileRangeIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *FileRangeSet) LowerBoundSegment(min uint64) FileRangeIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *FileRangeSet) UpperBoundSegment(max uint64) FileRangeIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *FileRangeSet) FindGap(key uint64) FileRangeGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *FileRangeSet) LowerBoundGap(min uint64) FileRangeGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *FileRangeSet) UpperBoundGap(max uint64) FileRangeGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *FileRangeSet) Add(r __generics_imported0.MappableRange, val uint64) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *FileRangeSet) AddWithoutMerging(r __generics_imported0.MappableRange, val uint64) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *FileRangeSet) Insert(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (fileRangeSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (fileRangeSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (fileRangeSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *FileRangeSet) InsertWithoutMerging(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *FileRangeSet) InsertWithoutMergingUnchecked(gap FileRangeGapIterator, r __generics_imported0.MappableRange, val uint64) FileRangeIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return FileRangeIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *FileRangeSet) Remove(seg FileRangeIterator) FileRangeGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + fileRangeSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(FileRangeGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *FileRangeSet) RemoveAll() { + s.root = FileRangenode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *FileRangeSet) RemoveRange(r __generics_imported0.MappableRange) FileRangeGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *FileRangeSet) Merge(first, second FileRangeIterator) FileRangeIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *FileRangeSet) MergeUnchecked(first, second FileRangeIterator) FileRangeIterator { + if first.End() == second.Start() { + if mval, ok := (fileRangeSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return FileRangeIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *FileRangeSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *FileRangeSet) MergeRange(r __generics_imported0.MappableRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *FileRangeSet) MergeAdjacent(r __generics_imported0.MappableRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *FileRangeSet) Split(seg FileRangeIterator, split uint64) (FileRangeIterator, FileRangeIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *FileRangeSet) SplitUnchecked(seg FileRangeIterator, split uint64) (FileRangeIterator, FileRangeIterator) { + val1, val2 := (fileRangeSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.MappableRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *FileRangeSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *FileRangeSet) Isolate(seg FileRangeIterator, r __generics_imported0.MappableRange) FileRangeIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *FileRangeSet) ApplyContiguous(r __generics_imported0.MappableRange, fn func(seg FileRangeIterator)) FileRangeGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return FileRangeGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return FileRangeGapIterator{} + } + } +} + +// +stateify savable +type FileRangenode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *FileRangenode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [FileRangemaxDegree - 1]__generics_imported0.MappableRange + values [FileRangemaxDegree - 1]uint64 + children [FileRangemaxDegree]*FileRangenode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *FileRangenode) firstSegment() FileRangeIterator { + for n.hasChildren { + n = n.children[0] + } + return FileRangeIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *FileRangenode) lastSegment() FileRangeIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return FileRangeIterator{n, n.nrSegments - 1} +} + +func (n *FileRangenode) prevSibling() *FileRangenode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *FileRangenode) nextSibling() *FileRangenode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *FileRangenode) rebalanceBeforeInsert(gap FileRangeGapIterator) FileRangeGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < FileRangemaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &FileRangenode{ + nrSegments: FileRangeminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &FileRangenode{ + nrSegments: FileRangeminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:FileRangeminDegree-1], n.keys[:FileRangeminDegree-1]) + copy(left.values[:FileRangeminDegree-1], n.values[:FileRangeminDegree-1]) + copy(right.keys[:FileRangeminDegree-1], n.keys[FileRangeminDegree:]) + copy(right.values[:FileRangeminDegree-1], n.values[FileRangeminDegree:]) + n.keys[0], n.values[0] = n.keys[FileRangeminDegree-1], n.values[FileRangeminDegree-1] + FileRangezeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:FileRangeminDegree], n.children[:FileRangeminDegree]) + copy(right.children[:FileRangeminDegree], n.children[FileRangeminDegree:]) + FileRangezeroNodeSlice(n.children[2:]) + for i := 0; i < FileRangeminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < FileRangeminDegree { + return FileRangeGapIterator{left, gap.index} + } + return FileRangeGapIterator{right, gap.index - FileRangeminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[FileRangeminDegree-1], n.values[FileRangeminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &FileRangenode{ + nrSegments: FileRangeminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:FileRangeminDegree-1], n.keys[FileRangeminDegree:]) + copy(sibling.values[:FileRangeminDegree-1], n.values[FileRangeminDegree:]) + FileRangezeroValueSlice(n.values[FileRangeminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:FileRangeminDegree], n.children[FileRangeminDegree:]) + FileRangezeroNodeSlice(n.children[FileRangeminDegree:]) + for i := 0; i < FileRangeminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = FileRangeminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < FileRangeminDegree { + return gap + } + return FileRangeGapIterator{sibling, gap.index - FileRangeminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *FileRangenode) rebalanceAfterRemove(gap FileRangeGapIterator) FileRangeGapIterator { + for { + if n.nrSegments >= FileRangeminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= FileRangeminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + fileRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return FileRangeGapIterator{n, 0} + } + if gap.node == n { + return FileRangeGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= FileRangeminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + fileRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return FileRangeGapIterator{n, n.nrSegments} + } + return FileRangeGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return FileRangeGapIterator{p, gap.index} + } + if gap.node == right { + return FileRangeGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *FileRangenode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = FileRangeGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + fileRangeSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type FileRangeIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *FileRangenode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg FileRangeIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg FileRangeIterator) Range() __generics_imported0.MappableRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg FileRangeIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg FileRangeIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg FileRangeIterator) SetRangeUnchecked(r __generics_imported0.MappableRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg FileRangeIterator) SetRange(r __generics_imported0.MappableRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg FileRangeIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg FileRangeIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg FileRangeIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg FileRangeIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg FileRangeIterator) Value() uint64 { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg FileRangeIterator) ValuePtr() *uint64 { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg FileRangeIterator) SetValue(val uint64) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg FileRangeIterator) PrevSegment() FileRangeIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return FileRangeIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return FileRangeIterator{} + } + return FileRangesegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg FileRangeIterator) NextSegment() FileRangeIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return FileRangeIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return FileRangeIterator{} + } + return FileRangesegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg FileRangeIterator) PrevGap() FileRangeGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return FileRangeGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg FileRangeIterator) NextGap() FileRangeGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return FileRangeGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg FileRangeIterator) PrevNonEmpty() (FileRangeIterator, FileRangeGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return FileRangeIterator{}, gap + } + return gap.PrevSegment(), FileRangeGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg FileRangeIterator) NextNonEmpty() (FileRangeIterator, FileRangeGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return FileRangeIterator{}, gap + } + return gap.NextSegment(), FileRangeGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type FileRangeGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *FileRangenode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap FileRangeGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap FileRangeGapIterator) Range() __generics_imported0.MappableRange { + return __generics_imported0.MappableRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap FileRangeGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return fileRangeSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap FileRangeGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return fileRangeSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap FileRangeGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap FileRangeGapIterator) PrevSegment() FileRangeIterator { + return FileRangesegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap FileRangeGapIterator) NextSegment() FileRangeIterator { + return FileRangesegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap FileRangeGapIterator) PrevGap() FileRangeGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return FileRangeGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap FileRangeGapIterator) NextGap() FileRangeGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return FileRangeGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func FileRangesegmentBeforePosition(n *FileRangenode, i int) FileRangeIterator { + for i == 0 { + if n.parent == nil { + return FileRangeIterator{} + } + n, i = n.parent, n.parentIndex + } + return FileRangeIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func FileRangesegmentAfterPosition(n *FileRangenode, i int) FileRangeIterator { + for i == n.nrSegments { + if n.parent == nil { + return FileRangeIterator{} + } + n, i = n.parent, n.parentIndex + } + return FileRangeIterator{n, i} +} + +func FileRangezeroValueSlice(slice []uint64) { + + for i := range slice { + fileRangeSetFunctions{}.ClearValue(&slice[i]) + } +} + +func FileRangezeroNodeSlice(slice []*FileRangenode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *FileRangeSet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *FileRangenode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *FileRangenode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type FileRangeSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []uint64 +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *FileRangeSet) ExportSortedSlices() *FileRangeSegmentDataSlices { + var sds FileRangeSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *FileRangeSet) ImportSortedSlices(sds *FileRangeSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.MappableRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *FileRangeSet) saveRoot() *FileRangeSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *FileRangeSet) loadRoot(sds *FileRangeSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/fsutil/frame_ref_set_impl.go b/pkg/sentry/fs/fsutil/frame_ref_set_impl.go new file mode 100755 index 000000000..6bc3f00ea --- /dev/null +++ b/pkg/sentry/fs/fsutil/frame_ref_set_impl.go @@ -0,0 +1,1274 @@ +package fsutil + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/platform" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + frameRefminDegree = 3 + + frameRefmaxDegree = 2 * frameRefminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type frameRefSet struct { + root frameRefnode `state:".(*frameRefSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *frameRefSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *frameRefSet) IsEmptyRange(r __generics_imported0.FileRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *frameRefSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *frameRefSet) SpanRange(r __generics_imported0.FileRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *frameRefSet) FirstSegment() frameRefIterator { + if s.root.nrSegments == 0 { + return frameRefIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *frameRefSet) LastSegment() frameRefIterator { + if s.root.nrSegments == 0 { + return frameRefIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *frameRefSet) FirstGap() frameRefGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return frameRefGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *frameRefSet) LastGap() frameRefGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return frameRefGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *frameRefSet) Find(key uint64) (frameRefIterator, frameRefGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return frameRefIterator{n, i}, frameRefGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return frameRefIterator{}, frameRefGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *frameRefSet) FindSegment(key uint64) frameRefIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *frameRefSet) LowerBoundSegment(min uint64) frameRefIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *frameRefSet) UpperBoundSegment(max uint64) frameRefIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *frameRefSet) FindGap(key uint64) frameRefGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *frameRefSet) LowerBoundGap(min uint64) frameRefGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *frameRefSet) UpperBoundGap(max uint64) frameRefGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *frameRefSet) Add(r __generics_imported0.FileRange, val uint64) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *frameRefSet) AddWithoutMerging(r __generics_imported0.FileRange, val uint64) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *frameRefSet) Insert(gap frameRefGapIterator, r __generics_imported0.FileRange, val uint64) frameRefIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (frameRefSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (frameRefSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (frameRefSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *frameRefSet) InsertWithoutMerging(gap frameRefGapIterator, r __generics_imported0.FileRange, val uint64) frameRefIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *frameRefSet) InsertWithoutMergingUnchecked(gap frameRefGapIterator, r __generics_imported0.FileRange, val uint64) frameRefIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return frameRefIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *frameRefSet) Remove(seg frameRefIterator) frameRefGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + frameRefSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(frameRefGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *frameRefSet) RemoveAll() { + s.root = frameRefnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *frameRefSet) RemoveRange(r __generics_imported0.FileRange) frameRefGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *frameRefSet) Merge(first, second frameRefIterator) frameRefIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *frameRefSet) MergeUnchecked(first, second frameRefIterator) frameRefIterator { + if first.End() == second.Start() { + if mval, ok := (frameRefSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return frameRefIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *frameRefSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *frameRefSet) MergeRange(r __generics_imported0.FileRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *frameRefSet) MergeAdjacent(r __generics_imported0.FileRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *frameRefSet) Split(seg frameRefIterator, split uint64) (frameRefIterator, frameRefIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *frameRefSet) SplitUnchecked(seg frameRefIterator, split uint64) (frameRefIterator, frameRefIterator) { + val1, val2 := (frameRefSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *frameRefSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *frameRefSet) Isolate(seg frameRefIterator, r __generics_imported0.FileRange) frameRefIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *frameRefSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg frameRefIterator)) frameRefGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return frameRefGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return frameRefGapIterator{} + } + } +} + +// +stateify savable +type frameRefnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *frameRefnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [frameRefmaxDegree - 1]__generics_imported0.FileRange + values [frameRefmaxDegree - 1]uint64 + children [frameRefmaxDegree]*frameRefnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *frameRefnode) firstSegment() frameRefIterator { + for n.hasChildren { + n = n.children[0] + } + return frameRefIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *frameRefnode) lastSegment() frameRefIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return frameRefIterator{n, n.nrSegments - 1} +} + +func (n *frameRefnode) prevSibling() *frameRefnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *frameRefnode) nextSibling() *frameRefnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *frameRefnode) rebalanceBeforeInsert(gap frameRefGapIterator) frameRefGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < frameRefmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &frameRefnode{ + nrSegments: frameRefminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &frameRefnode{ + nrSegments: frameRefminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:frameRefminDegree-1], n.keys[:frameRefminDegree-1]) + copy(left.values[:frameRefminDegree-1], n.values[:frameRefminDegree-1]) + copy(right.keys[:frameRefminDegree-1], n.keys[frameRefminDegree:]) + copy(right.values[:frameRefminDegree-1], n.values[frameRefminDegree:]) + n.keys[0], n.values[0] = n.keys[frameRefminDegree-1], n.values[frameRefminDegree-1] + frameRefzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:frameRefminDegree], n.children[:frameRefminDegree]) + copy(right.children[:frameRefminDegree], n.children[frameRefminDegree:]) + frameRefzeroNodeSlice(n.children[2:]) + for i := 0; i < frameRefminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < frameRefminDegree { + return frameRefGapIterator{left, gap.index} + } + return frameRefGapIterator{right, gap.index - frameRefminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[frameRefminDegree-1], n.values[frameRefminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &frameRefnode{ + nrSegments: frameRefminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:frameRefminDegree-1], n.keys[frameRefminDegree:]) + copy(sibling.values[:frameRefminDegree-1], n.values[frameRefminDegree:]) + frameRefzeroValueSlice(n.values[frameRefminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:frameRefminDegree], n.children[frameRefminDegree:]) + frameRefzeroNodeSlice(n.children[frameRefminDegree:]) + for i := 0; i < frameRefminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = frameRefminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < frameRefminDegree { + return gap + } + return frameRefGapIterator{sibling, gap.index - frameRefminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *frameRefnode) rebalanceAfterRemove(gap frameRefGapIterator) frameRefGapIterator { + for { + if n.nrSegments >= frameRefminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= frameRefminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + frameRefSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return frameRefGapIterator{n, 0} + } + if gap.node == n { + return frameRefGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= frameRefminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + frameRefSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return frameRefGapIterator{n, n.nrSegments} + } + return frameRefGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return frameRefGapIterator{p, gap.index} + } + if gap.node == right { + return frameRefGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *frameRefnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = frameRefGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + frameRefSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type frameRefIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *frameRefnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg frameRefIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg frameRefIterator) Range() __generics_imported0.FileRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg frameRefIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg frameRefIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg frameRefIterator) SetRangeUnchecked(r __generics_imported0.FileRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg frameRefIterator) SetRange(r __generics_imported0.FileRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg frameRefIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg frameRefIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg frameRefIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg frameRefIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg frameRefIterator) Value() uint64 { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg frameRefIterator) ValuePtr() *uint64 { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg frameRefIterator) SetValue(val uint64) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg frameRefIterator) PrevSegment() frameRefIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return frameRefIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return frameRefIterator{} + } + return frameRefsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg frameRefIterator) NextSegment() frameRefIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return frameRefIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return frameRefIterator{} + } + return frameRefsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg frameRefIterator) PrevGap() frameRefGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return frameRefGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg frameRefIterator) NextGap() frameRefGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return frameRefGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg frameRefIterator) PrevNonEmpty() (frameRefIterator, frameRefGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return frameRefIterator{}, gap + } + return gap.PrevSegment(), frameRefGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg frameRefIterator) NextNonEmpty() (frameRefIterator, frameRefGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return frameRefIterator{}, gap + } + return gap.NextSegment(), frameRefGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type frameRefGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *frameRefnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap frameRefGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap frameRefGapIterator) Range() __generics_imported0.FileRange { + return __generics_imported0.FileRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap frameRefGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return frameRefSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap frameRefGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return frameRefSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap frameRefGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap frameRefGapIterator) PrevSegment() frameRefIterator { + return frameRefsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap frameRefGapIterator) NextSegment() frameRefIterator { + return frameRefsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap frameRefGapIterator) PrevGap() frameRefGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return frameRefGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap frameRefGapIterator) NextGap() frameRefGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return frameRefGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func frameRefsegmentBeforePosition(n *frameRefnode, i int) frameRefIterator { + for i == 0 { + if n.parent == nil { + return frameRefIterator{} + } + n, i = n.parent, n.parentIndex + } + return frameRefIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func frameRefsegmentAfterPosition(n *frameRefnode, i int) frameRefIterator { + for i == n.nrSegments { + if n.parent == nil { + return frameRefIterator{} + } + n, i = n.parent, n.parentIndex + } + return frameRefIterator{n, i} +} + +func frameRefzeroValueSlice(slice []uint64) { + + for i := range slice { + frameRefSetFunctions{}.ClearValue(&slice[i]) + } +} + +func frameRefzeroNodeSlice(slice []*frameRefnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *frameRefSet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *frameRefnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *frameRefnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type frameRefSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []uint64 +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *frameRefSet) ExportSortedSlices() *frameRefSegmentDataSlices { + var sds frameRefSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *frameRefSet) ImportSortedSlices(sds *frameRefSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *frameRefSet) saveRoot() *frameRefSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *frameRefSet) loadRoot(sds *frameRefSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/fsutil/fsutil_state_autogen.go b/pkg/sentry/fs/fsutil/fsutil_state_autogen.go new file mode 100755 index 000000000..6371a66a5 --- /dev/null +++ b/pkg/sentry/fs/fsutil/fsutil_state_autogen.go @@ -0,0 +1,349 @@ +// automatically generated by stateify. + +package fsutil + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *DirtyInfo) beforeSave() {} +func (x *DirtyInfo) save(m state.Map) { + x.beforeSave() + m.Save("Keep", &x.Keep) +} + +func (x *DirtyInfo) afterLoad() {} +func (x *DirtyInfo) load(m state.Map) { + m.Load("Keep", &x.Keep) +} + +func (x *DirtySet) beforeSave() {} +func (x *DirtySet) save(m state.Map) { + x.beforeSave() + var root *DirtySegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *DirtySet) afterLoad() {} +func (x *DirtySet) load(m state.Map) { + m.LoadValue("root", new(*DirtySegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*DirtySegmentDataSlices)) }) +} + +func (x *Dirtynode) beforeSave() {} +func (x *Dirtynode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *Dirtynode) afterLoad() {} +func (x *Dirtynode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *DirtySegmentDataSlices) beforeSave() {} +func (x *DirtySegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *DirtySegmentDataSlices) afterLoad() {} +func (x *DirtySegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *StaticDirFileOperations) beforeSave() {} +func (x *StaticDirFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("dentryMap", &x.dentryMap) + m.Save("dirCursor", &x.dirCursor) +} + +func (x *StaticDirFileOperations) afterLoad() {} +func (x *StaticDirFileOperations) load(m state.Map) { + m.Load("dentryMap", &x.dentryMap) + m.Load("dirCursor", &x.dirCursor) +} + +func (x *NoReadWriteFile) beforeSave() {} +func (x *NoReadWriteFile) save(m state.Map) { + x.beforeSave() +} + +func (x *NoReadWriteFile) afterLoad() {} +func (x *NoReadWriteFile) load(m state.Map) { +} + +func (x *FileStaticContentReader) beforeSave() {} +func (x *FileStaticContentReader) save(m state.Map) { + x.beforeSave() + m.Save("content", &x.content) +} + +func (x *FileStaticContentReader) afterLoad() {} +func (x *FileStaticContentReader) load(m state.Map) { + m.Load("content", &x.content) +} + +func (x *FileRangeSet) beforeSave() {} +func (x *FileRangeSet) save(m state.Map) { + x.beforeSave() + var root *FileRangeSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *FileRangeSet) afterLoad() {} +func (x *FileRangeSet) load(m state.Map) { + m.LoadValue("root", new(*FileRangeSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*FileRangeSegmentDataSlices)) }) +} + +func (x *FileRangenode) beforeSave() {} +func (x *FileRangenode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *FileRangenode) afterLoad() {} +func (x *FileRangenode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *FileRangeSegmentDataSlices) beforeSave() {} +func (x *FileRangeSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *FileRangeSegmentDataSlices) afterLoad() {} +func (x *FileRangeSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *frameRefSet) beforeSave() {} +func (x *frameRefSet) save(m state.Map) { + x.beforeSave() + var root *frameRefSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *frameRefSet) afterLoad() {} +func (x *frameRefSet) load(m state.Map) { + m.LoadValue("root", new(*frameRefSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*frameRefSegmentDataSlices)) }) +} + +func (x *frameRefnode) beforeSave() {} +func (x *frameRefnode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *frameRefnode) afterLoad() {} +func (x *frameRefnode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *frameRefSegmentDataSlices) beforeSave() {} +func (x *frameRefSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *frameRefSegmentDataSlices) afterLoad() {} +func (x *frameRefSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *HostFileMapper) beforeSave() {} +func (x *HostFileMapper) save(m state.Map) { + x.beforeSave() + m.Save("refs", &x.refs) +} + +func (x *HostFileMapper) load(m state.Map) { + m.Load("refs", &x.refs) + m.AfterLoad(x.afterLoad) +} + +func (x *HostMappable) beforeSave() {} +func (x *HostMappable) save(m state.Map) { + x.beforeSave() + m.Save("hostFileMapper", &x.hostFileMapper) + m.Save("backingFile", &x.backingFile) + m.Save("mappings", &x.mappings) +} + +func (x *HostMappable) afterLoad() {} +func (x *HostMappable) load(m state.Map) { + m.Load("hostFileMapper", &x.hostFileMapper) + m.Load("backingFile", &x.backingFile) + m.Load("mappings", &x.mappings) +} + +func (x *SimpleFileInode) beforeSave() {} +func (x *SimpleFileInode) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *SimpleFileInode) afterLoad() {} +func (x *SimpleFileInode) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *NoReadWriteFileInode) beforeSave() {} +func (x *NoReadWriteFileInode) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *NoReadWriteFileInode) afterLoad() {} +func (x *NoReadWriteFileInode) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) +} + +func (x *InodeSimpleAttributes) beforeSave() {} +func (x *InodeSimpleAttributes) save(m state.Map) { + x.beforeSave() + m.Save("fsType", &x.fsType) + m.Save("unstable", &x.unstable) +} + +func (x *InodeSimpleAttributes) afterLoad() {} +func (x *InodeSimpleAttributes) load(m state.Map) { + m.Load("fsType", &x.fsType) + m.Load("unstable", &x.unstable) +} + +func (x *InodeSimpleExtendedAttributes) beforeSave() {} +func (x *InodeSimpleExtendedAttributes) save(m state.Map) { + x.beforeSave() + m.Save("xattrs", &x.xattrs) +} + +func (x *InodeSimpleExtendedAttributes) afterLoad() {} +func (x *InodeSimpleExtendedAttributes) load(m state.Map) { + m.Load("xattrs", &x.xattrs) +} + +func (x *staticFile) beforeSave() {} +func (x *staticFile) save(m state.Map) { + x.beforeSave() + m.Save("FileStaticContentReader", &x.FileStaticContentReader) +} + +func (x *staticFile) afterLoad() {} +func (x *staticFile) load(m state.Map) { + m.Load("FileStaticContentReader", &x.FileStaticContentReader) +} + +func (x *InodeStaticFileGetter) beforeSave() {} +func (x *InodeStaticFileGetter) save(m state.Map) { + x.beforeSave() + m.Save("Contents", &x.Contents) +} + +func (x *InodeStaticFileGetter) afterLoad() {} +func (x *InodeStaticFileGetter) load(m state.Map) { + m.Load("Contents", &x.Contents) +} + +func (x *CachingInodeOperations) beforeSave() {} +func (x *CachingInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("backingFile", &x.backingFile) + m.Save("mfp", &x.mfp) + m.Save("forcePageCache", &x.forcePageCache) + m.Save("attr", &x.attr) + m.Save("dirtyAttr", &x.dirtyAttr) + m.Save("mappings", &x.mappings) + m.Save("cache", &x.cache) + m.Save("dirty", &x.dirty) + m.Save("hostFileMapper", &x.hostFileMapper) + m.Save("refs", &x.refs) +} + +func (x *CachingInodeOperations) afterLoad() {} +func (x *CachingInodeOperations) load(m state.Map) { + m.Load("backingFile", &x.backingFile) + m.Load("mfp", &x.mfp) + m.Load("forcePageCache", &x.forcePageCache) + m.Load("attr", &x.attr) + m.Load("dirtyAttr", &x.dirtyAttr) + m.Load("mappings", &x.mappings) + m.Load("cache", &x.cache) + m.Load("dirty", &x.dirty) + m.Load("hostFileMapper", &x.hostFileMapper) + m.Load("refs", &x.refs) +} + +func init() { + state.Register("fsutil.DirtyInfo", (*DirtyInfo)(nil), state.Fns{Save: (*DirtyInfo).save, Load: (*DirtyInfo).load}) + state.Register("fsutil.DirtySet", (*DirtySet)(nil), state.Fns{Save: (*DirtySet).save, Load: (*DirtySet).load}) + state.Register("fsutil.Dirtynode", (*Dirtynode)(nil), state.Fns{Save: (*Dirtynode).save, Load: (*Dirtynode).load}) + state.Register("fsutil.DirtySegmentDataSlices", (*DirtySegmentDataSlices)(nil), state.Fns{Save: (*DirtySegmentDataSlices).save, Load: (*DirtySegmentDataSlices).load}) + state.Register("fsutil.StaticDirFileOperations", (*StaticDirFileOperations)(nil), state.Fns{Save: (*StaticDirFileOperations).save, Load: (*StaticDirFileOperations).load}) + state.Register("fsutil.NoReadWriteFile", (*NoReadWriteFile)(nil), state.Fns{Save: (*NoReadWriteFile).save, Load: (*NoReadWriteFile).load}) + state.Register("fsutil.FileStaticContentReader", (*FileStaticContentReader)(nil), state.Fns{Save: (*FileStaticContentReader).save, Load: (*FileStaticContentReader).load}) + state.Register("fsutil.FileRangeSet", (*FileRangeSet)(nil), state.Fns{Save: (*FileRangeSet).save, Load: (*FileRangeSet).load}) + state.Register("fsutil.FileRangenode", (*FileRangenode)(nil), state.Fns{Save: (*FileRangenode).save, Load: (*FileRangenode).load}) + state.Register("fsutil.FileRangeSegmentDataSlices", (*FileRangeSegmentDataSlices)(nil), state.Fns{Save: (*FileRangeSegmentDataSlices).save, Load: (*FileRangeSegmentDataSlices).load}) + state.Register("fsutil.frameRefSet", (*frameRefSet)(nil), state.Fns{Save: (*frameRefSet).save, Load: (*frameRefSet).load}) + state.Register("fsutil.frameRefnode", (*frameRefnode)(nil), state.Fns{Save: (*frameRefnode).save, Load: (*frameRefnode).load}) + state.Register("fsutil.frameRefSegmentDataSlices", (*frameRefSegmentDataSlices)(nil), state.Fns{Save: (*frameRefSegmentDataSlices).save, Load: (*frameRefSegmentDataSlices).load}) + state.Register("fsutil.HostFileMapper", (*HostFileMapper)(nil), state.Fns{Save: (*HostFileMapper).save, Load: (*HostFileMapper).load}) + state.Register("fsutil.HostMappable", (*HostMappable)(nil), state.Fns{Save: (*HostMappable).save, Load: (*HostMappable).load}) + state.Register("fsutil.SimpleFileInode", (*SimpleFileInode)(nil), state.Fns{Save: (*SimpleFileInode).save, Load: (*SimpleFileInode).load}) + state.Register("fsutil.NoReadWriteFileInode", (*NoReadWriteFileInode)(nil), state.Fns{Save: (*NoReadWriteFileInode).save, Load: (*NoReadWriteFileInode).load}) + state.Register("fsutil.InodeSimpleAttributes", (*InodeSimpleAttributes)(nil), state.Fns{Save: (*InodeSimpleAttributes).save, Load: (*InodeSimpleAttributes).load}) + state.Register("fsutil.InodeSimpleExtendedAttributes", (*InodeSimpleExtendedAttributes)(nil), state.Fns{Save: (*InodeSimpleExtendedAttributes).save, Load: (*InodeSimpleExtendedAttributes).load}) + state.Register("fsutil.staticFile", (*staticFile)(nil), state.Fns{Save: (*staticFile).save, Load: (*staticFile).load}) + state.Register("fsutil.InodeStaticFileGetter", (*InodeStaticFileGetter)(nil), state.Fns{Save: (*InodeStaticFileGetter).save, Load: (*InodeStaticFileGetter).load}) + state.Register("fsutil.CachingInodeOperations", (*CachingInodeOperations)(nil), state.Fns{Save: (*CachingInodeOperations).save, Load: (*CachingInodeOperations).load}) +} diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go deleted file mode 100644 index 82a67b2e5..000000000 --- a/pkg/sentry/fs/fsutil/inode_cached_test.go +++ /dev/null @@ -1,389 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fsutil - -import ( - "bytes" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserror" -) - -type noopBackingFile struct{} - -func (noopBackingFile) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) { - return dsts.NumBytes(), nil -} - -func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) { - return srcs.NumBytes(), nil -} - -func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error { - return nil -} - -func (noopBackingFile) Sync(context.Context) error { - return nil -} - -func (noopBackingFile) FD() int { - return -1 -} - -func (noopBackingFile) Allocate(ctx context.Context, offset int64, length int64) error { - return nil -} - -func TestSetPermissions(t *testing.T) { - ctx := contexttest.Context(t) - - uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{ - Perms: fs.FilePermsFromMode(0444), - }) - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/) - defer iops.Release() - - perms := fs.FilePermsFromMode(0777) - if !iops.SetPermissions(ctx, nil, perms) { - t.Fatalf("SetPermissions failed, want success") - } - - // Did permissions change? - if iops.attr.Perms != perms { - t.Fatalf("got perms +%v, want +%v", iops.attr.Perms, perms) - } - - // Did status change time change? - if !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("got status change time not dirty, want dirty") - } - if iops.attr.StatusChangeTime.Equal(uattr.StatusChangeTime) { - t.Fatalf("got status change time unchanged") - } -} - -func TestSetTimestamps(t *testing.T) { - ctx := contexttest.Context(t) - for _, test := range []struct { - desc string - ts fs.TimeSpec - wantChanged fs.AttrMask - }{ - { - desc: "noop", - ts: fs.TimeSpec{ - ATimeOmit: true, - MTimeOmit: true, - }, - wantChanged: fs.AttrMask{}, - }, - { - desc: "access time only", - ts: fs.TimeSpec{ - ATime: ktime.NowFromContext(ctx), - MTimeOmit: true, - }, - wantChanged: fs.AttrMask{ - AccessTime: true, - }, - }, - { - desc: "modification time only", - ts: fs.TimeSpec{ - ATimeOmit: true, - MTime: ktime.NowFromContext(ctx), - }, - wantChanged: fs.AttrMask{ - ModificationTime: true, - }, - }, - { - desc: "access and modification time", - ts: fs.TimeSpec{ - ATime: ktime.NowFromContext(ctx), - MTime: ktime.NowFromContext(ctx), - }, - wantChanged: fs.AttrMask{ - AccessTime: true, - ModificationTime: true, - }, - }, - { - desc: "system time access and modification time", - ts: fs.TimeSpec{ - ATimeSetSystemTime: true, - MTimeSetSystemTime: true, - }, - wantChanged: fs.AttrMask{ - AccessTime: true, - ModificationTime: true, - }, - }, - } { - t.Run(test.desc, func(t *testing.T) { - ctx := contexttest.Context(t) - - epoch := ktime.ZeroTime - uattr := fs.UnstableAttr{ - AccessTime: epoch, - ModificationTime: epoch, - StatusChangeTime: epoch, - } - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/) - defer iops.Release() - - if err := iops.SetTimestamps(ctx, nil, test.ts); err != nil { - t.Fatalf("SetTimestamps got error %v, want nil", err) - } - if test.wantChanged.AccessTime { - if !iops.attr.AccessTime.After(uattr.AccessTime) { - t.Fatalf("diritied access time did not advance, want %v > %v", iops.attr.AccessTime, uattr.AccessTime) - } - if !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("dirty access time requires dirty status change time") - } - if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) { - t.Fatalf("dirtied status change time did not advance") - } - } - if test.wantChanged.ModificationTime { - if !iops.attr.ModificationTime.After(uattr.ModificationTime) { - t.Fatalf("diritied modification time did not advance") - } - if !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("dirty modification time requires dirty status change time") - } - if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) { - t.Fatalf("dirtied status change time did not advance") - } - } - }) - } -} - -func TestTruncate(t *testing.T) { - ctx := contexttest.Context(t) - - uattr := fs.UnstableAttr{ - Size: 0, - } - iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/) - defer iops.Release() - - if err := iops.Truncate(ctx, nil, uattr.Size); err != nil { - t.Fatalf("Truncate got error %v, want nil", err) - } - var size int64 = 4096 - if err := iops.Truncate(ctx, nil, size); err != nil { - t.Fatalf("Truncate got error %v, want nil", err) - } - if iops.attr.Size != size { - t.Fatalf("Truncate got %d, want %d", iops.attr.Size, size) - } - if !iops.dirtyAttr.ModificationTime || !iops.dirtyAttr.StatusChangeTime { - t.Fatalf("Truncate did not dirty modification and status change time") - } - if !iops.attr.ModificationTime.After(uattr.ModificationTime) { - t.Fatalf("dirtied modification time did not change") - } - if !iops.attr.StatusChangeTime.After(uattr.StatusChangeTime) { - t.Fatalf("dirtied status change time did not change") - } -} - -type sliceBackingFile struct { - data []byte -} - -func newSliceBackingFile(data []byte) *sliceBackingFile { - return &sliceBackingFile{data} -} - -func (f *sliceBackingFile) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) { - r := safemem.BlockSeqReader{safemem.BlockSeqOf(safemem.BlockFromSafeSlice(f.data)).DropFirst64(offset)} - return r.ReadToBlocks(dsts) -} - -func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) { - w := safemem.BlockSeqWriter{safemem.BlockSeqOf(safemem.BlockFromSafeSlice(f.data)).DropFirst64(offset)} - return w.WriteFromBlocks(srcs) -} - -func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error { - return nil -} - -func (*sliceBackingFile) Sync(context.Context) error { - return nil -} - -func (*sliceBackingFile) FD() int { - return -1 -} - -func (f *sliceBackingFile) Allocate(ctx context.Context, offset int64, length int64) error { - return syserror.EOPNOTSUPP -} - -type noopMappingSpace struct{} - -// Invalidate implements memmap.MappingSpace.Invalidate. -func (noopMappingSpace) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) { -} - -func anonInode(ctx context.Context) *fs.Inode { - return fs.NewInode(&SimpleFileInode{ - InodeSimpleAttributes: NewInodeSimpleAttributes(ctx, fs.FileOwnerFromContext(ctx), fs.FilePermissions{ - User: fs.PermMask{Read: true, Write: true}, - }, 0), - }, fs.NewPseudoMountSource(), fs.StableAttr{ - Type: fs.Anonymous, - BlockSize: usermem.PageSize, - }) -} - -func pagesOf(bs ...byte) []byte { - buf := make([]byte, 0, len(bs)*usermem.PageSize) - for _, b := range bs { - buf = append(buf, bytes.Repeat([]byte{b}, usermem.PageSize)...) - } - return buf -} - -func TestRead(t *testing.T) { - ctx := contexttest.Context(t) - - // Construct a 3-page file. - buf := pagesOf('a', 'b', 'c') - file := fs.NewFile(ctx, fs.NewDirent(anonInode(ctx), "anon"), fs.FileFlags{}, nil) - uattr := fs.UnstableAttr{ - Size: int64(len(buf)), - } - iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/) - defer iops.Release() - - // Expect the cache to be initially empty. - if cached := iops.cache.Span(); cached != 0 { - t.Errorf("Span got %d, want 0", cached) - } - - // Create a memory mapping of the second page (as CachingInodeOperations - // expects to only cache mapped pages), then call Translate to force it to - // be cached. - var ms noopMappingSpace - ar := usermem.AddrRange{usermem.PageSize, 2 * usermem.PageSize} - if err := iops.AddMapping(ctx, ms, ar, usermem.PageSize, true); err != nil { - t.Fatalf("AddMapping got %v, want nil", err) - } - mr := memmap.MappableRange{usermem.PageSize, 2 * usermem.PageSize} - if _, err := iops.Translate(ctx, mr, mr, usermem.Read); err != nil { - t.Fatalf("Translate got %v, want nil", err) - } - if cached := iops.cache.Span(); cached != usermem.PageSize { - t.Errorf("SpanRange got %d, want %d", cached, usermem.PageSize) - } - - // Try to read 4 pages. The first and third pages should be read directly - // from the "file", the second page should be read from the cache, and only - // 3 pages (the size of the file) should be readable. - rbuf := make([]byte, 4*usermem.PageSize) - dst := usermem.BytesIOSequence(rbuf) - n, err := iops.Read(ctx, file, dst, 0) - if n != 3*usermem.PageSize || (err != nil && err != io.EOF) { - t.Fatalf("Read got (%d, %v), want (%d, nil or EOF)", n, err, 3*usermem.PageSize) - } - rbuf = rbuf[:3*usermem.PageSize] - - // Did we get the bytes we expect? - if !bytes.Equal(rbuf, buf) { - t.Errorf("Read back bytes %v, want %v", rbuf, buf) - } - - // Delete the memory mapping before iops.Release(). The cached page will - // either be evicted by ctx's pgalloc.MemoryFile, or dropped by - // iops.Release(). - iops.RemoveMapping(ctx, ms, ar, usermem.PageSize, true) -} - -func TestWrite(t *testing.T) { - ctx := contexttest.Context(t) - - // Construct a 4-page file. - buf := pagesOf('a', 'b', 'c', 'd') - orig := append([]byte(nil), buf...) - inode := anonInode(ctx) - uattr := fs.UnstableAttr{ - Size: int64(len(buf)), - } - iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/) - defer iops.Release() - - // Expect the cache to be initially empty. - if cached := iops.cache.Span(); cached != 0 { - t.Errorf("Span got %d, want 0", cached) - } - - // Create a memory mapping of the second and third pages (as - // CachingInodeOperations expects to only cache mapped pages), then call - // Translate to force them to be cached. - var ms noopMappingSpace - ar := usermem.AddrRange{usermem.PageSize, 3 * usermem.PageSize} - if err := iops.AddMapping(ctx, ms, ar, usermem.PageSize, true); err != nil { - t.Fatalf("AddMapping got %v, want nil", err) - } - defer iops.RemoveMapping(ctx, ms, ar, usermem.PageSize, true) - mr := memmap.MappableRange{usermem.PageSize, 3 * usermem.PageSize} - if _, err := iops.Translate(ctx, mr, mr, usermem.Read); err != nil { - t.Fatalf("Translate got %v, want nil", err) - } - if cached := iops.cache.Span(); cached != 2*usermem.PageSize { - t.Errorf("SpanRange got %d, want %d", cached, 2*usermem.PageSize) - } - - // Write to the first 2 pages. - wbuf := pagesOf('e', 'f') - src := usermem.BytesIOSequence(wbuf) - n, err := iops.Write(ctx, src, 0) - if n != 2*usermem.PageSize || err != nil { - t.Fatalf("Write got (%d, %v), want (%d, nil)", n, err, 2*usermem.PageSize) - } - - // The first page should have been written directly, since it was not cached. - want := append([]byte(nil), orig...) - copy(want, pagesOf('e')) - if !bytes.Equal(buf, want) { - t.Errorf("File contents are %v, want %v", buf, want) - } - - // Sync back to the "backing file". - if err := iops.WriteOut(ctx, inode); err != nil { - t.Errorf("Sync got %v, want nil", err) - } - - // Now the second page should have been written as well. - copy(want[usermem.PageSize:], pagesOf('f')) - if !bytes.Equal(buf, want) { - t.Errorf("File contents are %v, want %v", buf, want) - } -} diff --git a/pkg/sentry/fs/g3doc/inotify.md b/pkg/sentry/fs/g3doc/inotify.md deleted file mode 100644 index e1630553b..000000000 --- a/pkg/sentry/fs/g3doc/inotify.md +++ /dev/null @@ -1,122 +0,0 @@ -# Inotify - -Inotify implements the like-named filesystem event notification system for the -sentry, see `inotify(7)`. - -## Architecture - -For the most part, the sentry implementation of inotify mirrors the Linux -architecture. Inotify instances (i.e. the fd returned by inotify_init(2)) are -backed by a pseudo-filesystem. Events are generated from various places in the -sentry, including the [syscall layer][syscall_dir], the [vfs layer][dirent] and -the [process fd table][fd_map]. Watches are stored in inodes and generated -events are queued to the inotify instance owning the watches for delivery to the -user. - -## Objects - -Here is a brief description of the existing and new objects involved in the -sentry inotify mechanism, and how they interact: - -### [`fs.Inotify`][inotify] - -- An inotify instances, created by inotify_init(2)/inotify_init1(2). -- The inotify fd has a `fs.Dirent`, supports filesystem syscalls to read - events. -- Has multiple `fs.Watch`es, with at most one watch per target inode, per - inotify instance. -- Has an instance `id` which is globally unique. This is *not* the fd number - for this instance, since the fd can be duped. This `id` is not externally - visible. - -### [`fs.Watch`][watch] - -- An inotify watch, created/deleted by - inotify_add_watch(2)/inotify_rm_watch(2). -- Owned by an `fs.Inotify` instance, each watch keeps a pointer to the - `owner`. -- Associated with a single `fs.Inode`, which is the watch `target`. While the - watch is active, it indirectly pins `target` to memory. See the "Reference - Model" section for a detailed explanation. -- Filesystem operations on `target` generate `fs.Event`s. - -### [`fs.Event`][event] - -- A simple struct encapsulating all the fields for an inotify event. -- Generated by `fs.Watch`es and forwarded to the watches' `owner`s. -- Serialized to the user during read(2) syscalls on the associated - `fs.Inotify`'s fd. - -### [`fs.Dirent`][dirent] - -- Many inotify events are generated inside dirent methods. Events are - generated in the dirent methods rather than `fs.Inode` methods because some - events carry the name of the subject node, and node names are generally - unavailable in an `fs.Inode`. -- Dirents do not directly contain state for any watches. Instead, they forward - notifications to the underlying `fs.Inode`. - -### [`fs.Inode`][inode] - -- Interacts with inotify through `fs.Watch`es. -- Inodes contain a map of all active `fs.Watch`es on them. -- An `fs.Inotify` instance can have at most one `fs.Watch` per inode. - `fs.Watch`es on an inode are indexed by their `owner`'s `id`. -- All inotify logic is encapsulated in the [`Watches`][inode_watches] struct - in an inode. Logically, `Watches` is the set of inotify watches on the - inode. - -## Reference Model - -The sentry inotify implementation has a complex reference model. An inotify -watch observes a single inode. For efficient lookup, the state for a watch is -stored directly on the target inode. This state needs to be persistent for the -lifetime of watch. Unlike usual filesystem metadata, the watch state has no -"on-disk" representation, so they cannot be reconstructed by the filesystem if -the inode is flushed from memory. This effectively means we need to keep any -inodes with actives watches pinned to memory. - -We can't just hold an extra ref on the inode to pin it to memory because some -filesystems (such as gofer-based filesystems) don't have persistent inodes. In -such a filesystem, if we just pin the inode, nothing prevents the enclosing -dirent from being GCed. Once the dirent is GCed, the pinned inode is -unreachable -- these filesystems generate a new inode by re-reading the node -state on the next walk. Incidentally, hardlinks also don't work on these -filesystems for this reason. - -To prevent the above scenario, when a new watch is added on an inode, we *pin* -the dirent we used to reach the inode. Note that due to hardlinks, this dirent -may not be the only dirent pointing to the inode. Attempting to set an inotify -watch via multiple hardlinks to the same file results in the same watch being -returned for both links. However, for each new dirent we use to reach the same -inode, we add a new pin. We need a new pin for each new dirent used to reach the -inode because we have no guarantees about the deletion order of the different -links to the inode. - -## Lock Ordering - -There are 4 locks related to the inotify implementation: - -- `Inotify.mu`: the inotify instance lock. -- `Inotify.evMu`: the inotify event queue lock. -- `Watch.mu`: the watch lock, used to protect pins. -- `fs.Watches.mu`: the inode watch set mu, used to protect the collection of - watches on the inode. - -The correct lock ordering for inotify code is: - -`Inotify.mu` -> `fs.Watches.mu` -> `Watch.mu` -> `Inotify.evMu`. - -We need a distinct lock for the event queue because by the time a goroutine -attempts to queue a new event, it is already holding `fs.Watches.mu`. If we used -`Inotify.mu` to also protect the event queue, this would violate the above lock -ordering. - -[dirent]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/dirent.go -[event]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inotify_event.go -[fd_map]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/kernel/fd_map.go -[inode]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inode.go -[inode_watches]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inode_inotify.go -[inotify]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inotify.go -[syscall_dir]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/syscalls/linux/ -[watch]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inotify_watch.go diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD deleted file mode 100644 index 6b993928c..000000000 --- a/pkg/sentry/fs/gofer/BUILD +++ /dev/null @@ -1,65 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "gofer", - srcs = [ - "attr.go", - "cache_policy.go", - "context_file.go", - "device.go", - "file.go", - "file_state.go", - "fs.go", - "handles.go", - "inode.go", - "inode_state.go", - "path.go", - "session.go", - "session_state.go", - "socket.go", - "util.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/gofer", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/fd", - "//pkg/log", - "//pkg/metric", - "//pkg/p9", - "//pkg/refs", - "//pkg/secio", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fdpipe", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/host", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/safemem", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/unet", - "//pkg/waiter", - ], -) - -go_test( - name = "gofer_test", - size = "small", - srcs = ["gofer_test.go"], - embed = [":gofer"], - deps = [ - "//pkg/p9", - "//pkg/p9/p9test", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/fs", - ], -) diff --git a/pkg/sentry/fs/gofer/gofer_state_autogen.go b/pkg/sentry/fs/gofer/gofer_state_autogen.go new file mode 100755 index 000000000..e05895fab --- /dev/null +++ b/pkg/sentry/fs/gofer/gofer_state_autogen.go @@ -0,0 +1,113 @@ +// automatically generated by stateify. + +package gofer + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *fileOperations) beforeSave() {} +func (x *fileOperations) save(m state.Map) { + x.beforeSave() + m.Save("inodeOperations", &x.inodeOperations) + m.Save("dirCursor", &x.dirCursor) + m.Save("flags", &x.flags) +} + +func (x *fileOperations) load(m state.Map) { + m.LoadWait("inodeOperations", &x.inodeOperations) + m.Load("dirCursor", &x.dirCursor) + m.LoadWait("flags", &x.flags) + m.AfterLoad(x.afterLoad) +} + +func (x *filesystem) beforeSave() {} +func (x *filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystem) afterLoad() {} +func (x *filesystem) load(m state.Map) { +} + +func (x *inodeOperations) beforeSave() {} +func (x *inodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("fileState", &x.fileState) + m.Save("cachingInodeOps", &x.cachingInodeOps) +} + +func (x *inodeOperations) afterLoad() {} +func (x *inodeOperations) load(m state.Map) { + m.LoadWait("fileState", &x.fileState) + m.Load("cachingInodeOps", &x.cachingInodeOps) +} + +func (x *inodeFileState) save(m state.Map) { + x.beforeSave() + var loading struct{} = x.saveLoading() + m.SaveValue("loading", loading) + m.Save("s", &x.s) + m.Save("sattr", &x.sattr) + m.Save("savedUAttr", &x.savedUAttr) + m.Save("hostMappable", &x.hostMappable) +} + +func (x *inodeFileState) load(m state.Map) { + m.LoadWait("s", &x.s) + m.LoadWait("sattr", &x.sattr) + m.Load("savedUAttr", &x.savedUAttr) + m.Load("hostMappable", &x.hostMappable) + m.LoadValue("loading", new(struct{}), func(y interface{}) { x.loadLoading(y.(struct{})) }) + m.AfterLoad(x.afterLoad) +} + +func (x *endpointMaps) beforeSave() {} +func (x *endpointMaps) save(m state.Map) { + x.beforeSave() + m.Save("direntMap", &x.direntMap) + m.Save("pathMap", &x.pathMap) +} + +func (x *endpointMaps) afterLoad() {} +func (x *endpointMaps) load(m state.Map) { + m.Load("direntMap", &x.direntMap) + m.Load("pathMap", &x.pathMap) +} + +func (x *session) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("msize", &x.msize) + m.Save("version", &x.version) + m.Save("cachePolicy", &x.cachePolicy) + m.Save("aname", &x.aname) + m.Save("superBlockFlags", &x.superBlockFlags) + m.Save("connID", &x.connID) + m.Save("inodeMappings", &x.inodeMappings) + m.Save("mounter", &x.mounter) + m.Save("endpoints", &x.endpoints) +} + +func (x *session) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.LoadWait("msize", &x.msize) + m.LoadWait("version", &x.version) + m.LoadWait("cachePolicy", &x.cachePolicy) + m.LoadWait("aname", &x.aname) + m.LoadWait("superBlockFlags", &x.superBlockFlags) + m.LoadWait("connID", &x.connID) + m.LoadWait("inodeMappings", &x.inodeMappings) + m.LoadWait("mounter", &x.mounter) + m.LoadWait("endpoints", &x.endpoints) + m.AfterLoad(x.afterLoad) +} + +func init() { + state.Register("gofer.fileOperations", (*fileOperations)(nil), state.Fns{Save: (*fileOperations).save, Load: (*fileOperations).load}) + state.Register("gofer.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load}) + state.Register("gofer.inodeOperations", (*inodeOperations)(nil), state.Fns{Save: (*inodeOperations).save, Load: (*inodeOperations).load}) + state.Register("gofer.inodeFileState", (*inodeFileState)(nil), state.Fns{Save: (*inodeFileState).save, Load: (*inodeFileState).load}) + state.Register("gofer.endpointMaps", (*endpointMaps)(nil), state.Fns{Save: (*endpointMaps).save, Load: (*endpointMaps).load}) + state.Register("gofer.session", (*session)(nil), state.Fns{Save: (*session).save, Load: (*session).load}) +} diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go deleted file mode 100644 index f671f219d..000000000 --- a/pkg/sentry/fs/gofer/gofer_test.go +++ /dev/null @@ -1,310 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gofer - -import ( - "fmt" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/p9" - "gvisor.dev/gvisor/pkg/p9/p9test" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -// rootTest runs a test with a p9 mock and an fs.InodeOperations created from -// the attached root directory. The root file will be closed and client -// disconnected, but additional files must be closed manually. -func rootTest(t *testing.T, name string, cp cachePolicy, fn func(context.Context, *p9test.Harness, *p9test.Mock, *fs.Inode)) { - t.Run(name, func(t *testing.T) { - h, c := p9test.NewHarness(t) - defer h.Finish() - - // Create a new root. Note that we pass an empty, but non-nil - // map here. This allows tests to extend the root children - // dynamically. - root := h.NewDirectory(map[string]p9test.Generator{})(nil) - - // Return this as the root. - h.Attacher.EXPECT().Attach().Return(root, nil).Times(1) - - // ... and open via the client. - rootFile, err := c.Attach("/") - if err != nil { - t.Fatalf("unable to attach: %v", err) - } - defer rootFile.Close() - - // Wrap an a session. - s := &session{ - mounter: fs.RootOwner, - cachePolicy: cp, - client: c, - } - - // ... and an INode, with only the mode being explicitly valid for now. - ctx := contexttest.Context(t) - sattr, rootInodeOperations := newInodeOperations(ctx, s, contextFile{ - file: rootFile, - }, root.QID, p9.AttrMaskAll(), root.Attr, false /* socket */) - m := fs.NewMountSource(s, &filesystem{}, fs.MountSourceFlags{}) - rootInode := fs.NewInode(rootInodeOperations, m, sattr) - - // Ensure that the cache is fully invalidated, so that any - // close actions actually take place before the full harness is - // torn down. - defer func() { - m.FlushDirentRefs() - - // Wait for all resources to be released, otherwise the - // operations may fail after we close the rootFile. - fs.AsyncBarrier() - }() - - // Execute the test. - fn(ctx, h, root, rootInode) - }) -} - -func TestLookup(t *testing.T) { - type lookupTest struct { - // Name of the test. - name string - - // Expected return value. - want error - } - - tests := []lookupTest{ - { - name: "mock Walk passes (function succeeds)", - want: nil, - }, - { - name: "mock Walk fails (function fails)", - want: syscall.ENOENT, - }, - } - - const file = "file" // The walked target file. - - for _, test := range tests { - rootTest(t, test.name, cacheNone, func(ctx context.Context, h *p9test.Harness, rootFile *p9test.Mock, rootInode *fs.Inode) { - // Setup the appropriate result. - rootFile.WalkCallback = func() error { - return test.want - } - if test.want == nil { - // Set the contents of the root. We expect a - // normal file generator for ppp above. This is - // overriden by setting WalkErr in the mock. - rootFile.AddChild(file, h.NewFile()) - } - - // Call function. - dirent, err := rootInode.Lookup(ctx, file) - - // Unwrap the InodeOperations. - var newInodeOperations fs.InodeOperations - if dirent != nil { - if dirent.IsNegative() { - err = syscall.ENOENT - } else { - newInodeOperations = dirent.Inode.InodeOperations - } - } - - // Check return values. - if err != test.want { - t.Errorf("Lookup got err %v, want %v", err, test.want) - } - if err == nil && newInodeOperations == nil { - t.Errorf("Lookup got non-nil err and non-nil node, wanted at least one non-nil") - } - }) - } -} - -func TestRevalidation(t *testing.T) { - type revalidationTest struct { - cachePolicy cachePolicy - - // Whether dirent should be reloaded before any modifications. - preModificationWantReload bool - - // Whether dirent should be reloaded after updating an unstable - // attribute on the remote fs. - postModificationWantReload bool - - // Whether dirent unstable attributes should be updated after - // updating an attribute on the remote fs. - postModificationWantUpdatedAttrs bool - - // Whether dirent should be reloaded after the remote has - // removed the file. - postRemovalWantReload bool - } - - tests := []revalidationTest{ - { - // Policy cacheNone causes Revalidate to always return - // true. - cachePolicy: cacheNone, - preModificationWantReload: true, - postModificationWantReload: true, - postModificationWantUpdatedAttrs: true, - postRemovalWantReload: true, - }, - { - // Policy cacheAll causes Revalidate to always return - // false. - cachePolicy: cacheAll, - preModificationWantReload: false, - postModificationWantReload: false, - postModificationWantUpdatedAttrs: false, - postRemovalWantReload: false, - }, - { - // Policy cacheAllWritethrough causes Revalidate to - // always return false. - cachePolicy: cacheAllWritethrough, - preModificationWantReload: false, - postModificationWantReload: false, - postModificationWantUpdatedAttrs: false, - postRemovalWantReload: false, - }, - { - // Policy cacheRemoteRevalidating causes Revalidate to - // return update cached unstable attrs, and returns - // true only when the remote inode itself has been - // removed or replaced. - cachePolicy: cacheRemoteRevalidating, - preModificationWantReload: false, - postModificationWantReload: false, - postModificationWantUpdatedAttrs: true, - postRemovalWantReload: true, - }, - } - - const file = "file" // The file walked below. - - for _, test := range tests { - name := fmt.Sprintf("cachepolicy=%s", test.cachePolicy) - rootTest(t, name, test.cachePolicy, func(ctx context.Context, h *p9test.Harness, rootFile *p9test.Mock, rootInode *fs.Inode) { - // Wrap in a dirent object. - rootDir := fs.NewDirent(rootInode, "root") - - // Create a mock file a child of the root. We save when - // this is generated, so that when the time changed, we - // can update the original entry. - var origMocks []*p9test.Mock - rootFile.AddChild(file, func(parent *p9test.Mock) *p9test.Mock { - // Regular a regular file that has a consistent - // path number. This might be used by - // validation so we don't change it. - m := h.NewMock(parent, 0, p9.Attr{ - Mode: p9.ModeRegular, - }) - origMocks = append(origMocks, m) - return m - }) - - // Do the walk. - dirent, err := rootDir.Walk(ctx, rootDir, file) - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - - // We must release the dirent, of the test will fail - // with a reference leak. This is tracked by p9test. - defer dirent.DecRef() - - // Walk again. Depending on the cache policy, we may - // get a new dirent. - newDirent, err := rootDir.Walk(ctx, rootDir, file) - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - if test.preModificationWantReload && dirent == newDirent { - t.Errorf("Lookup with cachePolicy=%s got old dirent %+v, wanted a new dirent", test.cachePolicy, dirent) - } - if !test.preModificationWantReload && dirent != newDirent { - t.Errorf("Lookup with cachePolicy=%s got new dirent %+v, wanted old dirent %+v", test.cachePolicy, newDirent, dirent) - } - newDirent.DecRef() // See above. - - // Modify the underlying mocked file's modification - // time for the next walk that occurs. - nowSeconds := time.Now().Unix() - rootFile.AddChild(file, func(parent *p9test.Mock) *p9test.Mock { - // Ensure that the path is the same as above, - // but we change only the modification time of - // the file. - return h.NewMock(parent, 0, p9.Attr{ - Mode: p9.ModeRegular, - MTimeSeconds: uint64(nowSeconds), - }) - }) - - // We also modify the original time, so that GetAttr - // behaves as expected for the caching case. - for _, m := range origMocks { - m.Attr.MTimeSeconds = uint64(nowSeconds) - } - - // Walk again. Depending on the cache policy, we may - // get a new dirent. - newDirent, err = rootDir.Walk(ctx, rootDir, file) - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - if test.postModificationWantReload && dirent == newDirent { - t.Errorf("Lookup with cachePolicy=%s got old dirent, wanted a new dirent", test.cachePolicy) - } - if !test.postModificationWantReload && dirent != newDirent { - t.Errorf("Lookup with cachePolicy=%s got new dirent, wanted old dirent", test.cachePolicy) - } - uattrs, err := newDirent.Inode.UnstableAttr(ctx) - if err != nil { - t.Fatalf("Error getting unstable attrs: %v", err) - } - gotModTimeSeconds := uattrs.ModificationTime.Seconds() - if test.postModificationWantUpdatedAttrs && gotModTimeSeconds != nowSeconds { - t.Fatalf("Lookup with cachePolicy=%s got new modification time %v, wanted %v", test.cachePolicy, gotModTimeSeconds, nowSeconds) - } - newDirent.DecRef() // See above. - - // Remove the file from the remote fs, subsequent walks - // should now fail to find anything. - rootFile.RemoveChild(file) - - // Walk again. Depending on the cache policy, we may - // get ENOENT. - newDirent, err = rootDir.Walk(ctx, rootDir, file) - if test.postRemovalWantReload && err == nil { - t.Errorf("Lookup with cachePolicy=%s got nil error, wanted ENOENT", test.cachePolicy) - } - if !test.postRemovalWantReload && (err != nil || dirent != newDirent) { - t.Errorf("Lookup with cachePolicy=%s got new dirent and error %v, wanted old dirent and nil error", test.cachePolicy, err) - } - if err == nil { - newDirent.DecRef() // See above. - } - }) - } -} diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD deleted file mode 100644 index b1080fb1a..000000000 --- a/pkg/sentry/fs/host/BUILD +++ /dev/null @@ -1,83 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "host", - srcs = [ - "control.go", - "descriptor.go", - "descriptor_state.go", - "device.go", - "file.go", - "fs.go", - "inode.go", - "inode_state.go", - "ioctl_unsafe.go", - "socket.go", - "socket_iovec.go", - "socket_state.go", - "socket_unsafe.go", - "tty.go", - "util.go", - "util_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/host", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/log", - "//pkg/refs", - "//pkg/secio", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/safemem", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/unimpl", - "//pkg/sentry/uniqueid", - "//pkg/sentry/usermem", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/unet", - "//pkg/waiter", - ], -) - -go_test( - name = "host_test", - size = "small", - srcs = [ - "descriptor_test.go", - "fs_test.go", - "inode_test.go", - "socket_test.go", - "wait_test.go", - ], - embed = [":host"], - deps = [ - "//pkg/fd", - "//pkg/fdnotifier", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/host/descriptor_test.go b/pkg/sentry/fs/host/descriptor_test.go deleted file mode 100644 index 4205981f5..000000000 --- a/pkg/sentry/fs/host/descriptor_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "io/ioutil" - "path/filepath" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestDescriptorRelease(t *testing.T) { - for _, tc := range []struct { - name string - saveable bool - wouldBlock bool - }{ - {name: "all false"}, - {name: "saveable", saveable: true}, - {name: "wouldBlock", wouldBlock: true}, - } { - t.Run(tc.name, func(t *testing.T) { - dir, err := ioutil.TempDir("", "descriptor_test") - if err != nil { - t.Fatal("ioutil.TempDir() failed:", err) - } - - fd, err := syscall.Open(filepath.Join(dir, "file"), syscall.O_RDWR|syscall.O_CREAT, 0666) - if err != nil { - t.Fatal("failed to open temp file:", err) - } - - // FD ownership is transferred to the descritor. - queue := &waiter.Queue{} - d, err := newDescriptor(fd, false /* donated*/, tc.saveable, tc.wouldBlock, queue) - if err != nil { - syscall.Close(fd) - t.Fatalf("newDescriptor(%d, %t, false, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err) - } - if tc.saveable { - if d.origFD < 0 { - t.Errorf("saveable descriptor must preserve origFD, desc: %+v", d) - } - } - if tc.wouldBlock { - if !fdnotifier.HasFD(int32(d.value)) { - t.Errorf("FD not registered with notifier, desc: %+v", d) - } - } - - oldVal := d.value - d.Release() - if d.value != -1 { - t.Errorf("d.value want: -1, got: %d", d.value) - } - if tc.wouldBlock { - if fdnotifier.HasFD(int32(oldVal)) { - t.Errorf("FD not unregistered with notifier, desc: %+v", d) - } - } - }) - } -} diff --git a/pkg/sentry/fs/host/fs_test.go b/pkg/sentry/fs/host/fs_test.go deleted file mode 100644 index c6852ee30..000000000 --- a/pkg/sentry/fs/host/fs_test.go +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "fmt" - "io/ioutil" - "os" - "path" - "reflect" - "sort" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -// newTestMountNamespace creates a MountNamespace with a ramfs root. -// It returns the host folder created, which should be removed when done. -func newTestMountNamespace(t *testing.T) (*fs.MountNamespace, string, error) { - p, err := ioutil.TempDir("", "root") - if err != nil { - return nil, "", err - } - - fd, err := open(nil, p) - if err != nil { - os.RemoveAll(p) - return nil, "", err - } - ctx := contexttest.Context(t) - root, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false) - if err != nil { - os.RemoveAll(p) - return nil, "", err - } - mm, err := fs.NewMountNamespace(ctx, root) - if err != nil { - os.RemoveAll(p) - return nil, "", err - } - return mm, p, nil -} - -// createTestDirs populates the root with some test files and directories. -// /a/a1.txt -// /a/a2.txt -// /b/b1.txt -// /b/c/c1.txt -// /symlinks/normal.txt -// /symlinks/to_normal.txt -> /symlinks/normal.txt -// /symlinks/recursive -> /symlinks -func createTestDirs(ctx context.Context, t *testing.T, m *fs.MountNamespace) error { - r := m.Root() - defer r.DecRef() - - if err := r.CreateDirectory(ctx, r, "a", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - a, err := r.Walk(ctx, r, "a") - if err != nil { - return err - } - defer a.DecRef() - - a1, err := a.Create(ctx, r, "a1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - a1.DecRef() - - a2, err := a.Create(ctx, r, "a2.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - a2.DecRef() - - if err := r.CreateDirectory(ctx, r, "b", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - b, err := r.Walk(ctx, r, "b") - if err != nil { - return err - } - defer b.DecRef() - - b1, err := b.Create(ctx, r, "b1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - b1.DecRef() - - if err := b.CreateDirectory(ctx, r, "c", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - c, err := b.Walk(ctx, r, "c") - if err != nil { - return err - } - defer c.DecRef() - - c1, err := c.Create(ctx, r, "c1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - c1.DecRef() - - if err := r.CreateDirectory(ctx, r, "symlinks", fs.FilePermsFromMode(0777)); err != nil { - return err - } - - symlinks, err := r.Walk(ctx, r, "symlinks") - if err != nil { - return err - } - defer symlinks.DecRef() - - normal, err := symlinks.Create(ctx, r, "normal.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666)) - if err != nil { - return err - } - normal.DecRef() - - if err := symlinks.CreateLink(ctx, r, "/symlinks/normal.txt", "to_normal.txt"); err != nil { - return err - } - - return symlinks.CreateLink(ctx, r, "/symlinks", "recursive") -} - -// allPaths returns a slice of all paths of entries visible in the rootfs. -func allPaths(ctx context.Context, t *testing.T, m *fs.MountNamespace, base string) ([]string, error) { - var paths []string - root := m.Root() - defer root.DecRef() - - maxTraversals := uint(1) - d, err := m.FindLink(ctx, root, nil, base, &maxTraversals) - if err != nil { - t.Logf("FindLink failed for %q", base) - return paths, err - } - defer d.DecRef() - - if fs.IsDir(d.Inode.StableAttr) { - dir, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) - if err != nil { - return nil, fmt.Errorf("failed to open directory %q: %v", base, err) - } - iter, ok := dir.FileOperations.(fs.DirIterator) - if !ok { - return nil, fmt.Errorf("cannot directly iterate on host directory %q", base) - } - dirCtx := &fs.DirCtx{ - Serializer: noopDentrySerializer{}, - } - if _, err := fs.DirentReaddir(ctx, d, iter, root, dirCtx, 0); err != nil { - return nil, err - } - for name := range dirCtx.DentAttrs() { - if name == "." || name == ".." { - continue - } - - fullName := path.Join(base, name) - paths = append(paths, fullName) - - // Recurse. - subpaths, err := allPaths(ctx, t, m, fullName) - if err != nil { - return paths, err - } - paths = append(paths, subpaths...) - } - } - - return paths, nil -} - -type noopDentrySerializer struct{} - -func (noopDentrySerializer) CopyOut(string, fs.DentAttr) error { - return nil -} -func (noopDentrySerializer) Written() int { - return 4096 -} - -// pathsEqual returns true if the two string slices contain the same entries. -func pathsEqual(got, want []string) bool { - sort.Strings(got) - sort.Strings(want) - - if len(got) != len(want) { - return false - } - - for i := range got { - if got[i] != want[i] { - return false - } - } - - return true -} - -func TestWhitelist(t *testing.T) { - for _, test := range []struct { - // description of the test. - desc string - // paths are the paths to whitelist - paths []string - // want are all of the directory entries that should be - // visible (nothing beyond this set should be visible). - want []string - }{ - { - desc: "root", - paths: []string{"/"}, - want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt", "/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt", "/symlinks/recursive"}, - }, - { - desc: "top-level directories", - paths: []string{"/a", "/b"}, - want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "nested directories (1/2)", - paths: []string{"/b", "/b/c"}, - want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "nested directories (2/2)", - paths: []string{"/b/c", "/b"}, - want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "single file", - paths: []string{"/b/c/c1.txt"}, - want: []string{"/b", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "single file and directory", - paths: []string{"/a/a1.txt", "/b/c"}, - want: []string{"/a", "/a/a1.txt", "/b", "/b/c", "/b/c/c1.txt"}, - }, - { - desc: "symlink", - paths: []string{"/symlinks/to_normal.txt"}, - want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt"}, - }, - { - desc: "recursive symlink", - paths: []string{"/symlinks/recursive/normal.txt"}, - want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/recursive"}, - }, - } { - t.Run(test.desc, func(t *testing.T) { - m, p, err := newTestMountNamespace(t) - if err != nil { - t.Errorf("Failed to create MountNamespace: %v", err) - } - defer os.RemoveAll(p) - - ctx := withRoot(contexttest.RootContext(t), m.Root()) - if err := createTestDirs(ctx, t, m); err != nil { - t.Errorf("Failed to create test dirs: %v", err) - } - - if err := installWhitelist(ctx, m, test.paths); err != nil { - t.Errorf("installWhitelist(%v) err got %v want nil", test.paths, err) - } - - got, err := allPaths(ctx, t, m, "/") - if err != nil { - t.Fatalf("Failed to lookup paths (whitelisted: %v): %v", test.paths, err) - } - - if !pathsEqual(got, test.want) { - t.Errorf("For paths %v got %v want %v", test.paths, got, test.want) - } - }) - } -} - -func TestRootPath(t *testing.T) { - // Create a temp dir, which will be the root of our mounted fs. - rootPath, err := ioutil.TempDir(os.TempDir(), "root") - if err != nil { - t.Fatalf("TempDir failed: %v", err) - } - defer os.RemoveAll(rootPath) - - // Create two files inside the new root, one which will be whitelisted - // and one not. - whitelisted, err := ioutil.TempFile(rootPath, "white") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - if _, err := ioutil.TempFile(rootPath, "black"); err != nil { - t.Fatalf("TempFile failed: %v", err) - } - - // Create a mount with a root path and single whitelisted file. - hostFS := &Filesystem{} - ctx := contexttest.Context(t) - data := fmt.Sprintf("%s=%s,%s=%s", rootPathKey, rootPath, whitelistKey, whitelisted.Name()) - inode, err := hostFS.Mount(ctx, "", fs.MountSourceFlags{}, data, nil) - if err != nil { - t.Fatalf("Mount failed: %v", err) - } - mm, err := fs.NewMountNamespace(ctx, inode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - if err := hostFS.InstallWhitelist(ctx, mm); err != nil { - t.Fatalf("InstallWhitelist failed: %v", err) - } - - // Get the contents of the root directory. - rootDir := mm.Root() - rctx := withRoot(ctx, rootDir) - f, err := rootDir.Inode.GetFile(rctx, rootDir, fs.FileFlags{}) - if err != nil { - t.Fatalf("GetFile failed: %v", err) - } - c := &fs.CollectEntriesSerializer{} - if err := f.Readdir(rctx, c); err != nil { - t.Fatalf("Readdir failed: %v", err) - } - - // We should have only our whitelisted file, plus the dots. - want := []string{path.Base(whitelisted.Name()), ".", ".."} - got := c.Order - sort.Strings(want) - sort.Strings(got) - if !reflect.DeepEqual(got, want) { - t.Errorf("Readdir got %v, wanted %v", got, want) - } -} - -type rootContext struct { - context.Context - root *fs.Dirent -} - -// withRoot returns a copy of ctx with the given root. -func withRoot(ctx context.Context, root *fs.Dirent) context.Context { - return &rootContext{ - Context: ctx, - root: root, - } -} - -// Value implements Context.Value. -func (rc rootContext) Value(key interface{}) interface{} { - switch key { - case fs.CtxRoot: - rc.root.IncRef() - return rc.root - default: - return rc.Context.Value(key) - } -} diff --git a/pkg/sentry/fs/host/host_state_autogen.go b/pkg/sentry/fs/host/host_state_autogen.go new file mode 100755 index 000000000..9611da42a --- /dev/null +++ b/pkg/sentry/fs/host/host_state_autogen.go @@ -0,0 +1,138 @@ +// automatically generated by stateify. + +package host + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *descriptor) save(m state.Map) { + x.beforeSave() + m.Save("donated", &x.donated) + m.Save("origFD", &x.origFD) + m.Save("wouldBlock", &x.wouldBlock) +} + +func (x *descriptor) load(m state.Map) { + m.Load("donated", &x.donated) + m.Load("origFD", &x.origFD) + m.Load("wouldBlock", &x.wouldBlock) + m.AfterLoad(x.afterLoad) +} + +func (x *fileOperations) beforeSave() {} +func (x *fileOperations) save(m state.Map) { + x.beforeSave() + m.Save("iops", &x.iops) + m.Save("dirCursor", &x.dirCursor) +} + +func (x *fileOperations) afterLoad() {} +func (x *fileOperations) load(m state.Map) { + m.LoadWait("iops", &x.iops) + m.Load("dirCursor", &x.dirCursor) +} + +func (x *Filesystem) beforeSave() {} +func (x *Filesystem) save(m state.Map) { + x.beforeSave() + m.Save("paths", &x.paths) +} + +func (x *Filesystem) afterLoad() {} +func (x *Filesystem) load(m state.Map) { + m.Load("paths", &x.paths) +} + +func (x *superOperations) beforeSave() {} +func (x *superOperations) save(m state.Map) { + x.beforeSave() + m.Save("SimpleMountSourceOperations", &x.SimpleMountSourceOperations) + m.Save("root", &x.root) + m.Save("inodeMappings", &x.inodeMappings) + m.Save("mounter", &x.mounter) + m.Save("dontTranslateOwnership", &x.dontTranslateOwnership) +} + +func (x *superOperations) afterLoad() {} +func (x *superOperations) load(m state.Map) { + m.Load("SimpleMountSourceOperations", &x.SimpleMountSourceOperations) + m.Load("root", &x.root) + m.Load("inodeMappings", &x.inodeMappings) + m.Load("mounter", &x.mounter) + m.Load("dontTranslateOwnership", &x.dontTranslateOwnership) +} + +func (x *inodeOperations) beforeSave() {} +func (x *inodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("fileState", &x.fileState) + m.Save("cachingInodeOps", &x.cachingInodeOps) +} + +func (x *inodeOperations) afterLoad() {} +func (x *inodeOperations) load(m state.Map) { + m.LoadWait("fileState", &x.fileState) + m.Load("cachingInodeOps", &x.cachingInodeOps) +} + +func (x *inodeFileState) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.queue) { m.Failf("queue is %v, expected zero", x.queue) } + m.Save("mops", &x.mops) + m.Save("descriptor", &x.descriptor) + m.Save("sattr", &x.sattr) + m.Save("savedUAttr", &x.savedUAttr) +} + +func (x *inodeFileState) load(m state.Map) { + m.LoadWait("mops", &x.mops) + m.LoadWait("descriptor", &x.descriptor) + m.LoadWait("sattr", &x.sattr) + m.Load("savedUAttr", &x.savedUAttr) + m.AfterLoad(x.afterLoad) +} + +func (x *ConnectedEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("queue", &x.queue) + m.Save("path", &x.path) + m.Save("ref", &x.ref) + m.Save("srfd", &x.srfd) + m.Save("stype", &x.stype) +} + +func (x *ConnectedEndpoint) load(m state.Map) { + m.Load("queue", &x.queue) + m.Load("path", &x.path) + m.Load("ref", &x.ref) + m.LoadWait("srfd", &x.srfd) + m.Load("stype", &x.stype) + m.AfterLoad(x.afterLoad) +} + +func (x *TTYFileOperations) beforeSave() {} +func (x *TTYFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("fileOperations", &x.fileOperations) + m.Save("session", &x.session) + m.Save("fgProcessGroup", &x.fgProcessGroup) +} + +func (x *TTYFileOperations) afterLoad() {} +func (x *TTYFileOperations) load(m state.Map) { + m.Load("fileOperations", &x.fileOperations) + m.Load("session", &x.session) + m.Load("fgProcessGroup", &x.fgProcessGroup) +} + +func init() { + state.Register("host.descriptor", (*descriptor)(nil), state.Fns{Save: (*descriptor).save, Load: (*descriptor).load}) + state.Register("host.fileOperations", (*fileOperations)(nil), state.Fns{Save: (*fileOperations).save, Load: (*fileOperations).load}) + state.Register("host.Filesystem", (*Filesystem)(nil), state.Fns{Save: (*Filesystem).save, Load: (*Filesystem).load}) + state.Register("host.superOperations", (*superOperations)(nil), state.Fns{Save: (*superOperations).save, Load: (*superOperations).load}) + state.Register("host.inodeOperations", (*inodeOperations)(nil), state.Fns{Save: (*inodeOperations).save, Load: (*inodeOperations).load}) + state.Register("host.inodeFileState", (*inodeFileState)(nil), state.Fns{Save: (*inodeFileState).save, Load: (*inodeFileState).load}) + state.Register("host.ConnectedEndpoint", (*ConnectedEndpoint)(nil), state.Fns{Save: (*ConnectedEndpoint).save, Load: (*ConnectedEndpoint).load}) + state.Register("host.TTYFileOperations", (*TTYFileOperations)(nil), state.Fns{Save: (*TTYFileOperations).save, Load: (*TTYFileOperations).load}) +} diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go deleted file mode 100644 index bd2d856ee..000000000 --- a/pkg/sentry/fs/host/inode_test.go +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "io/ioutil" - "os" - "path" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -// TestMultipleReaddir verifies that multiple Readdir calls return the same -// thing if they use different dir contexts. -func TestMultipleReaddir(t *testing.T) { - p, err := ioutil.TempDir("", "readdir") - if err != nil { - t.Fatalf("Failed to create test dir: %v", err) - } - defer os.RemoveAll(p) - - f, err := os.Create(path.Join(p, "a.txt")) - if err != nil { - t.Fatalf("Failed to create a.txt: %v", err) - } - f.Close() - - f, err = os.Create(path.Join(p, "b.txt")) - if err != nil { - t.Fatalf("Failed to create b.txt: %v", err) - } - f.Close() - - fd, err := open(nil, p) - if err != nil { - t.Fatalf("Failed to open %q: %v", p, err) - } - ctx := contexttest.Context(t) - n, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false) - if err != nil { - t.Fatalf("Failed to create inode: %v", err) - } - - dirent := fs.NewDirent(n, "readdir") - openFile, err := n.GetFile(ctx, dirent, fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("Failed to get file: %v", err) - } - defer openFile.DecRef() - - c1 := &fs.DirCtx{DirCursor: new(string)} - if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, c1, 0); err != nil { - t.Fatalf("First Readdir failed: %v", err) - } - - c2 := &fs.DirCtx{DirCursor: new(string)} - if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, c2, 0); err != nil { - t.Errorf("Second Readdir failed: %v", err) - } - - if _, ok := c1.DentAttrs()["a.txt"]; !ok { - t.Errorf("want a.txt in first Readdir, got %v", c1.DentAttrs()) - } - if _, ok := c1.DentAttrs()["b.txt"]; !ok { - t.Errorf("want b.txt in first Readdir, got %v", c1.DentAttrs()) - } - - if _, ok := c2.DentAttrs()["a.txt"]; !ok { - t.Errorf("want a.txt in second Readdir, got %v", c2.DentAttrs()) - } - if _, ok := c2.DentAttrs()["b.txt"]; !ok { - t.Errorf("want b.txt in second Readdir, got %v", c2.DentAttrs()) - } -} - -// TestCloseFD verifies fds will be closed. -func TestCloseFD(t *testing.T) { - var p [2]int - if err := syscall.Pipe(p[0:]); err != nil { - t.Fatalf("Failed to create pipe %v", err) - } - defer syscall.Close(p[0]) - defer syscall.Close(p[1]) - - // Use the write-end because we will detect if it's closed on the read end. - ctx := contexttest.Context(t) - file, err := NewFile(ctx, p[1], fs.RootOwner) - if err != nil { - t.Fatalf("Failed to create File: %v", err) - } - file.DecRef() - - s := make([]byte, 10) - if c, err := syscall.Read(p[0], s); c != 0 || err != nil { - t.Errorf("want 0, nil (EOF) from read end, got %v, %v", c, err) - } -} diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go deleted file mode 100644 index 68b38fd1c..000000000 --- a/pkg/sentry/fs/host/socket_test.go +++ /dev/null @@ -1,246 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "reflect" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/fd" - "gvisor.dev/gvisor/pkg/fdnotifier" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/waiter" -) - -var ( - // Make sure that ConnectedEndpoint implements transport.ConnectedEndpoint. - _ = transport.ConnectedEndpoint(new(ConnectedEndpoint)) - - // Make sure that ConnectedEndpoint implements transport.Receiver. - _ = transport.Receiver(new(ConnectedEndpoint)) -) - -func getFl(fd int) (uint32, error) { - fl, _, err := syscall.RawSyscall(syscall.SYS_FCNTL, uintptr(fd), syscall.F_GETFL, 0) - if err == 0 { - return uint32(fl), nil - } - return 0, err -} - -func TestSocketIsBlocking(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - - fl, err := getFl(pair[0]) - if err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err) - } - if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK { - t.Fatalf("Expected socket %v to be blocking", pair[0]) - } - if fl, err = getFl(pair[1]); err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err) - } - if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK { - t.Fatalf("Expected socket %v to be blocking", pair[1]) - } - sock, err := newSocket(contexttest.Context(t), pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) failed => %v", pair[0], err) - } - defer sock.DecRef() - // Test that the socket now is non-blocking. - if fl, err = getFl(pair[0]); err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err) - } - if fl&syscall.O_NONBLOCK != syscall.O_NONBLOCK { - t.Errorf("Expected socket %v to have become non-blocking", pair[0]) - } - if fl, err = getFl(pair[1]); err != nil { - t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[1], err) - } - if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK { - t.Errorf("Did not expect socket %v to become non-blocking", pair[1]) - } -} - -func TestSocketWritev(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - socket, err := newSocket(contexttest.Context(t), pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer socket.DecRef() - buf := []byte("hello world\n") - n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(buf)) - if err != nil { - t.Fatalf("socket writev failed: %v", err) - } - - if n != int64(len(buf)) { - t.Fatalf("socket writev wrote incorrect bytes: %d", n) - } -} - -func TestSocketWritevLen0(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - socket, err := newSocket(contexttest.Context(t), pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer socket.DecRef() - n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(nil)) - if err != nil { - t.Fatalf("socket writev failed: %v", err) - } - - if n != 0 { - t.Fatalf("socket writev wrote incorrect bytes: %d", n) - } -} - -func TestSocketSendMsgLen0(t *testing.T) { - // Using socketpair here because it's already connected. - pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("host socket creation failed: %v", err) - } - sfile, err := newSocket(contexttest.Context(t), pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer sfile.DecRef() - - s := sfile.FileOperations.(socket.Socket) - n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{}) - if n != 0 { - t.Fatalf("socket sendmsg() failed: %v wrote: %d", terr, n) - } - - if terr != nil { - t.Fatalf("socket sendmsg() failed: %v", terr) - } -} - -func TestListen(t *testing.T) { - pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err) - } - sfile1, err := newSocket(contexttest.Context(t), pair[0], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[0], err) - } - defer sfile1.DecRef() - socket1 := sfile1.FileOperations.(socket.Socket) - - sfile2, err := newSocket(contexttest.Context(t), pair[1], false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", pair[1], err) - } - defer sfile2.DecRef() - socket2 := sfile2.FileOperations.(socket.Socket) - - // Socketpairs can not be listened to. - if err := socket1.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { - t.Fatalf("socket1.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) - } - if err := socket2.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { - t.Fatalf("socket2.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) - } - - // Create a Unix socket, do not bind it. - sock, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err) - } - sfile3, err := newSocket(contexttest.Context(t), sock, false) - if err != nil { - t.Fatalf("newSocket(%v) => %v", sock, err) - } - defer sfile3.DecRef() - socket3 := sfile3.FileOperations.(socket.Socket) - - // This socket is not bound so we can't listen on it. - if err := socket3.Listen(nil, 64); err != syserr.ErrInvalidEndpointState { - t.Fatalf("socket3.Listen(nil, 64) => %v, want syserr.ErrInvalidEndpointState", err) - } -} - -func TestPasscred(t *testing.T) { - e := ConnectedEndpoint{} - if got, want := e.Passcred(), false; got != want { - t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want) - } -} - -func TestGetLocalAddress(t *testing.T) { - e := ConnectedEndpoint{path: "foo"} - want := tcpip.FullAddress{Addr: tcpip.Address("foo")} - if got, err := e.GetLocalAddress(); err != nil || got != want { - t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil) - } -} - -func TestQueuedSize(t *testing.T) { - e := ConnectedEndpoint{} - tests := []struct { - name string - f func() int64 - }{ - {"SendQueuedSize", e.SendQueuedSize}, - {"RecvQueuedSize", e.RecvQueuedSize}, - } - - for _, test := range tests { - if got, want := test.f(), int64(-1); got != want { - t.Errorf("Got %#v.%s() = %d, want = %d", e, test.name, got, want) - } - } -} - -func TestRelease(t *testing.T) { - f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} - want := &ConnectedEndpoint{queue: c.queue} - want.ref.DecRef() - fdnotifier.AddFD(int32(c.file.FD()), nil) - c.Release() - if !reflect.DeepEqual(c, want) { - t.Errorf("got = %#v, want = %#v", c, want) - } -} diff --git a/pkg/sentry/fs/host/wait_test.go b/pkg/sentry/fs/host/wait_test.go deleted file mode 100644 index 88d24d693..000000000 --- a/pkg/sentry/fs/host/wait_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package host - -import ( - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestWait(t *testing.T) { - var fds [2]int - err := syscall.Pipe(fds[:]) - if err != nil { - t.Fatalf("Unable to create pipe: %v", err) - } - - defer syscall.Close(fds[1]) - - ctx := contexttest.Context(t) - file, err := NewFile(ctx, fds[0], fs.RootOwner) - if err != nil { - syscall.Close(fds[0]) - t.Fatalf("NewFile failed: %v", err) - } - - defer file.DecRef() - - r := file.Readiness(waiter.EventIn) - if r != 0 { - t.Fatalf("File is ready for read when it shouldn't be.") - } - - e, ch := waiter.NewChannelEntry(nil) - file.EventRegister(&e, waiter.EventIn) - defer file.EventUnregister(&e) - - // Check that there are no notifications yet. - if len(ch) != 0 { - t.Fatalf("Channel is non-empty") - } - - // Write to the pipe, so it should be writable now. - syscall.Write(fds[1], []byte{1}) - - // Check that we get a notification. We need to yield the current thread - // so that the fdnotifier can deliver notifications, so we use a - // 1-second timeout instead of just checking the length of the channel. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Channel not notified") - } -} diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go deleted file mode 100644 index 7e56d5969..000000000 --- a/pkg/sentry/fs/inode_overlay_test.go +++ /dev/null @@ -1,470 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/syserror" -) - -func TestLookup(t *testing.T) { - ctx := contexttest.Context(t) - for _, test := range []struct { - // Test description. - desc string - - // Lookup parameters. - dir *fs.Inode - name string - - // Want from lookup. - found bool - hasUpper bool - hasLower bool - }{ - { - desc: "no upper, lower has name", - dir: fs.NewTestOverlayDir(ctx, - nil, /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: false, - hasLower: true, - }, - { - desc: "no lower, upper has name", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* upper */ - nil, /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: false, - }, - { - desc: "upper and lower, only lower has name", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "b", - dir: false, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: false, - hasLower: true, - }, - { - desc: "upper and lower, only upper has name", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "b", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: false, - }, - { - desc: "upper and lower, both have file", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: false, - }, - { - desc: "upper and lower, both have directory", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: true, - }, - }, nil), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: true, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: true, - hasLower: true, - }, - { - desc: "upper and lower, upper negative masks lower file", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, nil, []string{"a"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: false, - hasUpper: false, - hasLower: false, - }, - { - desc: "upper and lower, upper negative does not mask lower file", - dir: fs.NewTestOverlayDir(ctx, - newTestRamfsDir(ctx, nil, []string{"b"}), /* upper */ - newTestRamfsDir(ctx, []dirContent{ - { - name: "a", - dir: false, - }, - }, nil), /* lower */ - false /* revalidate */), - name: "a", - found: true, - hasUpper: false, - hasLower: true, - }, - } { - t.Run(test.desc, func(t *testing.T) { - dirent, err := test.dir.Lookup(ctx, test.name) - if test.found && (err == syserror.ENOENT || dirent.IsNegative()) { - t.Fatalf("lookup %q expected to find positive dirent, got dirent %v err %v", test.name, dirent, err) - } - if !test.found { - if err != syserror.ENOENT && !dirent.IsNegative() { - t.Errorf("lookup %q expected to return ENOENT or negative dirent, got dirent %v err %v", test.name, dirent, err) - } - // Nothing more to check. - return - } - if hasUpper := dirent.Inode.TestHasUpperFS(); hasUpper != test.hasUpper { - t.Fatalf("lookup got upper filesystem %v, want %v", hasUpper, test.hasUpper) - } - if hasLower := dirent.Inode.TestHasLowerFS(); hasLower != test.hasLower { - t.Errorf("lookup got lower filesystem %v, want %v", hasLower, test.hasLower) - } - }) - } -} - -func TestLookupRevalidation(t *testing.T) { - // File name used in the tests. - fileName := "foofile" - ctx := contexttest.Context(t) - for _, tc := range []struct { - // Test description. - desc string - - // Upper and lower fs for the overlay. - upper *fs.Inode - lower *fs.Inode - - // Whether the upper requires revalidation. - revalidate bool - - // Whether we should get the same dirent on second lookup. - wantSame bool - }{ - { - desc: "file from upper with no revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, nil, nil), - revalidate: false, - wantSame: true, - }, - { - desc: "file from upper with revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, nil, nil), - revalidate: true, - wantSame: false, - }, - { - desc: "file from lower with no revalidation", - upper: newTestRamfsDir(ctx, nil, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: false, - wantSame: true, - }, - { - desc: "file from lower with revalidation", - upper: newTestRamfsDir(ctx, nil, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: true, - // The file does not exist in the upper, so we do not - // need to revalidate it. - wantSame: true, - }, - { - desc: "file from upper and lower with no revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: false, - wantSame: true, - }, - { - desc: "file from upper and lower with revalidation", - upper: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - lower: newTestRamfsDir(ctx, []dirContent{{name: fileName}}, nil), - revalidate: true, - wantSame: false, - }, - } { - t.Run(tc.desc, func(t *testing.T) { - root := fs.NewDirent(newTestRamfsDir(ctx, nil, nil), "root") - ctx = &rootContext{ - Context: ctx, - root: root, - } - overlay := fs.NewDirent(fs.NewTestOverlayDir(ctx, tc.upper, tc.lower, tc.revalidate), "overlay") - // Lookup the file twice through the overlay. - first, err := overlay.Walk(ctx, root, fileName) - if err != nil { - t.Fatalf("overlay.Walk(%q) failed: %v", fileName, err) - } - second, err := overlay.Walk(ctx, root, fileName) - if err != nil { - t.Fatalf("overlay.Walk(%q) failed: %v", fileName, err) - } - - if tc.wantSame && first != second { - t.Errorf("dirent lookup got different dirents, wanted same\nfirst=%+v\nsecond=%+v", first, second) - } else if !tc.wantSame && first == second { - t.Errorf("dirent lookup got the same dirent, wanted different: %+v", first) - } - }) - } -} - -func TestCacheFlush(t *testing.T) { - ctx := contexttest.Context(t) - - // Upper and lower each have a file. - upperFileName := "file-from-upper" - lowerFileName := "file-from-lower" - upper := newTestRamfsDir(ctx, []dirContent{{name: upperFileName}}, nil) - lower := newTestRamfsDir(ctx, []dirContent{{name: lowerFileName}}, nil) - - overlay := fs.NewTestOverlayDir(ctx, upper, lower, true /* revalidate */) - - mns, err := fs.NewMountNamespace(ctx, overlay) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - root := mns.Root() - defer root.DecRef() - - ctx = &rootContext{ - Context: ctx, - root: root, - } - - for _, fileName := range []string{upperFileName, lowerFileName} { - // Walk to the file. - maxTraversals := uint(0) - dirent, err := mns.FindInode(ctx, root, nil, fileName, &maxTraversals) - if err != nil { - t.Fatalf("FindInode(%q) failed: %v", fileName, err) - } - - // Get a file from the dirent. - file, err := dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}) - if err != nil { - t.Fatalf("GetFile() failed: %v", err) - } - - // The dirent should have 3 refs, one from us, one from the - // file, and one from the dirent cache. - // dirent cache. - if got, want := dirent.ReadRefs(), 3; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - - // Drop the file reference. - file.DecRef() - - // Dirent should have 2 refs left. - if got, want := dirent.ReadRefs(), 2; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - - // Flush the dirent cache. - mns.FlushMountSourceRefs() - - // Dirent should have 1 ref left from the dirent cache. - if got, want := dirent.ReadRefs(), 1; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - - // Drop our ref. - dirent.DecRef() - - // We should be back to zero refs. - if got, want := dirent.ReadRefs(), 0; int(got) != want { - t.Errorf("dirent.ReadRefs() got %d want %d", got, want) - } - } - -} - -type dir struct { - fs.InodeOperations - - // List of negative child names. - negative []string - - // ReaddirCalled records whether Readdir was called on a file - // corresponding to this inode. - ReaddirCalled bool -} - -// Getxattr implements InodeOperations.Getxattr. -func (d *dir) Getxattr(inode *fs.Inode, name string) (string, error) { - for _, n := range d.negative { - if name == fs.XattrOverlayWhiteout(n) { - return "y", nil - } - } - return "", syserror.ENOATTR -} - -// GetFile implements InodeOperations.GetFile. -func (d *dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - file, err := d.InodeOperations.GetFile(ctx, dirent, flags) - if err != nil { - return nil, err - } - defer file.DecRef() - // Wrap the file's FileOperations in a dirFile. - fops := &dirFile{ - FileOperations: file.FileOperations, - inode: d, - } - return fs.NewFile(ctx, dirent, flags, fops), nil -} - -type dirContent struct { - name string - dir bool -} - -type dirFile struct { - fs.FileOperations - inode *dir -} - -type inode struct { - fsutil.InodeGenericChecker `state:"nosave"` - fsutil.InodeNoExtendedAttributes `state:"nosave"` - fsutil.InodeNoopRelease `state:"nosave"` - fsutil.InodeNoopWriteOut `state:"nosave"` - fsutil.InodeNotAllocatable `state:"nosave"` - fsutil.InodeNotDirectory `state:"nosave"` - fsutil.InodeNotMappable `state:"nosave"` - fsutil.InodeNotSocket `state:"nosave"` - fsutil.InodeNotSymlink `state:"nosave"` - fsutil.InodeNotTruncatable `state:"nosave"` - fsutil.InodeNotVirtual `state:"nosave"` - - fsutil.InodeSimpleAttributes - fsutil.InodeStaticFileGetter -} - -// Readdir implements fs.FileOperations.Readdir. It sets the ReaddirCalled -// field on the inode. -func (f *dirFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySerializer) (int64, error) { - f.inode.ReaddirCalled = true - return f.FileOperations.Readdir(ctx, file, ser) -} - -func newTestRamfsInode(ctx context.Context, msrc *fs.MountSource) *fs.Inode { - inode := fs.NewInode(&inode{ - InodeStaticFileGetter: fsutil.InodeStaticFileGetter{ - Contents: []byte("foobar"), - }, - }, msrc, fs.StableAttr{Type: fs.RegularFile}) - return inode -} - -func newTestRamfsDir(ctx context.Context, contains []dirContent, negative []string) *fs.Inode { - msrc := fs.NewPseudoMountSource() - contents := make(map[string]*fs.Inode) - for _, c := range contains { - if c.dir { - contents[c.name] = newTestRamfsDir(ctx, nil, nil) - } else { - contents[c.name] = newTestRamfsInode(ctx, msrc) - } - } - dops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermissions{ - User: fs.PermMask{Read: true, Execute: true}, - }) - return fs.NewInode(&dir{ - InodeOperations: dops, - negative: negative, - }, msrc, fs.StableAttr{Type: fs.Directory}) -} diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD deleted file mode 100644 index 08d7c0c57..000000000 --- a/pkg/sentry/fs/lock/BUILD +++ /dev/null @@ -1,58 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "lock_range", - out = "lock_range.go", - package = "lock", - prefix = "Lock", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "lock_set", - out = "lock_set.go", - consts = { - "minDegree": "3", - }, - package = "lock", - prefix = "Lock", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "LockRange", - "Value": "Lock", - "Functions": "lockSetFunctions", - }, -) - -go_library( - name = "lock", - srcs = [ - "lock.go", - "lock_range.go", - "lock_set.go", - "lock_set_functions.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/lock", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/log", - "//pkg/waiter", - ], -) - -go_test( - name = "lock_test", - size = "small", - srcs = [ - "lock_range_test.go", - "lock_test.go", - ], - embed = [":lock"], -) diff --git a/pkg/sentry/fs/lock/lock_range.go b/pkg/sentry/fs/lock/lock_range.go new file mode 100755 index 000000000..7a6f77640 --- /dev/null +++ b/pkg/sentry/fs/lock/lock_range.go @@ -0,0 +1,62 @@ +package lock + +// A Range represents a contiguous range of T. +// +// +stateify savable +type LockRange struct { + // Start is the inclusive start of the range. + Start uint64 + + // End is the exclusive end of the range. + End uint64 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r LockRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r LockRange) Length() uint64 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r LockRange) Contains(x uint64) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r LockRange) Overlaps(r2 LockRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r LockRange) IsSupersetOf(r2 LockRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r LockRange) Intersect(r2 LockRange) LockRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r LockRange) CanSplitAt(x uint64) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/fs/lock/lock_range_test.go b/pkg/sentry/fs/lock/lock_range_test.go deleted file mode 100644 index 6221199d1..000000000 --- a/pkg/sentry/fs/lock/lock_range_test.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package lock - -import ( - "syscall" - "testing" -) - -func TestComputeRange(t *testing.T) { - tests := []struct { - // Description of test. - name string - - // Requested start of the lock range. - start int64 - - // Requested length of the lock range, - // can be negative :( - length int64 - - // Pre-computed file offset based on whence. - // Will be added to start. - offset int64 - - // Expected error. - err error - - // If error is nil, the expected LockRange. - LockRange - }{ - { - name: "offset, start, and length all zero", - LockRange: LockRange{Start: 0, End: LockEOF}, - }, - { - name: "zero offset, zero start, positive length", - start: 0, - length: 4096, - offset: 0, - LockRange: LockRange{Start: 0, End: 4096}, - }, - { - name: "zero offset, negative start", - start: -4096, - offset: 0, - err: syscall.EINVAL, - }, - { - name: "large offset, negative start, positive length", - start: -2048, - length: 2048, - offset: 4096, - LockRange: LockRange{Start: 2048, End: 4096}, - }, - { - name: "large offset, negative start, zero length", - start: -2048, - length: 0, - offset: 4096, - LockRange: LockRange{Start: 2048, End: LockEOF}, - }, - { - name: "zero offset, zero start, negative length", - start: 0, - length: -4096, - offset: 0, - err: syscall.EINVAL, - }, - { - name: "large offset, zero start, negative length", - start: 0, - length: -4096, - offset: 4096, - LockRange: LockRange{Start: 0, End: 4096}, - }, - { - name: "offset, start, and length equal, length is negative", - start: 1024, - length: -1024, - offset: 1024, - LockRange: LockRange{Start: 1024, End: 2048}, - }, - { - name: "offset, start, and length equal, start is negative", - start: -1024, - length: 1024, - offset: 1024, - LockRange: LockRange{Start: 0, End: 1024}, - }, - { - name: "offset, start, and length equal, offset is negative", - start: 1024, - length: 1024, - offset: -1024, - LockRange: LockRange{Start: 0, End: 1024}, - }, - { - name: "offset, start, and length equal, all negative", - start: -1024, - length: -1024, - offset: -1024, - err: syscall.EINVAL, - }, - { - name: "offset, start, and length equal, all positive", - start: 1024, - length: 1024, - offset: 1024, - LockRange: LockRange{Start: 2048, End: 3072}, - }, - } - - for _, test := range tests { - rng, err := ComputeRange(test.start, test.length, test.offset) - if err != test.err { - t.Errorf("%s: lockRange(%d, %d, %d) got error %v, want %v", test.name, test.start, test.length, test.offset, err, test.err) - continue - } - if err == nil && rng != test.LockRange { - t.Errorf("%s: lockRange(%d, %d, %d) got LockRange %v, want %v", test.name, test.start, test.length, test.offset, rng, test.LockRange) - } - } -} diff --git a/pkg/sentry/fs/lock/lock_set.go b/pkg/sentry/fs/lock/lock_set.go new file mode 100755 index 000000000..127ca5012 --- /dev/null +++ b/pkg/sentry/fs/lock/lock_set.go @@ -0,0 +1,1270 @@ +package lock + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + LockminDegree = 3 + + LockmaxDegree = 2 * LockminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type LockSet struct { + root Locknode `state:".(*LockSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *LockSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *LockSet) IsEmptyRange(r LockRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *LockSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *LockSet) SpanRange(r LockRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *LockSet) FirstSegment() LockIterator { + if s.root.nrSegments == 0 { + return LockIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *LockSet) LastSegment() LockIterator { + if s.root.nrSegments == 0 { + return LockIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *LockSet) FirstGap() LockGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return LockGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *LockSet) LastGap() LockGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return LockGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *LockSet) Find(key uint64) (LockIterator, LockGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return LockIterator{n, i}, LockGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return LockIterator{}, LockGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *LockSet) FindSegment(key uint64) LockIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *LockSet) LowerBoundSegment(min uint64) LockIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *LockSet) UpperBoundSegment(max uint64) LockIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *LockSet) FindGap(key uint64) LockGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *LockSet) LowerBoundGap(min uint64) LockGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *LockSet) UpperBoundGap(max uint64) LockGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *LockSet) Add(r LockRange, val Lock) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *LockSet) AddWithoutMerging(r LockRange, val Lock) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *LockSet) Insert(gap LockGapIterator, r LockRange, val Lock) LockIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (lockSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (lockSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (lockSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *LockSet) InsertWithoutMerging(gap LockGapIterator, r LockRange, val Lock) LockIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *LockSet) InsertWithoutMergingUnchecked(gap LockGapIterator, r LockRange, val Lock) LockIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return LockIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *LockSet) Remove(seg LockIterator) LockGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + lockSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(LockGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *LockSet) RemoveAll() { + s.root = Locknode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *LockSet) RemoveRange(r LockRange) LockGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *LockSet) Merge(first, second LockIterator) LockIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *LockSet) MergeUnchecked(first, second LockIterator) LockIterator { + if first.End() == second.Start() { + if mval, ok := (lockSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return LockIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *LockSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *LockSet) MergeRange(r LockRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *LockSet) MergeAdjacent(r LockRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *LockSet) Split(seg LockIterator, split uint64) (LockIterator, LockIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *LockSet) SplitUnchecked(seg LockIterator, split uint64) (LockIterator, LockIterator) { + val1, val2 := (lockSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), LockRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *LockSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *LockSet) Isolate(seg LockIterator, r LockRange) LockIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *LockSet) ApplyContiguous(r LockRange, fn func(seg LockIterator)) LockGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return LockGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return LockGapIterator{} + } + } +} + +// +stateify savable +type Locknode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *Locknode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [LockmaxDegree - 1]LockRange + values [LockmaxDegree - 1]Lock + children [LockmaxDegree]*Locknode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Locknode) firstSegment() LockIterator { + for n.hasChildren { + n = n.children[0] + } + return LockIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Locknode) lastSegment() LockIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return LockIterator{n, n.nrSegments - 1} +} + +func (n *Locknode) prevSibling() *Locknode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *Locknode) nextSibling() *Locknode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *Locknode) rebalanceBeforeInsert(gap LockGapIterator) LockGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < LockmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &Locknode{ + nrSegments: LockminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &Locknode{ + nrSegments: LockminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:LockminDegree-1], n.keys[:LockminDegree-1]) + copy(left.values[:LockminDegree-1], n.values[:LockminDegree-1]) + copy(right.keys[:LockminDegree-1], n.keys[LockminDegree:]) + copy(right.values[:LockminDegree-1], n.values[LockminDegree:]) + n.keys[0], n.values[0] = n.keys[LockminDegree-1], n.values[LockminDegree-1] + LockzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:LockminDegree], n.children[:LockminDegree]) + copy(right.children[:LockminDegree], n.children[LockminDegree:]) + LockzeroNodeSlice(n.children[2:]) + for i := 0; i < LockminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < LockminDegree { + return LockGapIterator{left, gap.index} + } + return LockGapIterator{right, gap.index - LockminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[LockminDegree-1], n.values[LockminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &Locknode{ + nrSegments: LockminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:LockminDegree-1], n.keys[LockminDegree:]) + copy(sibling.values[:LockminDegree-1], n.values[LockminDegree:]) + LockzeroValueSlice(n.values[LockminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:LockminDegree], n.children[LockminDegree:]) + LockzeroNodeSlice(n.children[LockminDegree:]) + for i := 0; i < LockminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = LockminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < LockminDegree { + return gap + } + return LockGapIterator{sibling, gap.index - LockminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *Locknode) rebalanceAfterRemove(gap LockGapIterator) LockGapIterator { + for { + if n.nrSegments >= LockminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= LockminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + lockSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return LockGapIterator{n, 0} + } + if gap.node == n { + return LockGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= LockminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + lockSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return LockGapIterator{n, n.nrSegments} + } + return LockGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return LockGapIterator{p, gap.index} + } + if gap.node == right { + return LockGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *Locknode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = LockGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + lockSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type LockIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *Locknode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg LockIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg LockIterator) Range() LockRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg LockIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg LockIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg LockIterator) SetRangeUnchecked(r LockRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg LockIterator) SetRange(r LockRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg LockIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg LockIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg LockIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg LockIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg LockIterator) Value() Lock { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg LockIterator) ValuePtr() *Lock { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg LockIterator) SetValue(val Lock) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg LockIterator) PrevSegment() LockIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return LockIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return LockIterator{} + } + return LocksegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg LockIterator) NextSegment() LockIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return LockIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return LockIterator{} + } + return LocksegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg LockIterator) PrevGap() LockGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return LockGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg LockIterator) NextGap() LockGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return LockGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg LockIterator) PrevNonEmpty() (LockIterator, LockGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return LockIterator{}, gap + } + return gap.PrevSegment(), LockGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg LockIterator) NextNonEmpty() (LockIterator, LockGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return LockIterator{}, gap + } + return gap.NextSegment(), LockGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type LockGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *Locknode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap LockGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap LockGapIterator) Range() LockRange { + return LockRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap LockGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return lockSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap LockGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return lockSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap LockGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap LockGapIterator) PrevSegment() LockIterator { + return LocksegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap LockGapIterator) NextSegment() LockIterator { + return LocksegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap LockGapIterator) PrevGap() LockGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return LockGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap LockGapIterator) NextGap() LockGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return LockGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func LocksegmentBeforePosition(n *Locknode, i int) LockIterator { + for i == 0 { + if n.parent == nil { + return LockIterator{} + } + n, i = n.parent, n.parentIndex + } + return LockIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func LocksegmentAfterPosition(n *Locknode, i int) LockIterator { + for i == n.nrSegments { + if n.parent == nil { + return LockIterator{} + } + n, i = n.parent, n.parentIndex + } + return LockIterator{n, i} +} + +func LockzeroValueSlice(slice []Lock) { + + for i := range slice { + lockSetFunctions{}.ClearValue(&slice[i]) + } +} + +func LockzeroNodeSlice(slice []*Locknode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *LockSet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *Locknode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *Locknode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type LockSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []Lock +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *LockSet) ExportSortedSlices() *LockSegmentDataSlices { + var sds LockSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *LockSet) ImportSortedSlices(sds *LockSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := LockRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *LockSet) saveRoot() *LockSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *LockSet) loadRoot(sds *LockSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/fs/lock/lock_state_autogen.go b/pkg/sentry/fs/lock/lock_state_autogen.go new file mode 100755 index 000000000..cb69e2cd0 --- /dev/null +++ b/pkg/sentry/fs/lock/lock_state_autogen.go @@ -0,0 +1,106 @@ +// automatically generated by stateify. + +package lock + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Lock) beforeSave() {} +func (x *Lock) save(m state.Map) { + x.beforeSave() + m.Save("Readers", &x.Readers) + m.Save("HasWriter", &x.HasWriter) + m.Save("Writer", &x.Writer) +} + +func (x *Lock) afterLoad() {} +func (x *Lock) load(m state.Map) { + m.Load("Readers", &x.Readers) + m.Load("HasWriter", &x.HasWriter) + m.Load("Writer", &x.Writer) +} + +func (x *Locks) beforeSave() {} +func (x *Locks) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.blockedQueue) { m.Failf("blockedQueue is %v, expected zero", x.blockedQueue) } + m.Save("locks", &x.locks) +} + +func (x *Locks) afterLoad() {} +func (x *Locks) load(m state.Map) { + m.Load("locks", &x.locks) +} + +func (x *LockRange) beforeSave() {} +func (x *LockRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *LockRange) afterLoad() {} +func (x *LockRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *LockSet) beforeSave() {} +func (x *LockSet) save(m state.Map) { + x.beforeSave() + var root *LockSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *LockSet) afterLoad() {} +func (x *LockSet) load(m state.Map) { + m.LoadValue("root", new(*LockSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*LockSegmentDataSlices)) }) +} + +func (x *Locknode) beforeSave() {} +func (x *Locknode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *Locknode) afterLoad() {} +func (x *Locknode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *LockSegmentDataSlices) beforeSave() {} +func (x *LockSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *LockSegmentDataSlices) afterLoad() {} +func (x *LockSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func init() { + state.Register("lock.Lock", (*Lock)(nil), state.Fns{Save: (*Lock).save, Load: (*Lock).load}) + state.Register("lock.Locks", (*Locks)(nil), state.Fns{Save: (*Locks).save, Load: (*Locks).load}) + state.Register("lock.LockRange", (*LockRange)(nil), state.Fns{Save: (*LockRange).save, Load: (*LockRange).load}) + state.Register("lock.LockSet", (*LockSet)(nil), state.Fns{Save: (*LockSet).save, Load: (*LockSet).load}) + state.Register("lock.Locknode", (*Locknode)(nil), state.Fns{Save: (*Locknode).save, Load: (*Locknode).load}) + state.Register("lock.LockSegmentDataSlices", (*LockSegmentDataSlices)(nil), state.Fns{Save: (*LockSegmentDataSlices).save, Load: (*LockSegmentDataSlices).load}) +} diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go deleted file mode 100644 index ba002aeb7..000000000 --- a/pkg/sentry/fs/lock/lock_test.go +++ /dev/null @@ -1,1059 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package lock - -import ( - "reflect" - "testing" -) - -type entry struct { - Lock - LockRange -} - -func equals(e0, e1 []entry) bool { - if len(e0) != len(e1) { - return false - } - for i := range e0 { - for k := range e0[i].Lock.Readers { - if !e1[i].Lock.Readers[k] { - return false - } - } - for k := range e1[i].Lock.Readers { - if !e0[i].Lock.Readers[k] { - return false - } - } - if !reflect.DeepEqual(e0[i].LockRange, e1[i].LockRange) { - return false - } - if e0[i].Lock.HasWriter != e1[i].Lock.HasWriter { - return false - } - if e0[i].Lock.Writer != e1[i].Lock.Writer { - return false - } - } - return true -} - -// fill a LockSet with consecutive region locks. Will panic if -// LockRanges are not consecutive. -func fill(entries []entry) LockSet { - l := LockSet{} - for _, e := range entries { - gap := l.FindGap(e.LockRange.Start) - if !gap.Ok() { - panic("cannot insert into existing segment") - } - l.Insert(gap, e.LockRange, e.Lock) - } - return l -} - -func TestCanLockEmpty(t *testing.T) { - l := LockSet{} - - // Expect to be able to take any locks given that the set is empty. - eof := l.FirstGap().End() - r := LockRange{0, eof} - if !l.canLock(1, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 1) - } - if !l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2) - } - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } - if !l.canLock(2, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 2) - } -} - -func TestCanLock(t *testing.T) { - // + -------------- + ---------- + -------------- + --------- + - // | Readers 1 & 2 | Readers 1 | Readers 1 & 3 | Writer 1 | - // + ------------- + ---------- + -------------- + --------- + - // 0 1024 2048 3072 4096 - l := fill([]entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 3: true}}, - LockRange: LockRange{2048, 3072}, - }, - { - Lock: Lock{HasWriter: true, Writer: 1}, - LockRange: LockRange{3072, 4096}, - }, - }) - - // Now that we have a mildly interesting layout, try some checks on different - // ranges, uids, and lock types. - // - // Expect to be able to extend the read lock, despite the writer lock, because - // the writer has the same uid as the requested read lock. - r := LockRange{0, 8192} - if !l.canLock(1, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 1) - } - // Expect to *not* be able to extend the read lock since there is an overlapping - // writer region locked by someone other than the uid. - if l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got true, want false", ReadLock, r, 2) - } - // Expect to be able to extend the read lock if there are only other readers in - // the way. - r = LockRange{64, 3072} - if !l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2) - } - // Expect to be able to set a read lock beyond the range of any existing locks. - r = LockRange{4096, 10240} - if !l.canLock(2, ReadLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", ReadLock, r, 2) - } - - // Expect to not be able to take a write lock with other readers in the way. - r = LockRange{0, 8192} - if l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got true, want false", WriteLock, r, 1) - } - // Expect to be able to extend the write lock for the same uid. - r = LockRange{3072, 8192} - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } - // Expect to not be able to overlap a write lock for two different uids. - if l.canLock(2, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got true, want false", WriteLock, r, 2) - } - // Expect to be able to set a write lock that is beyond the range of any - // existing locks. - r = LockRange{8192, 10240} - if !l.canLock(2, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 2) - } - // Expect to be able to upgrade a read lock (any portion of it). - r = LockRange{1024, 2048} - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } - r = LockRange{1080, 2000} - if !l.canLock(1, WriteLock, r) { - t.Fatalf("canLock type %d for range %v and uid %d got false, want true", WriteLock, r, 1) - } -} - -func TestSetLock(t *testing.T) { - tests := []struct { - // description of test. - name string - - // LockSet entries to pre-fill. - before []entry - - // Description of region to lock: - // - // start is the file offset of the lock. - start uint64 - // end is the end file offset of the lock. - end uint64 - // uid of lock attempter. - uid UniqueID - // lock type requested. - lockType LockType - - // success is true if taking the above - // lock should succeed. - success bool - - // Expected layout of the set after locking - // if success is true. - after []entry - }{ - { - name: "set zero length ReadLock on empty set", - start: 0, - end: 0, - uid: 0, - lockType: ReadLock, - success: true, - }, - { - name: "set zero length WriteLock on empty set", - start: 0, - end: 0, - uid: 0, - lockType: WriteLock, - success: true, - }, - { - name: "set ReadLock on empty set", - start: 0, - end: LockEOF, - uid: 0, - lockType: ReadLock, - success: true, - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "set WriteLock on empty set", - start: 0, - end: LockEOF, - uid: 0, - lockType: WriteLock, - success: true, - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "set ReadLock on WriteLock same uid", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - lockType: ReadLock, - success: true, - // + ----------- + --------------------------- + - // | Readers 0 | Writer 0 | - // + ----------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "set WriteLock on ReadLock same uid", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - lockType: WriteLock, - success: true, - // + ----------- + --------------------------- + - // | Writer 0 | Readers 0 | - // + ----------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "set ReadLock on WriteLock different uid", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: ReadLock, - success: false, - }, - { - name: "set WriteLock on ReadLock different uid", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: WriteLock, - success: false, - }, - { - name: "split ReadLock for overlapping lock at start 0", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: ReadLock, - success: true, - // + -------------- + --------------------------- + - // | Readers 0 & 1 | Readers 0 | - // + -------------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "split ReadLock for overlapping lock at non-zero start", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 4096, - end: 8192, - uid: 1, - lockType: ReadLock, - success: true, - // + ---------- + -------------- + ----------- + - // | Readers 0 | Readers 0 & 1 | Readers 0 | - // + ---------- + -------------- + ----------- + - // 0 4096 8192 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{4096, 8192}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{8192, LockEOF}, - }, - }, - }, - { - name: "fill front gap with ReadLock", - // + --------- + ---------------------------- + - // | gap | Readers 0 | - // + --------- + ---------------------------- + - // 0 1024 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - start: 0, - end: 8192, - uid: 0, - lockType: ReadLock, - success: true, - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "fill end gap with ReadLock", - // + ---------------------------- + - // | Readers 0 | - // + ---------------------------- + - // 0 4096 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, 4096}, - }, - }, - start: 1024, - end: LockEOF, - uid: 0, - lockType: ReadLock, - success: true, - // Note that this is not merged after lock does a Split. This is - // fine because the two locks will still *behave* as one. In other - // words we can fragment any lock all we want and semantically it - // makes no difference. - // - // + ----------- + --------------------------- + - // | Readers 0 | Readers 0 | - // + ----------- + --------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - }, - { - name: "fill gap with ReadLock and split", - // + --------- + ---------------------------- + - // | gap | Readers 0 | - // + --------- + ---------------------------- + - // 0 1024 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 1, - lockType: ReadLock, - success: true, - // + --------- + ------------- + ------------- + - // | Reader 1 | Readers 0 & 1 | Reader 0 | - // + ----------+ ------------- + ------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "upgrade ReadLock to WriteLock for single uid fill gap", - // + ------------- + --------- + --- + ------------- + - // | Readers 0 & 1 | Readers 0 | gap | Readers 0 & 2 | - // + ------------- + --------- + --- + ------------- + - // 0 1024 2048 4096 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 1024, - end: 4096, - uid: 0, - lockType: WriteLock, - success: true, - // + ------------- + -------- + ------------- + - // | Readers 0 & 1 | Writer 0 | Readers 0 & 2 | - // + ------------- + -------- + ------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "upgrade ReadLock to WriteLock for single uid keep gap", - // + ------------- + --------- + --- + ------------- + - // | Readers 0 & 1 | Readers 0 | gap | Readers 0 & 2 | - // + ------------- + --------- + --- + ------------- + - // 0 1024 2048 4096 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 1024, - end: 3072, - uid: 0, - lockType: WriteLock, - success: true, - // + ------------- + -------- + --- + ------------- + - // | Readers 0 & 1 | Writer 0 | gap | Readers 0 & 2 | - // + ------------- + -------- + --- + ------------- + - // 0 1024 3072 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{1024, 3072}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "fail to upgrade ReadLock to WriteLock with conflicting Reader", - // + ------------- + --------- + - // | Readers 0 & 1 | Readers 0 | - // + ------------- + --------- + - // 0 1024 2048 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, 2048}, - }, - }, - start: 0, - end: 2048, - uid: 0, - lockType: WriteLock, - success: false, - }, - { - name: "take WriteLock on whole file if all uids are the same", - // + ------------- + --------- + --------- + ---------- + - // | Writer 0 | Readers 0 | Readers 0 | Readers 0 | - // + ------------- + --------- + --------- + ---------- + - // 0 1024 2048 4096 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{1024, 2048}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{2048, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - lockType: WriteLock, - success: true, - // We do not manually merge locks. Semantically a fragmented lock - // held by the same uid will behave as one lock so it makes no difference. - // - // + ------------- + ---------------------------- + - // | Writer 0 | Writer 0 | - // + ------------- + ---------------------------- + - // 0 1024 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{1024, LockEOF}, - }, - }, - }, - } - - for _, test := range tests { - l := fill(test.before) - - r := LockRange{Start: test.start, End: test.end} - success := l.lock(test.uid, test.lockType, r) - var got []entry - for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - got = append(got, entry{ - Lock: seg.Value(), - LockRange: seg.Range(), - }) - } - - if success != test.success { - t.Errorf("%s: setlock(%v, %+v, %d, %d) got success %v, want %v", test.name, test.before, r, test.uid, test.lockType, success, test.success) - continue - } - - if success { - if !equals(got, test.after) { - t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after) - } - } - } -} - -func TestUnlock(t *testing.T) { - tests := []struct { - // description of test. - name string - - // LockSet entries to pre-fill. - before []entry - - // Description of region to unlock: - // - // start is the file start of the lock. - start uint64 - // end is the end file start of the lock. - end uint64 - // uid of lock holder. - uid UniqueID - - // Expected layout of the set after unlocking. - after []entry - }{ - { - name: "unlock zero length on empty set", - start: 0, - end: 0, - uid: 0, - }, - { - name: "unlock on empty set (no-op)", - start: 0, - end: LockEOF, - uid: 0, - }, - { - name: "unlock uid not locked (no-op)", - // + --------------------------- + - // | Readers 1 & 2 | - // + --------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 1024, - end: 4096, - uid: 0, - // + --------------------------- + - // | Readers 1 & 2 | - // + --------------------------- + - // 0 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - }, - { - name: "unlock ReadLock over entire file", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - }, - { - name: "unlock WriteLock over entire file", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - }, - { - name: "unlock partial ReadLock (start)", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - // + ------ + --------------------------- + - // | gap | Readers 0 | - // +------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock partial WriteLock (start)", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 0, - end: 4096, - uid: 0, - // + ------ + --------------------------- + - // | gap | Writer 0 | - // +------- + --------------------------- + - // 0 4096 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock partial ReadLock (end)", - // + ----------------------------------------- + - // | Readers 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 4096, - end: LockEOF, - uid: 0, - // + --------------------------- + - // | Readers 0 | - // +---------------------------- + - // 0 4096 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, - LockRange: LockRange{0, 4096}, - }, - }, - }, - { - name: "unlock partial WriteLock (end)", - // + ----------------------------------------- + - // | Writer 0 | - // + ----------------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 4096, - end: LockEOF, - uid: 0, - // + --------------------------- + - // | Writer 0 | - // +---------------------------- + - // 0 4096 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 4096}, - }, - }, - }, - { - name: "unlock for single uid", - // + ------------- + --------- + ------------------- + - // | Readers 0 & 1 | Writer 0 | Readers 0 & 1 & 2 | - // + ------------- + --------- + ------------------- + - // 0 1024 4096 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 0, - end: LockEOF, - uid: 0, - // + --------- + --- + --------------- + - // | Readers 1 | gap | Readers 1 & 2 | - // + --------- + --- + --------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock subsection locked", - // + ------------------------------- + - // | Readers 0 & 1 & 2 | - // + ------------------------------- + - // 0 max uint64 - before: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, - LockRange: LockRange{0, LockEOF}, - }, - }, - start: 1024, - end: 4096, - uid: 0, - // + ----------------- + ------------- + ----------------- + - // | Readers 0 & 1 & 2 | Readers 1 & 2 | Readers 0 & 1 & 2 | - // + ----------------- + ------------- + ----------------- + - // 0 1024 4096 max uint64 - after: []entry{ - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, - LockRange: LockRange{1024, 4096}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock mid-gap to increase gap", - // + --------- + ----- + ------------------- + - // | Writer 0 | gap | Readers 0 & 1 | - // + --------- + ----- + ------------------- + - // 0 1024 4096 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 8, - end: 2048, - uid: 0, - // + --------- + ----- + ------------------- + - // | Writer 0 | gap | Readers 0 & 1 | - // + --------- + ----- + ------------------- + - // 0 8 4096 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 8}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - }, - { - name: "unlock split region on uid mid-gap", - // + --------- + ----- + ------------------- + - // | Writer 0 | gap | Readers 0 & 1 | - // + --------- + ----- + ------------------- + - // 0 1024 4096 max uint64 - before: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{4096, LockEOF}, - }, - }, - start: 2048, - end: 8192, - uid: 0, - // + --------- + ----- + --------- + ------------- + - // | Writer 0 | gap | Readers 1 | Readers 0 & 1 | - // + --------- + ----- + --------- + ------------- + - // 0 1024 4096 8192 max uint64 - after: []entry{ - { - Lock: Lock{HasWriter: true, Writer: 0}, - LockRange: LockRange{0, 1024}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, - LockRange: LockRange{4096, 8192}, - }, - { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, - LockRange: LockRange{8192, LockEOF}, - }, - }, - }, - } - - for _, test := range tests { - l := fill(test.before) - - r := LockRange{Start: test.start, End: test.end} - l.unlock(test.uid, r) - var got []entry - for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - got = append(got, entry{ - Lock: seg.Value(), - LockRange: seg.Range(), - }) - } - if !equals(got, test.after) { - t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after) - } - } -} diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go deleted file mode 100644 index 0b84732aa..000000000 --- a/pkg/sentry/fs/mount_test.go +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" -) - -// cacheReallyContains iterates through the dirent cache to determine whether -// it contains the given dirent. -func cacheReallyContains(cache *DirentCache, d *Dirent) bool { - for i := cache.list.Front(); i != nil; i = i.Next() { - if i == d { - return true - } - } - return false -} - -func mountPathsAre(root *Dirent, got []*Mount, want ...string) error { - gotPaths := make(map[string]struct{}, len(got)) - gotStr := make([]string, len(got)) - for i, g := range got { - groot := g.Root() - name, _ := groot.FullName(root) - groot.DecRef() - gotStr[i] = name - gotPaths[name] = struct{}{} - } - if len(got) != len(want) { - return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want) - } - for _, w := range want { - if _, ok := gotPaths[w]; !ok { - return fmt.Errorf("no mount with path %q found", w) - } - } - return nil -} - -// TestMountSourceOnlyCachedOnce tests that a Dirent that is mounted over only ends -// up in a single Dirent Cache. NOTE(b/63848693): Having a dirent in multiple -// caches causes major consistency issues. -func TestMountSourceOnlyCachedOnce(t *testing.T) { - ctx := contexttest.Context(t) - - rootCache := NewDirentCache(100) - rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{ - Type: Directory, - }) - mm, err := NewMountNamespace(ctx, rootInode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - rootDirent := mm.Root() - defer rootDirent.DecRef() - - // Get a child of the root which we will mount over. Note that the - // MockInodeOperations causes Walk to always succeed. - child, err := rootDirent.Walk(ctx, rootDirent, "child") - if err != nil { - t.Fatalf("failed to walk to child dirent: %v", err) - } - child.maybeExtendReference() // Cache. - - // Ensure that the root cache contains the child. - if !cacheReallyContains(rootCache, child) { - t.Errorf("wanted rootCache to contain child dirent, but it did not") - } - - // Create a new cache and inode, and mount it over child. - submountCache := NewDirentCache(100) - submountInode := NewMockInode(ctx, NewMockMountSource(submountCache), StableAttr{ - Type: Directory, - }) - if err := mm.Mount(ctx, child, submountInode); err != nil { - t.Fatalf("failed to mount over child: %v", err) - } - - // Walk to the child again. - child2, err := rootDirent.Walk(ctx, rootDirent, "child") - if err != nil { - t.Fatalf("failed to walk to child dirent: %v", err) - } - - // Should have a different Dirent than before. - if child == child2 { - t.Fatalf("expected %v not equal to %v, but they are the same", child, child2) - } - - // Neither of the caches should no contain the child. - if cacheReallyContains(rootCache, child) { - t.Errorf("wanted rootCache not to contain child dirent, but it did") - } - if cacheReallyContains(submountCache, child) { - t.Errorf("wanted submountCache not to contain child dirent, but it did") - } -} - -func TestAllMountsUnder(t *testing.T) { - ctx := contexttest.Context(t) - - rootCache := NewDirentCache(100) - rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{ - Type: Directory, - }) - mm, err := NewMountNamespace(ctx, rootInode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - rootDirent := mm.Root() - defer rootDirent.DecRef() - - // Add mounts at the following paths: - paths := []string{ - "/foo", - "/foo/bar", - "/foo/bar/baz", - "/foo/qux", - "/waldo", - } - - var maxTraversals uint - for _, p := range paths { - maxTraversals = 0 - d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", p, err) - } - - submountInode := NewMockInode(ctx, NewMockMountSource(nil), StableAttr{ - Type: Directory, - }) - if err := mm.Mount(ctx, d, submountInode); err != nil { - t.Fatalf("could not mount at %q: %v", p, err) - } - d.DecRef() - } - - // mm root should contain all submounts (and does not include the root mount). - rootMnt := mm.FindMount(rootDirent) - submounts := mm.AllMountsUnder(rootMnt) - allPaths := append(paths, "/") - if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil { - t.Error(err) - } - - // Each mount should have a unique ID. - foundIDs := make(map[uint64]struct{}) - for _, m := range submounts { - if _, ok := foundIDs[m.ID]; ok { - t.Errorf("got multiple mounts with id %d", m.ID) - } - foundIDs[m.ID] = struct{}{} - } - - // Root mount should have no parent. - if p := rootMnt.ParentID; p != invalidMountID { - t.Errorf("root.Parent got %v wanted nil", p) - } - - // Check that "foo" mount has 3 children. - maxTraversals = 0 - d, err := mm.FindLink(ctx, rootDirent, nil, "/foo", &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", "/foo", err) - } - defer d.DecRef() - submounts = mm.AllMountsUnder(mm.FindMount(d)) - if err := mountPathsAre(rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil { - t.Error(err) - } - - // "waldo" mount should have no children. - maxTraversals = 0 - waldo, err := mm.FindLink(ctx, rootDirent, nil, "/waldo", &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", "/waldo", err) - } - defer waldo.DecRef() - submounts = mm.AllMountsUnder(mm.FindMount(waldo)) - if err := mountPathsAre(rootDirent, submounts, "/waldo"); err != nil { - t.Error(err) - } -} - -func TestUnmount(t *testing.T) { - ctx := contexttest.Context(t) - - rootCache := NewDirentCache(100) - rootInode := NewMockInode(ctx, NewMockMountSource(rootCache), StableAttr{ - Type: Directory, - }) - mm, err := NewMountNamespace(ctx, rootInode) - if err != nil { - t.Fatalf("NewMountNamespace failed: %v", err) - } - rootDirent := mm.Root() - defer rootDirent.DecRef() - - // Add mounts at the following paths: - paths := []string{ - "/foo", - "/foo/bar", - "/foo/bar/goo", - "/foo/bar/goo/abc", - "/foo/abc", - "/foo/def", - "/waldo", - "/wally", - } - - var maxTraversals uint - for _, p := range paths { - maxTraversals = 0 - d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", p, err) - } - - submountInode := NewMockInode(ctx, NewMockMountSource(nil), StableAttr{ - Type: Directory, - }) - if err := mm.Mount(ctx, d, submountInode); err != nil { - t.Fatalf("could not mount at %q: %v", p, err) - } - d.DecRef() - } - - allPaths := make([]string, len(paths)+1) - allPaths[0] = "/" - copy(allPaths[1:], paths) - - rootMnt := mm.FindMount(rootDirent) - for i := len(paths) - 1; i >= 0; i-- { - maxTraversals = 0 - p := paths[i] - d, err := mm.FindLink(ctx, rootDirent, nil, p, &maxTraversals) - if err != nil { - t.Fatalf("could not find path %q in mount manager: %v", p, err) - } - - if err := mm.Unmount(ctx, d, false); err != nil { - t.Fatalf("could not unmount at %q: %v", p, err) - } - d.DecRef() - - // Remove the path that has been unmounted and the check that the remaining - // mounts are still there. - allPaths = allPaths[:len(allPaths)-1] - submounts := mm.AllMountsUnder(rootMnt) - if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil { - t.Error(err) - } - } -} diff --git a/pkg/sentry/fs/mounts_test.go b/pkg/sentry/fs/mounts_test.go deleted file mode 100644 index 764318895..000000000 --- a/pkg/sentry/fs/mounts_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" -) - -// Creates a new MountNamespace with filesystem: -// / (root dir) -// |-foo (dir) -// |-bar (file) -func createMountNamespace(ctx context.Context) (*fs.MountNamespace, error) { - perms := fs.FilePermsFromMode(0777) - m := fs.NewPseudoMountSource() - - barFile := fsutil.NewSimpleFileInode(ctx, fs.RootOwner, perms, 0) - fooDir := ramfs.NewDir(ctx, map[string]*fs.Inode{ - "bar": fs.NewInode(barFile, m, fs.StableAttr{Type: fs.RegularFile}), - }, fs.RootOwner, perms) - rootDir := ramfs.NewDir(ctx, map[string]*fs.Inode{ - "foo": fs.NewInode(fooDir, m, fs.StableAttr{Type: fs.Directory}), - }, fs.RootOwner, perms) - - return fs.NewMountNamespace(ctx, fs.NewInode(rootDir, m, fs.StableAttr{Type: fs.Directory})) -} - -func TestFindLink(t *testing.T) { - ctx := contexttest.Context(t) - mm, err := createMountNamespace(ctx) - if err != nil { - t.Fatalf("createMountNamespace failed: %v", err) - } - - root := mm.Root() - defer root.DecRef() - foo, err := root.Walk(ctx, root, "foo") - if err != nil { - t.Fatalf("Error walking to foo: %v", err) - } - - // Positive cases. - for _, tc := range []struct { - findPath string - wd *fs.Dirent - wantPath string - }{ - {".", root, "/"}, - {".", foo, "/foo"}, - {"..", foo, "/"}, - {"../../..", foo, "/"}, - {"///foo", foo, "/foo"}, - {"/foo", foo, "/foo"}, - {"/foo/bar", foo, "/foo/bar"}, - {"/foo/.///./bar", foo, "/foo/bar"}, - {"/foo///bar", foo, "/foo/bar"}, - {"/foo/../foo/bar", foo, "/foo/bar"}, - {"foo/bar", root, "/foo/bar"}, - {"foo////bar", root, "/foo/bar"}, - {"bar", foo, "/foo/bar"}, - } { - wdPath, _ := tc.wd.FullName(root) - maxTraversals := uint(0) - if d, err := mm.FindLink(ctx, root, tc.wd, tc.findPath, &maxTraversals); err != nil { - t.Errorf("FindLink(%q, wd=%q) failed: %v", tc.findPath, wdPath, err) - } else if got, _ := d.FullName(root); got != tc.wantPath { - t.Errorf("FindLink(%q, wd=%q) got dirent %q, want %q", tc.findPath, wdPath, got, tc.wantPath) - } - } - - // Negative cases. - for _, tc := range []struct { - findPath string - wd *fs.Dirent - }{ - {"bar", root}, - {"/bar", root}, - {"/foo/../../bar", root}, - {"foo", foo}, - } { - wdPath, _ := tc.wd.FullName(root) - maxTraversals := uint(0) - if _, err := mm.FindLink(ctx, root, tc.wd, tc.findPath, &maxTraversals); err == nil { - t.Errorf("FindLink(%q, wd=%q) did not return error", tc.findPath, wdPath) - } - } -} diff --git a/pkg/sentry/fs/path_test.go b/pkg/sentry/fs/path_test.go deleted file mode 100644 index e6f57ebba..000000000 --- a/pkg/sentry/fs/path_test.go +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fs - -import ( - "testing" -) - -// TestSplitLast tests variants of path splitting. -func TestSplitLast(t *testing.T) { - cases := []struct { - path string - dir string - file string - }{ - {path: "/", dir: "/", file: "."}, - {path: "/.", dir: "/", file: "."}, - {path: "/./", dir: "/", file: "."}, - {path: "/./.", dir: "/.", file: "."}, - {path: "/././", dir: "/.", file: "."}, - {path: "/./..", dir: "/.", file: ".."}, - {path: "/./../", dir: "/.", file: ".."}, - {path: "/..", dir: "/", file: ".."}, - {path: "/../", dir: "/", file: ".."}, - {path: "/../.", dir: "/..", file: "."}, - {path: "/.././", dir: "/..", file: "."}, - {path: "/../..", dir: "/..", file: ".."}, - {path: "/../../", dir: "/..", file: ".."}, - - {path: "", dir: ".", file: "."}, - {path: ".", dir: ".", file: "."}, - {path: "./", dir: ".", file: "."}, - {path: "./.", dir: ".", file: "."}, - {path: "././", dir: ".", file: "."}, - {path: "./..", dir: ".", file: ".."}, - {path: "./../", dir: ".", file: ".."}, - {path: "..", dir: ".", file: ".."}, - {path: "../", dir: ".", file: ".."}, - {path: "../.", dir: "..", file: "."}, - {path: ".././", dir: "..", file: "."}, - {path: "../..", dir: "..", file: ".."}, - {path: "../../", dir: "..", file: ".."}, - - {path: "/foo", dir: "/", file: "foo"}, - {path: "/foo/", dir: "/", file: "foo"}, - {path: "/foo/.", dir: "/foo", file: "."}, - {path: "/foo/./", dir: "/foo", file: "."}, - {path: "/foo/./.", dir: "/foo/.", file: "."}, - {path: "/foo/./..", dir: "/foo/.", file: ".."}, - {path: "/foo/..", dir: "/foo", file: ".."}, - {path: "/foo/../", dir: "/foo", file: ".."}, - {path: "/foo/../.", dir: "/foo/..", file: "."}, - {path: "/foo/../..", dir: "/foo/..", file: ".."}, - - {path: "/foo/bar", dir: "/foo", file: "bar"}, - {path: "/foo/bar/", dir: "/foo", file: "bar"}, - {path: "/foo/bar/.", dir: "/foo/bar", file: "."}, - {path: "/foo/bar/./", dir: "/foo/bar", file: "."}, - {path: "/foo/bar/./.", dir: "/foo/bar/.", file: "."}, - {path: "/foo/bar/./..", dir: "/foo/bar/.", file: ".."}, - {path: "/foo/bar/..", dir: "/foo/bar", file: ".."}, - {path: "/foo/bar/../", dir: "/foo/bar", file: ".."}, - {path: "/foo/bar/../.", dir: "/foo/bar/..", file: "."}, - {path: "/foo/bar/../..", dir: "/foo/bar/..", file: ".."}, - - {path: "foo", dir: ".", file: "foo"}, - {path: "foo", dir: ".", file: "foo"}, - {path: "foo/", dir: ".", file: "foo"}, - {path: "foo/.", dir: "foo", file: "."}, - {path: "foo/./", dir: "foo", file: "."}, - {path: "foo/./.", dir: "foo/.", file: "."}, - {path: "foo/./..", dir: "foo/.", file: ".."}, - {path: "foo/..", dir: "foo", file: ".."}, - {path: "foo/../", dir: "foo", file: ".."}, - {path: "foo/../.", dir: "foo/..", file: "."}, - {path: "foo/../..", dir: "foo/..", file: ".."}, - {path: "foo/", dir: ".", file: "foo"}, - {path: "foo/.", dir: "foo", file: "."}, - - {path: "foo/bar", dir: "foo", file: "bar"}, - {path: "foo/bar/", dir: "foo", file: "bar"}, - {path: "foo/bar/.", dir: "foo/bar", file: "."}, - {path: "foo/bar/./", dir: "foo/bar", file: "."}, - {path: "foo/bar/./.", dir: "foo/bar/.", file: "."}, - {path: "foo/bar/./..", dir: "foo/bar/.", file: ".."}, - {path: "foo/bar/..", dir: "foo/bar", file: ".."}, - {path: "foo/bar/../", dir: "foo/bar", file: ".."}, - {path: "foo/bar/../.", dir: "foo/bar/..", file: "."}, - {path: "foo/bar/../..", dir: "foo/bar/..", file: ".."}, - {path: "foo/bar/", dir: "foo", file: "bar"}, - {path: "foo/bar/.", dir: "foo/bar", file: "."}, - } - - for _, c := range cases { - dir, file := SplitLast(c.path) - if dir != c.dir || file != c.file { - t.Errorf("SplitLast(%q) got (%q, %q), expected (%q, %q)", c.path, dir, file, c.dir, c.file) - } - } -} - -// TestSplitFirst tests variants of path splitting. -func TestSplitFirst(t *testing.T) { - cases := []struct { - path string - first string - remainder string - }{ - {path: "/", first: "/", remainder: ""}, - {path: "/.", first: "/", remainder: "."}, - {path: "///.", first: "/", remainder: "//."}, - {path: "/.///", first: "/", remainder: "."}, - {path: "/./.", first: "/", remainder: "./."}, - {path: "/././", first: "/", remainder: "./."}, - {path: "/./..", first: "/", remainder: "./.."}, - {path: "/./../", first: "/", remainder: "./.."}, - {path: "/..", first: "/", remainder: ".."}, - {path: "/../", first: "/", remainder: ".."}, - {path: "/../.", first: "/", remainder: "../."}, - {path: "/.././", first: "/", remainder: "../."}, - {path: "/../..", first: "/", remainder: "../.."}, - {path: "/../../", first: "/", remainder: "../.."}, - - {path: "", first: ".", remainder: ""}, - {path: ".", first: ".", remainder: ""}, - {path: "./", first: ".", remainder: ""}, - {path: ".///", first: ".", remainder: ""}, - {path: "./.", first: ".", remainder: "."}, - {path: "././", first: ".", remainder: "."}, - {path: "./..", first: ".", remainder: ".."}, - {path: "./../", first: ".", remainder: ".."}, - {path: "..", first: "..", remainder: ""}, - {path: "../", first: "..", remainder: ""}, - {path: "../.", first: "..", remainder: "."}, - {path: ".././", first: "..", remainder: "."}, - {path: "../..", first: "..", remainder: ".."}, - {path: "../../", first: "..", remainder: ".."}, - - {path: "/foo", first: "/", remainder: "foo"}, - {path: "/foo/", first: "/", remainder: "foo"}, - {path: "/foo///", first: "/", remainder: "foo"}, - {path: "/foo/.", first: "/", remainder: "foo/."}, - {path: "/foo/./", first: "/", remainder: "foo/."}, - {path: "/foo/./.", first: "/", remainder: "foo/./."}, - {path: "/foo/./..", first: "/", remainder: "foo/./.."}, - {path: "/foo/..", first: "/", remainder: "foo/.."}, - {path: "/foo/../", first: "/", remainder: "foo/.."}, - {path: "/foo/../.", first: "/", remainder: "foo/../."}, - {path: "/foo/../..", first: "/", remainder: "foo/../.."}, - - {path: "/foo/bar", first: "/", remainder: "foo/bar"}, - {path: "///foo/bar", first: "/", remainder: "//foo/bar"}, - {path: "/foo///bar", first: "/", remainder: "foo///bar"}, - {path: "/foo/bar/.", first: "/", remainder: "foo/bar/."}, - {path: "/foo/bar/./", first: "/", remainder: "foo/bar/."}, - {path: "/foo/bar/./.", first: "/", remainder: "foo/bar/./."}, - {path: "/foo/bar/./..", first: "/", remainder: "foo/bar/./.."}, - {path: "/foo/bar/..", first: "/", remainder: "foo/bar/.."}, - {path: "/foo/bar/../", first: "/", remainder: "foo/bar/.."}, - {path: "/foo/bar/../.", first: "/", remainder: "foo/bar/../."}, - {path: "/foo/bar/../..", first: "/", remainder: "foo/bar/../.."}, - - {path: "foo", first: "foo", remainder: ""}, - {path: "foo", first: "foo", remainder: ""}, - {path: "foo/", first: "foo", remainder: ""}, - {path: "foo///", first: "foo", remainder: ""}, - {path: "foo/.", first: "foo", remainder: "."}, - {path: "foo/./", first: "foo", remainder: "."}, - {path: "foo/./.", first: "foo", remainder: "./."}, - {path: "foo/./..", first: "foo", remainder: "./.."}, - {path: "foo/..", first: "foo", remainder: ".."}, - {path: "foo/../", first: "foo", remainder: ".."}, - {path: "foo/../.", first: "foo", remainder: "../."}, - {path: "foo/../..", first: "foo", remainder: "../.."}, - {path: "foo/", first: "foo", remainder: ""}, - {path: "foo/.", first: "foo", remainder: "."}, - - {path: "foo/bar", first: "foo", remainder: "bar"}, - {path: "foo///bar", first: "foo", remainder: "bar"}, - {path: "foo/bar/", first: "foo", remainder: "bar"}, - {path: "foo/bar/.", first: "foo", remainder: "bar/."}, - {path: "foo/bar/./", first: "foo", remainder: "bar/."}, - {path: "foo/bar/./.", first: "foo", remainder: "bar/./."}, - {path: "foo/bar/./..", first: "foo", remainder: "bar/./.."}, - {path: "foo/bar/..", first: "foo", remainder: "bar/.."}, - {path: "foo/bar/../", first: "foo", remainder: "bar/.."}, - {path: "foo/bar/../.", first: "foo", remainder: "bar/../."}, - {path: "foo/bar/../..", first: "foo", remainder: "bar/../.."}, - {path: "foo/bar/", first: "foo", remainder: "bar"}, - {path: "foo/bar/.", first: "foo", remainder: "bar/."}, - } - - for _, c := range cases { - first, remainder := SplitFirst(c.path) - if first != c.first || remainder != c.remainder { - t.Errorf("SplitFirst(%q) got (%q, %q), expected (%q, %q)", c.path, first, remainder, c.first, c.remainder) - } - } -} - -// TestIsSubpath tests the IsSubpath method. -func TestIsSubpath(t *testing.T) { - tcs := []struct { - // Two absolute paths. - pathA string - pathB string - - // Whether pathA is a subpath of pathB. - wantIsSubpath bool - - // Relative path from pathA to pathB. Only checked if - // wantIsSubpath is true. - wantRelpath string - }{ - { - pathA: "/foo/bar/baz", - pathB: "/foo", - wantIsSubpath: true, - wantRelpath: "bar/baz", - }, - { - pathA: "/foo", - pathB: "/foo/bar/baz", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/foo", - wantIsSubpath: false, - }, - { - pathA: "/foobar", - pathB: "/foo", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/foobar", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/foobar", - wantIsSubpath: false, - }, - { - pathA: "/", - pathB: "/foo", - wantIsSubpath: false, - }, - { - pathA: "/foo", - pathB: "/", - wantIsSubpath: true, - wantRelpath: "foo", - }, - { - pathA: "/foo/bar/../bar", - pathB: "/foo", - wantIsSubpath: true, - wantRelpath: "bar", - }, - { - pathA: "/foo/bar", - pathB: "/foo/../foo", - wantIsSubpath: true, - wantRelpath: "bar", - }, - } - - for _, tc := range tcs { - gotRelpath, gotIsSubpath := IsSubpath(tc.pathA, tc.pathB) - if gotRelpath != tc.wantRelpath || gotIsSubpath != tc.wantIsSubpath { - t.Errorf("IsSubpath(%q, %q) got %q %t, want %q %t", tc.pathA, tc.pathB, gotRelpath, gotIsSubpath, tc.wantRelpath, tc.wantIsSubpath) - } - } -} diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD deleted file mode 100644 index b70c583f3..000000000 --- a/pkg/sentry/fs/proc/BUILD +++ /dev/null @@ -1,73 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "proc", - srcs = [ - "cgroup.go", - "cpuinfo.go", - "exec_args.go", - "fds.go", - "filesystems.go", - "fs.go", - "inode.go", - "loadavg.go", - "meminfo.go", - "mounts.go", - "net.go", - "proc.go", - "rpcinet_proc.go", - "stat.go", - "sys.go", - "sys_net.go", - "sys_net_state.go", - "task.go", - "uid_gid_map.go", - "uptime.go", - "version.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/proc/device", - "//pkg/sentry/fs/proc/seqfile", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/mm", - "//pkg/sentry/socket", - "//pkg/sentry/socket/rpcinet", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) - -go_test( - name = "proc_test", - size = "small", - srcs = [ - "net_test.go", - "sys_net_test.go", - ], - embed = [":proc"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/sentry/inet", - "//pkg/sentry/usermem", - ], -) diff --git a/pkg/sentry/fs/proc/README.md b/pkg/sentry/fs/proc/README.md deleted file mode 100644 index 5d4ec6c7b..000000000 --- a/pkg/sentry/fs/proc/README.md +++ /dev/null @@ -1,332 +0,0 @@ -This document tracks what is implemented in procfs. Refer to -Documentation/filesystems/proc.txt in the Linux project for information about -procfs generally. - -**NOTE**: This document is not guaranteed to be up to date. If you find an -inconsistency, please file a bug. - -[TOC] - -## Kernel data - -The following files are implemented: - -| File /proc/ | Content | -| :------------------------ | :---------------------------------------------------- | -| [cpuinfo](#cpuinfo) | Info about the CPU | -| [filesystems](#filesystems) | Supported filesystems | -| [loadavg](#loadavg) | Load average of last 1, 5 & 15 minutes | -| [meminfo](#meminfo) | Overall memory info | -| [stat](#stat) | Overall kernel statistics | -| [sys](#sys) | Change parameters within the kernel | -| [uptime](#uptime) | Wall clock since boot, combined idle time of all cpus | -| [version](#version) | Kernel version | - -### cpuinfo - -```bash -$ cat /proc/cpuinfo -processor : 0 -vendor_id : GenuineIntel -cpu family : 6 -model : 45 -model name : unknown -stepping : unknown -cpu MHz : 1234.588 -fpu : yes -fpu_exception : yes -cpuid level : 13 -wp : yes -flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic popcnt tsc_deadline_timer aes xsave avx xsaveopt -bogomips : 1234.59 -clflush size : 64 -cache_alignment : 64 -address sizes : 46 bits physical, 48 bits virtual -power management: - -... -``` - -Notable divergences: - -Field name | Notes -:--------------- | :--------------------------------------- -model name | Always unknown -stepping | Always unknown -fpu | Always yes -fpu_exception | Always yes -wp | Always yes -bogomips | Bogus value (matches cpu MHz) -clflush size | Always 64 -cache_alignment | Always 64 -address sizes | Always 46 bits physical, 48 bits virtual -power management | Always blank - -Otherwise fields are derived from the sentry configuration. - -### filesystems - -```bash -$ cat /proc/filesystems -nodev 9p -nodev devpts -nodev devtmpfs -nodev proc -nodev sysfs -nodev tmpfs -``` - -### loadavg - -```bash -$ cat /proc/loadavg -0.00 0.00 0.00 0/0 0 -``` - -Column | Notes -:------------------------------------ | :---------- -CPU.IO utilization in last 1 minute | Always zero -CPU.IO utilization in last 5 minutes | Always zero -CPU.IO utilization in last 10 minutes | Always zero -Num currently running processes | Always zero -Total num processes | Always zero - -TODO(b/62345059): Populate the columns with accurate statistics. - -### meminfo - -```bash -$ cat /proc/meminfo -MemTotal: 2097152 kB -MemFree: 2083540 kB -MemAvailable: 2083540 kB -Buffers: 0 kB -Cached: 4428 kB -SwapCache: 0 kB -Active: 10812 kB -Inactive: 2216 kB -Active(anon): 8600 kB -Inactive(anon): 0 kB -Active(file): 2212 kB -Inactive(file): 2216 kB -Unevictable: 0 kB -Mlocked: 0 kB -SwapTotal: 0 kB -SwapFree: 0 kB -Dirty: 0 kB -Writeback: 0 kB -AnonPages: 8600 kB -Mapped: 4428 kB -Shmem: 0 kB - -``` - -Notable divergences: - -Field name | Notes -:---------------- | :----------------------------------------------------- -Buffers | Always zero, no block devices -SwapCache | Always zero, no swap -Inactive(anon) | Always zero, see SwapCache -Unevictable | Always zero TODO(b/31823263) -Mlocked | Always zero TODO(b/31823263) -SwapTotal | Always zero, no swap -SwapFree | Always zero, no swap -Dirty | Always zero TODO(b/31823263) -Writeback | Always zero TODO(b/31823263) -MemAvailable | Uses the same value as MemFree since there is no swap. -Slab | Missing -SReclaimable | Missing -SUnreclaim | Missing -KernelStack | Missing -PageTables | Missing -NFS_Unstable | Missing -Bounce | Missing -WritebackTmp | Missing -CommitLimit | Missing -Committed_AS | Missing -VmallocTotal | Missing -VmallocUsed | Missing -VmallocChunk | Missing -HardwareCorrupted | Missing -AnonHugePages | Missing -ShmemHugePages | Missing -ShmemPmdMapped | Missing -HugePages_Total | Missing -HugePages_Free | Missing -HugePages_Rsvd | Missing -HugePages_Surp | Missing -Hugepagesize | Missing -DirectMap4k | Missing -DirectMap2M | Missing -DirectMap1G | Missing - -### stat - -```bash -$ cat /proc/stat -cpu 0 0 0 0 0 0 0 0 0 0 -cpu0 0 0 0 0 0 0 0 0 0 0 -cpu1 0 0 0 0 0 0 0 0 0 0 -cpu2 0 0 0 0 0 0 0 0 0 0 -cpu3 0 0 0 0 0 0 0 0 0 0 -cpu4 0 0 0 0 0 0 0 0 0 0 -cpu5 0 0 0 0 0 0 0 0 0 0 -cpu6 0 0 0 0 0 0 0 0 0 0 -cpu7 0 0 0 0 0 0 0 0 0 0 -intr 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -ctxt 0 -btime 1504040968 -processes 0 -procs_running 0 -procs_blokkcked 0 -softirq 0 0 0 0 0 0 0 0 0 0 0 -``` - -All fields except for `btime` are always zero. - -TODO(b/37226836): Populate with accurate fields. - -### sys - -```bash -$ ls /proc/sys -kernel vm -``` - -Directory | Notes -:-------- | :---------------------------- -abi | Missing -debug | Missing -dev | Missing -fs | Missing -kernel | Contains hostname (only) -net | Missing -user | Missing -vm | Contains mmap_min_addr (only) - -### uptime - -```bash -$ cat /proc/uptime -3204.62 0.00 -``` - -Column | Notes -:------------------------------- | :---------------------------- -Total num seconds system running | Time since procfs was mounted -Number of seconds idle | Always zero - -### version - -```bash -$ cat /proc/version -Linux version 4.4 #1 SMP Sun Jan 10 15:06:54 PST 2016 -``` - -## Process-specific data - -The following files are implemented: - -File /proc/PID | Content -:---------------------- | :--------------------------------------------------- -[auxv](#auxv) | Copy of auxiliary vector for the process -[cmdline](#cmdline) | Command line arguments -[comm](#comm) | Command name associated with the process -[environ](#environ) | Process environment -[exe](#exe) | Symlink to the process's executable -[fd](#fd) | Directory containing links to open file descriptors -[fdinfo](#fdinfo) | Information associated with open file descriptors -[gid_map](#gid_map) | Mappings for group IDs inside the user namespace -[io](#io) | IO statistics -[maps](#maps) | Memory mappings (anon, executables, library files) -[mounts](#mounts) | Mounted filesystems -[mountinfo](#mountinfo) | Information about mounts -[ns](#ns) | Directory containing info about supported namespaces -[stat](#stat) | Process statistics -[statm](#statm) | Process memory statistics -[status](#status) | Process status in human readable format -[task](#task) | Directory containing info about running threads -[uid_map](#uid_map) | Mappings for user IDs inside the user namespace - -### auxv - -TODO - -### cmdline - -TODO - -### comm - -TODO - -### environment - -TODO - -### exe - -TODO - -### fd - -TODO - -### fdinfo - -TODO - -### gid_map - -TODO - -### io - -Only has data for rchar, wchar, syscr, and syscw. - -TODO: add more detail. - -### maps - -TODO - -### mounts - -TODO - -### mountinfo - -TODO - -### ns - -TODO - -### stat - -Only has data for pid, comm, state, ppid, utime, stime, cutime, cstime, -num_threads, and exit_signal. - -TODO: add more detail. - -### statm - -Only has data for vss and rss. - -TODO: add more detail. - -### status - -Contains data for Name, State, Tgid, Pid, Ppid, TracerPid, FDSize, VmSize, -VmRSS, Threads, CapInh, CapPrm, CapEff, CapBnd, Seccomp. - -TODO: add more detail. - -### task - -TODO - -### uid_map - -TODO diff --git a/pkg/sentry/fs/proc/device/BUILD b/pkg/sentry/fs/proc/device/BUILD deleted file mode 100644 index 0394451d4..000000000 --- a/pkg/sentry/fs/proc/device/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "device", - srcs = ["device.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc/device", - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/sentry/device"], -) diff --git a/pkg/sentry/fs/proc/device/device_state_autogen.go b/pkg/sentry/fs/proc/device/device_state_autogen.go new file mode 100755 index 000000000..be407ac45 --- /dev/null +++ b/pkg/sentry/fs/proc/device/device_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package device + diff --git a/pkg/sentry/fs/proc/net_test.go b/pkg/sentry/fs/proc/net_test.go deleted file mode 100644 index f18681405..000000000 --- a/pkg/sentry/fs/proc/net_test.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/inet" -) - -func newIPv6TestStack() *inet.TestStack { - s := inet.NewTestStack() - s.SupportsIPv6Flag = true - return s -} - -func TestIfinet6NoAddresses(t *testing.T) { - n := &ifinet6{s: newIPv6TestStack()} - if got := n.contents(); got != nil { - t.Errorf("Got n.contents() = %v, want = %v", got, nil) - } -} - -func TestIfinet6(t *testing.T) { - s := newIPv6TestStack() - s.InterfacesMap[1] = inet.Interface{Name: "eth0"} - s.InterfaceAddrsMap[1] = []inet.InterfaceAddr{ - { - Family: linux.AF_INET6, - PrefixLen: 128, - Addr: []byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"), - }, - } - s.InterfacesMap[2] = inet.Interface{Name: "eth1"} - s.InterfaceAddrsMap[2] = []inet.InterfaceAddr{ - { - Family: linux.AF_INET6, - PrefixLen: 128, - Addr: []byte("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"), - }, - } - want := map[string]struct{}{ - "000102030405060708090a0b0c0d0e0f 01 80 00 00 eth0\n": {}, - "101112131415161718191a1b1c1d1e1f 02 80 00 00 eth1\n": {}, - } - - n := &ifinet6{s: s} - contents := n.contents() - if len(contents) != len(want) { - t.Errorf("Got len(n.contents()) = %d, want = %d", len(contents), len(want)) - } - got := map[string]struct{}{} - for _, l := range contents { - got[l] = struct{}{} - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("Got n.contents() = %v, want = %v", got, want) - } -} diff --git a/pkg/sentry/fs/proc/proc_state_autogen.go b/pkg/sentry/fs/proc/proc_state_autogen.go new file mode 100755 index 000000000..5cb5a8b82 --- /dev/null +++ b/pkg/sentry/fs/proc/proc_state_autogen.go @@ -0,0 +1,657 @@ +// automatically generated by stateify. + +package proc + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *execArgInode) beforeSave() {} +func (x *execArgInode) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("arg", &x.arg) + m.Save("t", &x.t) +} + +func (x *execArgInode) afterLoad() {} +func (x *execArgInode) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("arg", &x.arg) + m.Load("t", &x.t) +} + +func (x *execArgFile) beforeSave() {} +func (x *execArgFile) save(m state.Map) { + x.beforeSave() + m.Save("arg", &x.arg) + m.Save("t", &x.t) +} + +func (x *execArgFile) afterLoad() {} +func (x *execArgFile) load(m state.Map) { + m.Load("arg", &x.arg) + m.Load("t", &x.t) +} + +func (x *fdDir) beforeSave() {} +func (x *fdDir) save(m state.Map) { + x.beforeSave() + m.Save("Dir", &x.Dir) + m.Save("t", &x.t) +} + +func (x *fdDir) afterLoad() {} +func (x *fdDir) load(m state.Map) { + m.Load("Dir", &x.Dir) + m.Load("t", &x.t) +} + +func (x *fdDirFile) beforeSave() {} +func (x *fdDirFile) save(m state.Map) { + x.beforeSave() + m.Save("isInfoFile", &x.isInfoFile) + m.Save("t", &x.t) +} + +func (x *fdDirFile) afterLoad() {} +func (x *fdDirFile) load(m state.Map) { + m.Load("isInfoFile", &x.isInfoFile) + m.Load("t", &x.t) +} + +func (x *fdInfoDir) beforeSave() {} +func (x *fdInfoDir) save(m state.Map) { + x.beforeSave() + m.Save("Dir", &x.Dir) + m.Save("t", &x.t) +} + +func (x *fdInfoDir) afterLoad() {} +func (x *fdInfoDir) load(m state.Map) { + m.Load("Dir", &x.Dir) + m.Load("t", &x.t) +} + +func (x *filesystemsData) beforeSave() {} +func (x *filesystemsData) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystemsData) afterLoad() {} +func (x *filesystemsData) load(m state.Map) { +} + +func (x *filesystem) beforeSave() {} +func (x *filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystem) afterLoad() {} +func (x *filesystem) load(m state.Map) { +} + +func (x *taskOwnedInodeOps) beforeSave() {} +func (x *taskOwnedInodeOps) save(m state.Map) { + x.beforeSave() + m.Save("InodeOperations", &x.InodeOperations) + m.Save("t", &x.t) +} + +func (x *taskOwnedInodeOps) afterLoad() {} +func (x *taskOwnedInodeOps) load(m state.Map) { + m.Load("InodeOperations", &x.InodeOperations) + m.Load("t", &x.t) +} + +func (x *staticFileInodeOps) beforeSave() {} +func (x *staticFileInodeOps) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeStaticFileGetter", &x.InodeStaticFileGetter) +} + +func (x *staticFileInodeOps) afterLoad() {} +func (x *staticFileInodeOps) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeStaticFileGetter", &x.InodeStaticFileGetter) +} + +func (x *loadavgData) beforeSave() {} +func (x *loadavgData) save(m state.Map) { + x.beforeSave() +} + +func (x *loadavgData) afterLoad() {} +func (x *loadavgData) load(m state.Map) { +} + +func (x *meminfoData) beforeSave() {} +func (x *meminfoData) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *meminfoData) afterLoad() {} +func (x *meminfoData) load(m state.Map) { + m.Load("k", &x.k) +} + +func (x *mountInfoFile) beforeSave() {} +func (x *mountInfoFile) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *mountInfoFile) afterLoad() {} +func (x *mountInfoFile) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *mountsFile) beforeSave() {} +func (x *mountsFile) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *mountsFile) afterLoad() {} +func (x *mountsFile) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *ifinet6) beforeSave() {} +func (x *ifinet6) save(m state.Map) { + x.beforeSave() + m.Save("s", &x.s) +} + +func (x *ifinet6) afterLoad() {} +func (x *ifinet6) load(m state.Map) { + m.Load("s", &x.s) +} + +func (x *netDev) beforeSave() {} +func (x *netDev) save(m state.Map) { + x.beforeSave() + m.Save("s", &x.s) +} + +func (x *netDev) afterLoad() {} +func (x *netDev) load(m state.Map) { + m.Load("s", &x.s) +} + +func (x *netUnix) beforeSave() {} +func (x *netUnix) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *netUnix) afterLoad() {} +func (x *netUnix) load(m state.Map) { + m.Load("k", &x.k) +} + +func (x *proc) beforeSave() {} +func (x *proc) save(m state.Map) { + x.beforeSave() + m.Save("Dir", &x.Dir) + m.Save("k", &x.k) + m.Save("pidns", &x.pidns) + m.Save("cgroupControllers", &x.cgroupControllers) +} + +func (x *proc) afterLoad() {} +func (x *proc) load(m state.Map) { + m.Load("Dir", &x.Dir) + m.Load("k", &x.k) + m.Load("pidns", &x.pidns) + m.Load("cgroupControllers", &x.cgroupControllers) +} + +func (x *self) beforeSave() {} +func (x *self) save(m state.Map) { + x.beforeSave() + m.Save("Symlink", &x.Symlink) + m.Save("pidns", &x.pidns) +} + +func (x *self) afterLoad() {} +func (x *self) load(m state.Map) { + m.Load("Symlink", &x.Symlink) + m.Load("pidns", &x.pidns) +} + +func (x *threadSelf) beforeSave() {} +func (x *threadSelf) save(m state.Map) { + x.beforeSave() + m.Save("Symlink", &x.Symlink) + m.Save("pidns", &x.pidns) +} + +func (x *threadSelf) afterLoad() {} +func (x *threadSelf) load(m state.Map) { + m.Load("Symlink", &x.Symlink) + m.Load("pidns", &x.pidns) +} + +func (x *rootProcFile) beforeSave() {} +func (x *rootProcFile) save(m state.Map) { + x.beforeSave() + m.Save("iops", &x.iops) +} + +func (x *rootProcFile) afterLoad() {} +func (x *rootProcFile) load(m state.Map) { + m.Load("iops", &x.iops) +} + +func (x *statData) beforeSave() {} +func (x *statData) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *statData) afterLoad() {} +func (x *statData) load(m state.Map) { + m.Load("k", &x.k) +} + +func (x *mmapMinAddrData) beforeSave() {} +func (x *mmapMinAddrData) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *mmapMinAddrData) afterLoad() {} +func (x *mmapMinAddrData) load(m state.Map) { + m.Load("k", &x.k) +} + +func (x *overcommitMemory) beforeSave() {} +func (x *overcommitMemory) save(m state.Map) { + x.beforeSave() +} + +func (x *overcommitMemory) afterLoad() {} +func (x *overcommitMemory) load(m state.Map) { +} + +func (x *hostname) beforeSave() {} +func (x *hostname) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) +} + +func (x *hostname) afterLoad() {} +func (x *hostname) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) +} + +func (x *hostnameFile) beforeSave() {} +func (x *hostnameFile) save(m state.Map) { + x.beforeSave() +} + +func (x *hostnameFile) afterLoad() {} +func (x *hostnameFile) load(m state.Map) { +} + +func (x *tcpMemInode) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("dir", &x.dir) + m.Save("s", &x.s) + m.Save("size", &x.size) +} + +func (x *tcpMemInode) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("dir", &x.dir) + m.LoadWait("s", &x.s) + m.Load("size", &x.size) + m.AfterLoad(x.afterLoad) +} + +func (x *tcpMemFile) beforeSave() {} +func (x *tcpMemFile) save(m state.Map) { + x.beforeSave() + m.Save("tcpMemInode", &x.tcpMemInode) +} + +func (x *tcpMemFile) afterLoad() {} +func (x *tcpMemFile) load(m state.Map) { + m.Load("tcpMemInode", &x.tcpMemInode) +} + +func (x *tcpSack) beforeSave() {} +func (x *tcpSack) save(m state.Map) { + x.beforeSave() + m.Save("stack", &x.stack) + m.Save("enabled", &x.enabled) + m.Save("SimpleFileInode", &x.SimpleFileInode) +} + +func (x *tcpSack) load(m state.Map) { + m.LoadWait("stack", &x.stack) + m.Load("enabled", &x.enabled) + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.AfterLoad(x.afterLoad) +} + +func (x *tcpSackFile) beforeSave() {} +func (x *tcpSackFile) save(m state.Map) { + x.beforeSave() + m.Save("tcpSack", &x.tcpSack) + m.Save("stack", &x.stack) +} + +func (x *tcpSackFile) afterLoad() {} +func (x *tcpSackFile) load(m state.Map) { + m.Load("tcpSack", &x.tcpSack) + m.LoadWait("stack", &x.stack) +} + +func (x *taskDir) beforeSave() {} +func (x *taskDir) save(m state.Map) { + x.beforeSave() + m.Save("Dir", &x.Dir) + m.Save("t", &x.t) + m.Save("pidns", &x.pidns) +} + +func (x *taskDir) afterLoad() {} +func (x *taskDir) load(m state.Map) { + m.Load("Dir", &x.Dir) + m.Load("t", &x.t) + m.Load("pidns", &x.pidns) +} + +func (x *subtasks) beforeSave() {} +func (x *subtasks) save(m state.Map) { + x.beforeSave() + m.Save("Dir", &x.Dir) + m.Save("t", &x.t) + m.Save("p", &x.p) +} + +func (x *subtasks) afterLoad() {} +func (x *subtasks) load(m state.Map) { + m.Load("Dir", &x.Dir) + m.Load("t", &x.t) + m.Load("p", &x.p) +} + +func (x *subtasksFile) beforeSave() {} +func (x *subtasksFile) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) + m.Save("pidns", &x.pidns) +} + +func (x *subtasksFile) afterLoad() {} +func (x *subtasksFile) load(m state.Map) { + m.Load("t", &x.t) + m.Load("pidns", &x.pidns) +} + +func (x *exe) beforeSave() {} +func (x *exe) save(m state.Map) { + x.beforeSave() + m.Save("Symlink", &x.Symlink) + m.Save("t", &x.t) +} + +func (x *exe) afterLoad() {} +func (x *exe) load(m state.Map) { + m.Load("Symlink", &x.Symlink) + m.Load("t", &x.t) +} + +func (x *namespaceSymlink) beforeSave() {} +func (x *namespaceSymlink) save(m state.Map) { + x.beforeSave() + m.Save("Symlink", &x.Symlink) + m.Save("t", &x.t) +} + +func (x *namespaceSymlink) afterLoad() {} +func (x *namespaceSymlink) load(m state.Map) { + m.Load("Symlink", &x.Symlink) + m.Load("t", &x.t) +} + +func (x *mapsData) beforeSave() {} +func (x *mapsData) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *mapsData) afterLoad() {} +func (x *mapsData) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *smapsData) beforeSave() {} +func (x *smapsData) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *smapsData) afterLoad() {} +func (x *smapsData) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *taskStatData) beforeSave() {} +func (x *taskStatData) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) + m.Save("tgstats", &x.tgstats) + m.Save("pidns", &x.pidns) +} + +func (x *taskStatData) afterLoad() {} +func (x *taskStatData) load(m state.Map) { + m.Load("t", &x.t) + m.Load("tgstats", &x.tgstats) + m.Load("pidns", &x.pidns) +} + +func (x *statmData) beforeSave() {} +func (x *statmData) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *statmData) afterLoad() {} +func (x *statmData) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *statusData) beforeSave() {} +func (x *statusData) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) + m.Save("pidns", &x.pidns) +} + +func (x *statusData) afterLoad() {} +func (x *statusData) load(m state.Map) { + m.Load("t", &x.t) + m.Load("pidns", &x.pidns) +} + +func (x *ioData) beforeSave() {} +func (x *ioData) save(m state.Map) { + x.beforeSave() + m.Save("ioUsage", &x.ioUsage) +} + +func (x *ioData) afterLoad() {} +func (x *ioData) load(m state.Map) { + m.Load("ioUsage", &x.ioUsage) +} + +func (x *comm) beforeSave() {} +func (x *comm) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("t", &x.t) +} + +func (x *comm) afterLoad() {} +func (x *comm) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("t", &x.t) +} + +func (x *commFile) beforeSave() {} +func (x *commFile) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *commFile) afterLoad() {} +func (x *commFile) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *auxvec) beforeSave() {} +func (x *auxvec) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("t", &x.t) +} + +func (x *auxvec) afterLoad() {} +func (x *auxvec) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("t", &x.t) +} + +func (x *auxvecFile) beforeSave() {} +func (x *auxvecFile) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) +} + +func (x *auxvecFile) afterLoad() {} +func (x *auxvecFile) load(m state.Map) { + m.Load("t", &x.t) +} + +func (x *idMapInodeOperations) beforeSave() {} +func (x *idMapInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("t", &x.t) + m.Save("gids", &x.gids) +} + +func (x *idMapInodeOperations) afterLoad() {} +func (x *idMapInodeOperations) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("t", &x.t) + m.Load("gids", &x.gids) +} + +func (x *idMapFileOperations) beforeSave() {} +func (x *idMapFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("iops", &x.iops) +} + +func (x *idMapFileOperations) afterLoad() {} +func (x *idMapFileOperations) load(m state.Map) { + m.Load("iops", &x.iops) +} + +func (x *uptime) beforeSave() {} +func (x *uptime) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("startTime", &x.startTime) +} + +func (x *uptime) afterLoad() {} +func (x *uptime) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("startTime", &x.startTime) +} + +func (x *uptimeFile) beforeSave() {} +func (x *uptimeFile) save(m state.Map) { + x.beforeSave() + m.Save("startTime", &x.startTime) +} + +func (x *uptimeFile) afterLoad() {} +func (x *uptimeFile) load(m state.Map) { + m.Load("startTime", &x.startTime) +} + +func (x *versionData) beforeSave() {} +func (x *versionData) save(m state.Map) { + x.beforeSave() + m.Save("k", &x.k) +} + +func (x *versionData) afterLoad() {} +func (x *versionData) load(m state.Map) { + m.Load("k", &x.k) +} + +func init() { + state.Register("proc.execArgInode", (*execArgInode)(nil), state.Fns{Save: (*execArgInode).save, Load: (*execArgInode).load}) + state.Register("proc.execArgFile", (*execArgFile)(nil), state.Fns{Save: (*execArgFile).save, Load: (*execArgFile).load}) + state.Register("proc.fdDir", (*fdDir)(nil), state.Fns{Save: (*fdDir).save, Load: (*fdDir).load}) + state.Register("proc.fdDirFile", (*fdDirFile)(nil), state.Fns{Save: (*fdDirFile).save, Load: (*fdDirFile).load}) + state.Register("proc.fdInfoDir", (*fdInfoDir)(nil), state.Fns{Save: (*fdInfoDir).save, Load: (*fdInfoDir).load}) + state.Register("proc.filesystemsData", (*filesystemsData)(nil), state.Fns{Save: (*filesystemsData).save, Load: (*filesystemsData).load}) + state.Register("proc.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load}) + state.Register("proc.taskOwnedInodeOps", (*taskOwnedInodeOps)(nil), state.Fns{Save: (*taskOwnedInodeOps).save, Load: (*taskOwnedInodeOps).load}) + state.Register("proc.staticFileInodeOps", (*staticFileInodeOps)(nil), state.Fns{Save: (*staticFileInodeOps).save, Load: (*staticFileInodeOps).load}) + state.Register("proc.loadavgData", (*loadavgData)(nil), state.Fns{Save: (*loadavgData).save, Load: (*loadavgData).load}) + state.Register("proc.meminfoData", (*meminfoData)(nil), state.Fns{Save: (*meminfoData).save, Load: (*meminfoData).load}) + state.Register("proc.mountInfoFile", (*mountInfoFile)(nil), state.Fns{Save: (*mountInfoFile).save, Load: (*mountInfoFile).load}) + state.Register("proc.mountsFile", (*mountsFile)(nil), state.Fns{Save: (*mountsFile).save, Load: (*mountsFile).load}) + state.Register("proc.ifinet6", (*ifinet6)(nil), state.Fns{Save: (*ifinet6).save, Load: (*ifinet6).load}) + state.Register("proc.netDev", (*netDev)(nil), state.Fns{Save: (*netDev).save, Load: (*netDev).load}) + state.Register("proc.netUnix", (*netUnix)(nil), state.Fns{Save: (*netUnix).save, Load: (*netUnix).load}) + state.Register("proc.proc", (*proc)(nil), state.Fns{Save: (*proc).save, Load: (*proc).load}) + state.Register("proc.self", (*self)(nil), state.Fns{Save: (*self).save, Load: (*self).load}) + state.Register("proc.threadSelf", (*threadSelf)(nil), state.Fns{Save: (*threadSelf).save, Load: (*threadSelf).load}) + state.Register("proc.rootProcFile", (*rootProcFile)(nil), state.Fns{Save: (*rootProcFile).save, Load: (*rootProcFile).load}) + state.Register("proc.statData", (*statData)(nil), state.Fns{Save: (*statData).save, Load: (*statData).load}) + state.Register("proc.mmapMinAddrData", (*mmapMinAddrData)(nil), state.Fns{Save: (*mmapMinAddrData).save, Load: (*mmapMinAddrData).load}) + state.Register("proc.overcommitMemory", (*overcommitMemory)(nil), state.Fns{Save: (*overcommitMemory).save, Load: (*overcommitMemory).load}) + state.Register("proc.hostname", (*hostname)(nil), state.Fns{Save: (*hostname).save, Load: (*hostname).load}) + state.Register("proc.hostnameFile", (*hostnameFile)(nil), state.Fns{Save: (*hostnameFile).save, Load: (*hostnameFile).load}) + state.Register("proc.tcpMemInode", (*tcpMemInode)(nil), state.Fns{Save: (*tcpMemInode).save, Load: (*tcpMemInode).load}) + state.Register("proc.tcpMemFile", (*tcpMemFile)(nil), state.Fns{Save: (*tcpMemFile).save, Load: (*tcpMemFile).load}) + state.Register("proc.tcpSack", (*tcpSack)(nil), state.Fns{Save: (*tcpSack).save, Load: (*tcpSack).load}) + state.Register("proc.tcpSackFile", (*tcpSackFile)(nil), state.Fns{Save: (*tcpSackFile).save, Load: (*tcpSackFile).load}) + state.Register("proc.taskDir", (*taskDir)(nil), state.Fns{Save: (*taskDir).save, Load: (*taskDir).load}) + state.Register("proc.subtasks", (*subtasks)(nil), state.Fns{Save: (*subtasks).save, Load: (*subtasks).load}) + state.Register("proc.subtasksFile", (*subtasksFile)(nil), state.Fns{Save: (*subtasksFile).save, Load: (*subtasksFile).load}) + state.Register("proc.exe", (*exe)(nil), state.Fns{Save: (*exe).save, Load: (*exe).load}) + state.Register("proc.namespaceSymlink", (*namespaceSymlink)(nil), state.Fns{Save: (*namespaceSymlink).save, Load: (*namespaceSymlink).load}) + state.Register("proc.mapsData", (*mapsData)(nil), state.Fns{Save: (*mapsData).save, Load: (*mapsData).load}) + state.Register("proc.smapsData", (*smapsData)(nil), state.Fns{Save: (*smapsData).save, Load: (*smapsData).load}) + state.Register("proc.taskStatData", (*taskStatData)(nil), state.Fns{Save: (*taskStatData).save, Load: (*taskStatData).load}) + state.Register("proc.statmData", (*statmData)(nil), state.Fns{Save: (*statmData).save, Load: (*statmData).load}) + state.Register("proc.statusData", (*statusData)(nil), state.Fns{Save: (*statusData).save, Load: (*statusData).load}) + state.Register("proc.ioData", (*ioData)(nil), state.Fns{Save: (*ioData).save, Load: (*ioData).load}) + state.Register("proc.comm", (*comm)(nil), state.Fns{Save: (*comm).save, Load: (*comm).load}) + state.Register("proc.commFile", (*commFile)(nil), state.Fns{Save: (*commFile).save, Load: (*commFile).load}) + state.Register("proc.auxvec", (*auxvec)(nil), state.Fns{Save: (*auxvec).save, Load: (*auxvec).load}) + state.Register("proc.auxvecFile", (*auxvecFile)(nil), state.Fns{Save: (*auxvecFile).save, Load: (*auxvecFile).load}) + state.Register("proc.idMapInodeOperations", (*idMapInodeOperations)(nil), state.Fns{Save: (*idMapInodeOperations).save, Load: (*idMapInodeOperations).load}) + state.Register("proc.idMapFileOperations", (*idMapFileOperations)(nil), state.Fns{Save: (*idMapFileOperations).save, Load: (*idMapFileOperations).load}) + state.Register("proc.uptime", (*uptime)(nil), state.Fns{Save: (*uptime).save, Load: (*uptime).load}) + state.Register("proc.uptimeFile", (*uptimeFile)(nil), state.Fns{Save: (*uptimeFile).save, Load: (*uptimeFile).load}) + state.Register("proc.versionData", (*versionData)(nil), state.Fns{Save: (*versionData).save, Load: (*versionData).load}) +} diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD deleted file mode 100644 index 20c3eefc8..000000000 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "seqfile", - srcs = ["seqfile.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/proc/device", - "//pkg/sentry/kernel/time", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) - -go_test( - name = "seqfile_test", - size = "small", - srcs = ["seqfile_test.go"], - embed = [":seqfile"], - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/usermem", - ], -) diff --git a/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go b/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go new file mode 100755 index 000000000..db9f7ceb9 --- /dev/null +++ b/pkg/sentry/fs/proc/seqfile/seqfile_state_autogen.go @@ -0,0 +1,58 @@ +// automatically generated by stateify. + +package seqfile + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SeqData) beforeSave() {} +func (x *SeqData) save(m state.Map) { + x.beforeSave() + m.Save("Buf", &x.Buf) + m.Save("Handle", &x.Handle) +} + +func (x *SeqData) afterLoad() {} +func (x *SeqData) load(m state.Map) { + m.Load("Buf", &x.Buf) + m.Load("Handle", &x.Handle) +} + +func (x *SeqFile) beforeSave() {} +func (x *SeqFile) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("SeqSource", &x.SeqSource) + m.Save("source", &x.source) + m.Save("generation", &x.generation) + m.Save("lastRead", &x.lastRead) +} + +func (x *SeqFile) afterLoad() {} +func (x *SeqFile) load(m state.Map) { + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("SeqSource", &x.SeqSource) + m.Load("source", &x.source) + m.Load("generation", &x.generation) + m.Load("lastRead", &x.lastRead) +} + +func (x *seqFileOperations) beforeSave() {} +func (x *seqFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("seqFile", &x.seqFile) +} + +func (x *seqFileOperations) afterLoad() {} +func (x *seqFileOperations) load(m state.Map) { + m.Load("seqFile", &x.seqFile) +} + +func init() { + state.Register("seqfile.SeqData", (*SeqData)(nil), state.Fns{Save: (*SeqData).save, Load: (*SeqData).load}) + state.Register("seqfile.SeqFile", (*SeqFile)(nil), state.Fns{Save: (*SeqFile).save, Load: (*SeqFile).load}) + state.Register("seqfile.seqFileOperations", (*seqFileOperations)(nil), state.Fns{Save: (*seqFileOperations).save, Load: (*seqFileOperations).load}) +} diff --git a/pkg/sentry/fs/proc/seqfile/seqfile_test.go b/pkg/sentry/fs/proc/seqfile/seqfile_test.go deleted file mode 100644 index aa795164d..000000000 --- a/pkg/sentry/fs/proc/seqfile/seqfile_test.go +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seqfile - -import ( - "bytes" - "fmt" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -type seqTest struct { - actual []SeqData - update bool -} - -func (s *seqTest) Init() { - var sq []SeqData - // Create some SeqData. - for i := 0; i < 10; i++ { - var b []byte - for j := 0; j < 10; j++ { - b = append(b, byte(i)) - } - sq = append(sq, SeqData{ - Buf: b, - Handle: &testHandle{i: i}, - }) - } - s.actual = sq -} - -// NeedsUpdate reports whether we need to update the data we've previously read. -func (s *seqTest) NeedsUpdate(int64) bool { - return s.update -} - -// ReadSeqFiledata returns a slice of SeqData which contains elements -// greater than the handle. -func (s *seqTest) ReadSeqFileData(ctx context.Context, handle SeqHandle) ([]SeqData, int64) { - if handle == nil { - return s.actual, 0 - } - h := *handle.(*testHandle) - var ret []SeqData - for _, b := range s.actual { - // We want the next one. - h2 := *b.Handle.(*testHandle) - if h2.i > h.i { - ret = append(ret, b) - } - } - return ret, 0 -} - -// Flatten a slice of slices into one slice. -func flatten(buf ...[]byte) []byte { - var flat []byte - for _, b := range buf { - flat = append(flat, b...) - } - return flat -} - -type testHandle struct { - i int -} - -type testTable struct { - offset int64 - readBufferSize int - expectedData []byte - expectedError error -} - -func runTableTests(ctx context.Context, table []testTable, dirent *fs.Dirent) error { - for _, tt := range table { - file, err := dirent.Inode.InodeOperations.GetFile(ctx, dirent, fs.FileFlags{Read: true}) - if err != nil { - return fmt.Errorf("GetFile returned error: %v", err) - } - - data := make([]byte, tt.readBufferSize) - resultLen, err := file.Preadv(ctx, usermem.BytesIOSequence(data), tt.offset) - if err != tt.expectedError { - return fmt.Errorf("t.Preadv(len: %v, offset: %v) (error) => %v expected %v", tt.readBufferSize, tt.offset, err, tt.expectedError) - } - expectedLen := int64(len(tt.expectedData)) - if resultLen != expectedLen { - // We make this just an error so we wall through and print the data below. - return fmt.Errorf("t.Preadv(len: %v, offset: %v) (size) => %v expected %v", tt.readBufferSize, tt.offset, resultLen, expectedLen) - } - if !bytes.Equal(data[:expectedLen], tt.expectedData) { - return fmt.Errorf("t.Preadv(len: %v, offset: %v) (data) => %v expected %v", tt.readBufferSize, tt.offset, data[:expectedLen], tt.expectedData) - } - } - return nil -} - -func TestSeqFile(t *testing.T) { - testSource := &seqTest{} - testSource.Init() - - // Create a file that can be R/W. - m := fs.NewPseudoMountSource() - ctx := contexttest.Context(t) - contents := map[string]*fs.Inode{ - "foo": NewSeqFileInode(ctx, testSource, m), - } - root := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0777)) - - // How about opening it? - inode := fs.NewInode(root, m, fs.StableAttr{Type: fs.Directory}) - dirent2, err := root.Lookup(ctx, inode, "foo") - if err != nil { - t.Fatalf("failed to walk to foo for n2: %v", err) - } - n2 := dirent2.Inode.InodeOperations - file2, err := n2.GetFile(ctx, dirent2, fs.FileFlags{Read: true, Write: true}) - if err != nil { - t.Fatalf("GetFile returned error: %v", err) - } - - // Writing? - if _, err := file2.Writev(ctx, usermem.BytesIOSequence([]byte("test"))); err == nil { - t.Fatalf("managed to write to n2: %v", err) - } - - // How about reading? - dirent3, err := root.Lookup(ctx, inode, "foo") - if err != nil { - t.Fatalf("failed to walk to foo: %v", err) - } - n3 := dirent3.Inode.InodeOperations - if n2 != n3 { - t.Error("got n2 != n3, want same") - } - - testSource.update = true - - table := []testTable{ - // Read past the end. - {100, 4, []byte{}, io.EOF}, - {110, 4, []byte{}, io.EOF}, - {200, 4, []byte{}, io.EOF}, - // Read a truncated first line. - {0, 4, testSource.actual[0].Buf[:4], nil}, - // Read the whole first line. - {0, 10, testSource.actual[0].Buf, nil}, - // Read the whole first line + 5 bytes of second line. - {0, 15, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf[:5]), nil}, - // First 4 bytes of the second line. - {10, 4, testSource.actual[1].Buf[:4], nil}, - // Read the two first lines. - {0, 20, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf), nil}, - // Read three lines. - {0, 30, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf, testSource.actual[2].Buf), nil}, - // Read everything, but use a bigger buffer than necessary. - {0, 150, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf, testSource.actual[2].Buf, testSource.actual[3].Buf, testSource.actual[4].Buf, testSource.actual[5].Buf, testSource.actual[6].Buf, testSource.actual[7].Buf, testSource.actual[8].Buf, testSource.actual[9].Buf), nil}, - // Read the last 3 bytes. - {97, 10, testSource.actual[9].Buf[7:], nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed with testSource.update = %v : %v", testSource.update, err) - } - - // Disable updates and do it again. - testSource.update = false - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed with testSource.update = %v: %v", testSource.update, err) - } -} - -// Test that we behave correctly when the file is updated. -func TestSeqFileFileUpdated(t *testing.T) { - testSource := &seqTest{} - testSource.Init() - testSource.update = true - - // Create a file that can be R/W. - m := fs.NewPseudoMountSource() - ctx := contexttest.Context(t) - contents := map[string]*fs.Inode{ - "foo": NewSeqFileInode(ctx, testSource, m), - } - root := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0777)) - - // How about opening it? - inode := fs.NewInode(root, m, fs.StableAttr{Type: fs.Directory}) - dirent2, err := root.Lookup(ctx, inode, "foo") - if err != nil { - t.Fatalf("failed to walk to foo for dirent2: %v", err) - } - - table := []testTable{ - {0, 16, flatten(testSource.actual[0].Buf, testSource.actual[1].Buf[:6]), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed: %v", err) - } - // Delete the first entry. - cut := testSource.actual[0].Buf - testSource.actual = testSource.actual[1:] - - table = []testTable{ - // Try reading buffer 0 with an offset. This will not delete the old data. - {1, 5, cut[1:6], nil}, - // Reset our file by reading at offset 0. - {0, 10, testSource.actual[0].Buf, nil}, - {16, 14, flatten(testSource.actual[1].Buf[6:], testSource.actual[2].Buf), nil}, - // Read the same data a second time. - {16, 14, flatten(testSource.actual[1].Buf[6:], testSource.actual[2].Buf), nil}, - // Read the following two lines. - {30, 20, flatten(testSource.actual[3].Buf, testSource.actual[4].Buf), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after removing first entry: %v", err) - } - - // Add a new duplicate line in the middle (6666...) - after := testSource.actual[5:] - testSource.actual = testSource.actual[:4] - // Note the list must be sorted. - testSource.actual = append(testSource.actual, after[0]) - testSource.actual = append(testSource.actual, after...) - - table = []testTable{ - {50, 20, flatten(testSource.actual[4].Buf, testSource.actual[5].Buf), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after adding middle entry: %v", err) - } - // This will be used in a later test. - oldTestData := testSource.actual - - // Delete everything. - testSource.actual = testSource.actual[:0] - table = []testTable{ - {20, 20, []byte{}, io.EOF}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after removing all entries: %v", err) - } - // Restore some of the data. - testSource.actual = oldTestData[:1] - table = []testTable{ - {6, 20, testSource.actual[0].Buf[6:], nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after adding first entry back: %v", err) - } - - // Re-extend the data - testSource.actual = oldTestData - table = []testTable{ - {30, 20, flatten(testSource.actual[3].Buf, testSource.actual[4].Buf), nil}, - } - if err := runTableTests(ctx, table, dirent2); err != nil { - t.Errorf("runTableTest failed after extending testSource: %v", err) - } -} diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go deleted file mode 100644 index 6abae7a60..000000000 --- a/pkg/sentry/fs/proc/sys_net_test.go +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proc - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -func TestQuerySendBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - s.TCPSendBufSize = inet.TCPBufferSize{100, 200, 300} - tmi := &tcpMemInode{s: s, dir: tcpWMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - buf := make([]byte, 100) - dst := usermem.BytesIOSequence(buf) - n, err := tmf.Read(ctx, nil, dst, 0) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if got, want := string(buf[:n]), "100\t200\t300\n"; got != want { - t.Fatalf("Bad string: got %v, want %v", got, want) - } -} - -func TestQueryRecvBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - s.TCPRecvBufSize = inet.TCPBufferSize{100, 200, 300} - tmi := &tcpMemInode{s: s, dir: tcpRMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - buf := make([]byte, 100) - dst := usermem.BytesIOSequence(buf) - n, err := tmf.Read(ctx, nil, dst, 0) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if got, want := string(buf[:n]), "100\t200\t300\n"; got != want { - t.Fatalf("Bad string: got %v, want %v", got, want) - } -} - -var cases = []struct { - str string - initial inet.TCPBufferSize - final inet.TCPBufferSize -}{ - { - str: "", - initial: inet.TCPBufferSize{1, 2, 3}, - final: inet.TCPBufferSize{1, 2, 3}, - }, - { - str: "100\n", - initial: inet.TCPBufferSize{1, 100, 200}, - final: inet.TCPBufferSize{100, 100, 200}, - }, - { - str: "100 200 300\n", - initial: inet.TCPBufferSize{1, 2, 3}, - final: inet.TCPBufferSize{100, 200, 300}, - }, -} - -func TestConfigureSendBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - for _, c := range cases { - s.TCPSendBufSize = c.initial - tmi := &tcpMemInode{s: s, dir: tcpWMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - // Write the values. - src := usermem.BytesIOSequence([]byte(c.str)) - if n, err := tmf.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { - t.Errorf("Write, case = %q: got (%d, %v), wanted (%d, nil)", c.str, n, err, len(c.str)) - } - - // Read the values from the stack and check them. - if s.TCPSendBufSize != c.final { - t.Errorf("TCPSendBufferSize, case = %q: got %v, wanted %v", c.str, s.TCPSendBufSize, c.final) - } - } -} - -func TestConfigureRecvBufferSize(t *testing.T) { - ctx := context.Background() - s := inet.NewTestStack() - for _, c := range cases { - s.TCPRecvBufSize = c.initial - tmi := &tcpMemInode{s: s, dir: tcpRMem} - tmf := &tcpMemFile{tcpMemInode: tmi} - - // Write the values. - src := usermem.BytesIOSequence([]byte(c.str)) - if n, err := tmf.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { - t.Errorf("Write, case = %q: got (%d, %v), wanted (%d, nil)", c.str, n, err, len(c.str)) - } - - // Read the values from the stack and check them. - if s.TCPRecvBufSize != c.final { - t.Errorf("TCPRecvBufferSize, case = %q: got %v, wanted %v", c.str, s.TCPRecvBufSize, c.final) - } - } -} diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD deleted file mode 100644 index 516efcc4c..000000000 --- a/pkg/sentry/fs/ramfs/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "ramfs", - srcs = [ - "dir.go", - "socket.go", - "symlink.go", - "tree.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/ramfs", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) - -go_test( - name = "ramfs_test", - size = "small", - srcs = ["tree_test.go"], - embed = [":ramfs"], - deps = [ - "//pkg/sentry/context/contexttest", - "//pkg/sentry/fs", - ], -) diff --git a/pkg/sentry/fs/ramfs/ramfs_state_autogen.go b/pkg/sentry/fs/ramfs/ramfs_state_autogen.go new file mode 100755 index 000000000..bc86a01a1 --- /dev/null +++ b/pkg/sentry/fs/ramfs/ramfs_state_autogen.go @@ -0,0 +1,94 @@ +// automatically generated by stateify. + +package ramfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Dir) beforeSave() {} +func (x *Dir) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("children", &x.children) + m.Save("dentryMap", &x.dentryMap) +} + +func (x *Dir) afterLoad() {} +func (x *Dir) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("children", &x.children) + m.Load("dentryMap", &x.dentryMap) +} + +func (x *dirFileOperations) beforeSave() {} +func (x *dirFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("dirCursor", &x.dirCursor) + m.Save("dir", &x.dir) +} + +func (x *dirFileOperations) afterLoad() {} +func (x *dirFileOperations) load(m state.Map) { + m.Load("dirCursor", &x.dirCursor) + m.Load("dir", &x.dir) +} + +func (x *Socket) beforeSave() {} +func (x *Socket) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("ep", &x.ep) +} + +func (x *Socket) afterLoad() {} +func (x *Socket) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("ep", &x.ep) +} + +func (x *socketFileOperations) beforeSave() {} +func (x *socketFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *socketFileOperations) afterLoad() {} +func (x *socketFileOperations) load(m state.Map) { +} + +func (x *Symlink) beforeSave() {} +func (x *Symlink) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("Target", &x.Target) +} + +func (x *Symlink) afterLoad() {} +func (x *Symlink) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("Target", &x.Target) +} + +func (x *symlinkFileOperations) beforeSave() {} +func (x *symlinkFileOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *symlinkFileOperations) afterLoad() {} +func (x *symlinkFileOperations) load(m state.Map) { +} + +func init() { + state.Register("ramfs.Dir", (*Dir)(nil), state.Fns{Save: (*Dir).save, Load: (*Dir).load}) + state.Register("ramfs.dirFileOperations", (*dirFileOperations)(nil), state.Fns{Save: (*dirFileOperations).save, Load: (*dirFileOperations).load}) + state.Register("ramfs.Socket", (*Socket)(nil), state.Fns{Save: (*Socket).save, Load: (*Socket).load}) + state.Register("ramfs.socketFileOperations", (*socketFileOperations)(nil), state.Fns{Save: (*socketFileOperations).save, Load: (*socketFileOperations).load}) + state.Register("ramfs.Symlink", (*Symlink)(nil), state.Fns{Save: (*Symlink).save, Load: (*Symlink).load}) + state.Register("ramfs.symlinkFileOperations", (*symlinkFileOperations)(nil), state.Fns{Save: (*symlinkFileOperations).save, Load: (*symlinkFileOperations).load}) +} diff --git a/pkg/sentry/fs/ramfs/tree_test.go b/pkg/sentry/fs/ramfs/tree_test.go deleted file mode 100644 index da259a971..000000000 --- a/pkg/sentry/fs/ramfs/tree_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ramfs - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" -) - -func TestMakeDirectoryTree(t *testing.T) { - mount := fs.NewPseudoMountSource() - - for _, test := range []struct { - name string - subdirs []string - }{ - { - name: "abs paths", - subdirs: []string{ - "/tmp", - "/tmp/a/b", - "/tmp/a/c/d", - "/tmp/c", - "/proc", - "/dev/a/b", - "/tmp", - }, - }, - { - name: "rel paths", - subdirs: []string{ - "tmp", - "tmp/a/b", - "tmp/a/c/d", - "tmp/c", - "proc", - "dev/a/b", - "tmp", - }, - }, - } { - ctx := contexttest.Context(t) - tree, err := MakeDirectoryTree(ctx, mount, test.subdirs) - if err != nil { - t.Errorf("%s: failed to make ramfs tree, got error %v, want nil", test.name, err) - continue - } - - // Expect to be able to find each of the paths. - mm, err := fs.NewMountNamespace(ctx, tree) - if err != nil { - t.Errorf("%s: failed to create mount manager: %v", test.name, err) - continue - } - root := mm.Root() - defer mm.DecRef() - - for _, p := range test.subdirs { - maxTraversals := uint(0) - if _, err := mm.FindInode(ctx, root, nil, p, &maxTraversals); err != nil { - t.Errorf("%s: failed to find node %s: %v", test.name, p, err) - break - } - } - } -} diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD deleted file mode 100644 index 70fa3af89..000000000 --- a/pkg/sentry/fs/sys/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "sys", - srcs = [ - "device.go", - "devices.go", - "fs.go", - "sys.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/sys", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/kernel", - "//pkg/sentry/usermem", - ], -) diff --git a/pkg/sentry/fs/sys/sys_state_autogen.go b/pkg/sentry/fs/sys/sys_state_autogen.go new file mode 100755 index 000000000..603057309 --- /dev/null +++ b/pkg/sentry/fs/sys/sys_state_autogen.go @@ -0,0 +1,34 @@ +// automatically generated by stateify. + +package sys + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *cpunum) beforeSave() {} +func (x *cpunum) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("InodeStaticFileGetter", &x.InodeStaticFileGetter) +} + +func (x *cpunum) afterLoad() {} +func (x *cpunum) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("InodeStaticFileGetter", &x.InodeStaticFileGetter) +} + +func (x *filesystem) beforeSave() {} +func (x *filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystem) afterLoad() {} +func (x *filesystem) load(m state.Map) { +} + +func init() { + state.Register("sys.cpunum", (*cpunum)(nil), state.Fns{Save: (*cpunum).save, Load: (*cpunum).load}) + state.Register("sys.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load}) +} diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD deleted file mode 100644 index 1d80daeaf..000000000 --- a/pkg/sentry/fs/timerfd/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "timerfd", - srcs = ["timerfd.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/timerfd", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel/time", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/fs/timerfd/timerfd_state_autogen.go b/pkg/sentry/fs/timerfd/timerfd_state_autogen.go new file mode 100755 index 000000000..e8d98af97 --- /dev/null +++ b/pkg/sentry/fs/timerfd/timerfd_state_autogen.go @@ -0,0 +1,25 @@ +// automatically generated by stateify. + +package timerfd + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *TimerOperations) beforeSave() {} +func (x *TimerOperations) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.events) { m.Failf("events is %v, expected zero", x.events) } + m.Save("timer", &x.timer) + m.Save("val", &x.val) +} + +func (x *TimerOperations) afterLoad() {} +func (x *TimerOperations) load(m state.Map) { + m.Load("timer", &x.timer) + m.Load("val", &x.val) +} + +func init() { + state.Register("timerfd.TimerOperations", (*TimerOperations)(nil), state.Fns{Save: (*TimerOperations).save, Load: (*TimerOperations).load}) +} diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD deleted file mode 100644 index 8f7eb5757..000000000 --- a/pkg/sentry/fs/tmpfs/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "tmpfs", - srcs = [ - "device.go", - "file_regular.go", - "fs.go", - "inode_file.go", - "tmpfs.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/metric", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/ramfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/safemem", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) - -go_test( - name = "tmpfs_test", - size = "small", - srcs = ["file_test.go"], - embed = [":tmpfs"], - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/contexttest", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - ], -) diff --git a/pkg/sentry/fs/tmpfs/file_test.go b/pkg/sentry/fs/tmpfs/file_test.go deleted file mode 100644 index 591f20c52..000000000 --- a/pkg/sentry/fs/tmpfs/file_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmpfs - -import ( - "bytes" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -func newFileInode(ctx context.Context) *fs.Inode { - m := fs.NewCachingMountSource(&Filesystem{}, fs.MountSourceFlags{}) - iops := NewInMemoryFile(ctx, usage.Tmpfs, fs.WithCurrentTime(ctx, fs.UnstableAttr{})) - return fs.NewInode(iops, m, fs.StableAttr{ - DeviceID: tmpfsDevice.DeviceID(), - InodeID: tmpfsDevice.NextIno(), - BlockSize: usermem.PageSize, - Type: fs.RegularFile, - }) -} - -func newFile(ctx context.Context) *fs.File { - inode := newFileInode(ctx) - f, _ := inode.GetFile(ctx, fs.NewDirent(inode, "stub"), fs.FileFlags{Read: true, Write: true}) - return f -} - -// Allocate once, write twice. -func TestGrow(t *testing.T) { - ctx := contexttest.Context(t) - f := newFile(ctx) - defer f.DecRef() - - abuf := bytes.Repeat([]byte{'a'}, 68) - n, err := f.Pwritev(ctx, usermem.BytesIOSequence(abuf), 0) - if n != int64(len(abuf)) || err != nil { - t.Fatalf("Pwritev got (%d, %v) want (%d, nil)", n, err, len(abuf)) - } - - bbuf := bytes.Repeat([]byte{'b'}, 856) - n, err = f.Pwritev(ctx, usermem.BytesIOSequence(bbuf), 68) - if n != int64(len(bbuf)) || err != nil { - t.Fatalf("Pwritev got (%d, %v) want (%d, nil)", n, err, len(bbuf)) - } - - rbuf := make([]byte, len(abuf)+len(bbuf)) - n, err = f.Preadv(ctx, usermem.BytesIOSequence(rbuf), 0) - if n != int64(len(rbuf)) || err != nil { - t.Fatalf("Preadv got (%d, %v) want (%d, nil)", n, err, len(rbuf)) - } - - if want := append(abuf, bbuf...); !bytes.Equal(rbuf, want) { - t.Fatalf("Read %v, want %v", rbuf, want) - } -} diff --git a/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go b/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go new file mode 100755 index 000000000..7d73d1c77 --- /dev/null +++ b/pkg/sentry/fs/tmpfs/tmpfs_state_autogen.go @@ -0,0 +1,108 @@ +// automatically generated by stateify. + +package tmpfs + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *regularFileOperations) beforeSave() {} +func (x *regularFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("iops", &x.iops) +} + +func (x *regularFileOperations) afterLoad() {} +func (x *regularFileOperations) load(m state.Map) { + m.Load("iops", &x.iops) +} + +func (x *Filesystem) beforeSave() {} +func (x *Filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *Filesystem) afterLoad() {} +func (x *Filesystem) load(m state.Map) { +} + +func (x *fileInodeOperations) beforeSave() {} +func (x *fileInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Save("kernel", &x.kernel) + m.Save("memUsage", &x.memUsage) + m.Save("attr", &x.attr) + m.Save("mappings", &x.mappings) + m.Save("writableMappingPages", &x.writableMappingPages) + m.Save("data", &x.data) + m.Save("seals", &x.seals) +} + +func (x *fileInodeOperations) afterLoad() {} +func (x *fileInodeOperations) load(m state.Map) { + m.Load("InodeSimpleExtendedAttributes", &x.InodeSimpleExtendedAttributes) + m.Load("kernel", &x.kernel) + m.Load("memUsage", &x.memUsage) + m.Load("attr", &x.attr) + m.Load("mappings", &x.mappings) + m.Load("writableMappingPages", &x.writableMappingPages) + m.Load("data", &x.data) + m.Load("seals", &x.seals) +} + +func (x *Dir) beforeSave() {} +func (x *Dir) save(m state.Map) { + x.beforeSave() + m.Save("ramfsDir", &x.ramfsDir) + m.Save("kernel", &x.kernel) +} + +func (x *Dir) load(m state.Map) { + m.Load("ramfsDir", &x.ramfsDir) + m.Load("kernel", &x.kernel) + m.AfterLoad(x.afterLoad) +} + +func (x *Symlink) beforeSave() {} +func (x *Symlink) save(m state.Map) { + x.beforeSave() + m.Save("Symlink", &x.Symlink) +} + +func (x *Symlink) afterLoad() {} +func (x *Symlink) load(m state.Map) { + m.Load("Symlink", &x.Symlink) +} + +func (x *Socket) beforeSave() {} +func (x *Socket) save(m state.Map) { + x.beforeSave() + m.Save("Socket", &x.Socket) +} + +func (x *Socket) afterLoad() {} +func (x *Socket) load(m state.Map) { + m.Load("Socket", &x.Socket) +} + +func (x *Fifo) beforeSave() {} +func (x *Fifo) save(m state.Map) { + x.beforeSave() + m.Save("InodeOperations", &x.InodeOperations) +} + +func (x *Fifo) afterLoad() {} +func (x *Fifo) load(m state.Map) { + m.Load("InodeOperations", &x.InodeOperations) +} + +func init() { + state.Register("tmpfs.regularFileOperations", (*regularFileOperations)(nil), state.Fns{Save: (*regularFileOperations).save, Load: (*regularFileOperations).load}) + state.Register("tmpfs.Filesystem", (*Filesystem)(nil), state.Fns{Save: (*Filesystem).save, Load: (*Filesystem).load}) + state.Register("tmpfs.fileInodeOperations", (*fileInodeOperations)(nil), state.Fns{Save: (*fileInodeOperations).save, Load: (*fileInodeOperations).load}) + state.Register("tmpfs.Dir", (*Dir)(nil), state.Fns{Save: (*Dir).save, Load: (*Dir).load}) + state.Register("tmpfs.Symlink", (*Symlink)(nil), state.Fns{Save: (*Symlink).save, Load: (*Symlink).load}) + state.Register("tmpfs.Socket", (*Socket)(nil), state.Fns{Save: (*Socket).save, Load: (*Socket).load}) + state.Register("tmpfs.Fifo", (*Fifo)(nil), state.Fns{Save: (*Fifo).save, Load: (*Fifo).load}) +} diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD deleted file mode 100644 index 5e9327aec..000000000 --- a/pkg/sentry/fs/tty/BUILD +++ /dev/null @@ -1,46 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "tty", - srcs = [ - "dir.go", - "fs.go", - "line_discipline.go", - "master.go", - "queue.go", - "slave.go", - "terminal.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/tty", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/refs", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/safemem", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/unimpl", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) - -go_test( - name = "tty_test", - size = "small", - srcs = ["tty_test.go"], - embed = [":tty"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/usermem", - ], -) diff --git a/pkg/sentry/fs/tty/tty_state_autogen.go b/pkg/sentry/fs/tty/tty_state_autogen.go new file mode 100755 index 000000000..6c9845627 --- /dev/null +++ b/pkg/sentry/fs/tty/tty_state_autogen.go @@ -0,0 +1,202 @@ +// automatically generated by stateify. + +package tty + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *dirInodeOperations) beforeSave() {} +func (x *dirInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("msrc", &x.msrc) + m.Save("master", &x.master) + m.Save("slaves", &x.slaves) + m.Save("dentryMap", &x.dentryMap) + m.Save("next", &x.next) +} + +func (x *dirInodeOperations) afterLoad() {} +func (x *dirInodeOperations) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("msrc", &x.msrc) + m.Load("master", &x.master) + m.Load("slaves", &x.slaves) + m.Load("dentryMap", &x.dentryMap) + m.Load("next", &x.next) +} + +func (x *dirFileOperations) beforeSave() {} +func (x *dirFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("di", &x.di) + m.Save("dirCursor", &x.dirCursor) +} + +func (x *dirFileOperations) afterLoad() {} +func (x *dirFileOperations) load(m state.Map) { + m.Load("di", &x.di) + m.Load("dirCursor", &x.dirCursor) +} + +func (x *filesystem) beforeSave() {} +func (x *filesystem) save(m state.Map) { + x.beforeSave() +} + +func (x *filesystem) afterLoad() {} +func (x *filesystem) load(m state.Map) { +} + +func (x *superOperations) beforeSave() {} +func (x *superOperations) save(m state.Map) { + x.beforeSave() +} + +func (x *superOperations) afterLoad() {} +func (x *superOperations) load(m state.Map) { +} + +func (x *lineDiscipline) beforeSave() {} +func (x *lineDiscipline) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.masterWaiter) { m.Failf("masterWaiter is %v, expected zero", x.masterWaiter) } + if !state.IsZeroValue(x.slaveWaiter) { m.Failf("slaveWaiter is %v, expected zero", x.slaveWaiter) } + m.Save("size", &x.size) + m.Save("inQueue", &x.inQueue) + m.Save("outQueue", &x.outQueue) + m.Save("termios", &x.termios) + m.Save("column", &x.column) +} + +func (x *lineDiscipline) afterLoad() {} +func (x *lineDiscipline) load(m state.Map) { + m.Load("size", &x.size) + m.Load("inQueue", &x.inQueue) + m.Load("outQueue", &x.outQueue) + m.Load("termios", &x.termios) + m.Load("column", &x.column) +} + +func (x *outputQueueTransformer) beforeSave() {} +func (x *outputQueueTransformer) save(m state.Map) { + x.beforeSave() +} + +func (x *outputQueueTransformer) afterLoad() {} +func (x *outputQueueTransformer) load(m state.Map) { +} + +func (x *inputQueueTransformer) beforeSave() {} +func (x *inputQueueTransformer) save(m state.Map) { + x.beforeSave() +} + +func (x *inputQueueTransformer) afterLoad() {} +func (x *inputQueueTransformer) load(m state.Map) { +} + +func (x *masterInodeOperations) beforeSave() {} +func (x *masterInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("d", &x.d) +} + +func (x *masterInodeOperations) afterLoad() {} +func (x *masterInodeOperations) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("d", &x.d) +} + +func (x *masterFileOperations) beforeSave() {} +func (x *masterFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("d", &x.d) + m.Save("t", &x.t) +} + +func (x *masterFileOperations) afterLoad() {} +func (x *masterFileOperations) load(m state.Map) { + m.Load("d", &x.d) + m.Load("t", &x.t) +} + +func (x *queue) beforeSave() {} +func (x *queue) save(m state.Map) { + x.beforeSave() + m.Save("readBuf", &x.readBuf) + m.Save("waitBuf", &x.waitBuf) + m.Save("waitBufLen", &x.waitBufLen) + m.Save("readable", &x.readable) + m.Save("transformer", &x.transformer) +} + +func (x *queue) afterLoad() {} +func (x *queue) load(m state.Map) { + m.Load("readBuf", &x.readBuf) + m.Load("waitBuf", &x.waitBuf) + m.Load("waitBufLen", &x.waitBufLen) + m.Load("readable", &x.readable) + m.Load("transformer", &x.transformer) +} + +func (x *slaveInodeOperations) beforeSave() {} +func (x *slaveInodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("SimpleFileInode", &x.SimpleFileInode) + m.Save("d", &x.d) + m.Save("t", &x.t) +} + +func (x *slaveInodeOperations) afterLoad() {} +func (x *slaveInodeOperations) load(m state.Map) { + m.Load("SimpleFileInode", &x.SimpleFileInode) + m.Load("d", &x.d) + m.Load("t", &x.t) +} + +func (x *slaveFileOperations) beforeSave() {} +func (x *slaveFileOperations) save(m state.Map) { + x.beforeSave() + m.Save("si", &x.si) +} + +func (x *slaveFileOperations) afterLoad() {} +func (x *slaveFileOperations) load(m state.Map) { + m.Load("si", &x.si) +} + +func (x *Terminal) beforeSave() {} +func (x *Terminal) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("n", &x.n) + m.Save("d", &x.d) + m.Save("ld", &x.ld) +} + +func (x *Terminal) afterLoad() {} +func (x *Terminal) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("n", &x.n) + m.Load("d", &x.d) + m.Load("ld", &x.ld) +} + +func init() { + state.Register("tty.dirInodeOperations", (*dirInodeOperations)(nil), state.Fns{Save: (*dirInodeOperations).save, Load: (*dirInodeOperations).load}) + state.Register("tty.dirFileOperations", (*dirFileOperations)(nil), state.Fns{Save: (*dirFileOperations).save, Load: (*dirFileOperations).load}) + state.Register("tty.filesystem", (*filesystem)(nil), state.Fns{Save: (*filesystem).save, Load: (*filesystem).load}) + state.Register("tty.superOperations", (*superOperations)(nil), state.Fns{Save: (*superOperations).save, Load: (*superOperations).load}) + state.Register("tty.lineDiscipline", (*lineDiscipline)(nil), state.Fns{Save: (*lineDiscipline).save, Load: (*lineDiscipline).load}) + state.Register("tty.outputQueueTransformer", (*outputQueueTransformer)(nil), state.Fns{Save: (*outputQueueTransformer).save, Load: (*outputQueueTransformer).load}) + state.Register("tty.inputQueueTransformer", (*inputQueueTransformer)(nil), state.Fns{Save: (*inputQueueTransformer).save, Load: (*inputQueueTransformer).load}) + state.Register("tty.masterInodeOperations", (*masterInodeOperations)(nil), state.Fns{Save: (*masterInodeOperations).save, Load: (*masterInodeOperations).load}) + state.Register("tty.masterFileOperations", (*masterFileOperations)(nil), state.Fns{Save: (*masterFileOperations).save, Load: (*masterFileOperations).load}) + state.Register("tty.queue", (*queue)(nil), state.Fns{Save: (*queue).save, Load: (*queue).load}) + state.Register("tty.slaveInodeOperations", (*slaveInodeOperations)(nil), state.Fns{Save: (*slaveInodeOperations).save, Load: (*slaveInodeOperations).load}) + state.Register("tty.slaveFileOperations", (*slaveFileOperations)(nil), state.Fns{Save: (*slaveFileOperations).save, Load: (*slaveFileOperations).load}) + state.Register("tty.Terminal", (*Terminal)(nil), state.Fns{Save: (*Terminal).save, Load: (*Terminal).load}) +} diff --git a/pkg/sentry/fs/tty/tty_test.go b/pkg/sentry/fs/tty/tty_test.go deleted file mode 100644 index 59f07ff8e..000000000 --- a/pkg/sentry/fs/tty/tty_test.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tty - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -func TestSimpleMasterToSlave(t *testing.T) { - ld := newLineDiscipline(linux.DefaultSlaveTermios) - ctx := contexttest.Context(t) - inBytes := []byte("hello, tty\n") - src := usermem.BytesIOSequence(inBytes) - outBytes := make([]byte, 32) - dst := usermem.BytesIOSequence(outBytes) - - // Write to the input queue. - nw, err := ld.inputQueueWrite(ctx, src) - if err != nil { - t.Fatalf("error writing to input queue: %v", err) - } - if nw != int64(len(inBytes)) { - t.Fatalf("wrote wrong length: got %d, want %d", nw, len(inBytes)) - } - - // Read from the input queue. - nr, err := ld.inputQueueRead(ctx, dst) - if err != nil { - t.Fatalf("error reading from input queue: %v", err) - } - if nr != int64(len(inBytes)) { - t.Fatalf("read wrong length: got %d, want %d", nr, len(inBytes)) - } - - outStr := string(outBytes[:nr]) - inStr := string(inBytes) - if outStr != inStr { - t.Fatalf("written and read strings do not match: got %q, want %q", outStr, inStr) - } -} diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD deleted file mode 100644 index f989f2f8b..000000000 --- a/pkg/sentry/hostcpu/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "hostcpu", - srcs = [ - "getcpu_amd64.s", - "hostcpu.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu", - visibility = ["//:sandbox"], -) - -go_test( - name = "hostcpu_test", - size = "small", - srcs = ["hostcpu_test.go"], - embed = [":hostcpu"], -) diff --git a/pkg/sentry/hostcpu/hostcpu_state_autogen.go b/pkg/sentry/hostcpu/hostcpu_state_autogen.go new file mode 100755 index 000000000..f04a56ec0 --- /dev/null +++ b/pkg/sentry/hostcpu/hostcpu_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package hostcpu + diff --git a/pkg/sentry/hostcpu/hostcpu_test.go b/pkg/sentry/hostcpu/hostcpu_test.go deleted file mode 100644 index 7d6885c9e..000000000 --- a/pkg/sentry/hostcpu/hostcpu_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package hostcpu - -import ( - "fmt" - "testing" -) - -func TestMaxValueInLinuxBitmap(t *testing.T) { - for _, test := range []struct { - str string - max uint64 - }{ - {"0", 0}, - {"0\n", 0}, - {"0,2", 2}, - {"0-63", 63}, - {"0-3,8-11", 11}, - } { - t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) { - max, err := maxValueInLinuxBitmap(test.str) - if err != nil || max != test.max { - t.Errorf("maxValueInLinuxBitmap: got (%d, %v), wanted (%d, nil)", max, err, test.max) - } - }) - } -} - -func TestMaxValueInLinuxBitmapErrors(t *testing.T) { - for _, str := range []string{"", "\n"} { - t.Run(fmt.Sprintf("%q", str), func(t *testing.T) { - max, err := maxValueInLinuxBitmap(str) - if err == nil { - t.Errorf("maxValueInLinuxBitmap: got (%d, nil), wanted (_, error)", max) - } - t.Log(err) - }) - } -} diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD deleted file mode 100644 index 67831d5a1..000000000 --- a/pkg/sentry/hostmm/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hostmm", - srcs = [ - "cgroup.go", - "hostmm.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/hostmm", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/fd", - "//pkg/log", - "//pkg/sentry/usermem", - ], -) diff --git a/pkg/sentry/hostmm/hostmm_state_autogen.go b/pkg/sentry/hostmm/hostmm_state_autogen.go new file mode 100755 index 000000000..730de5101 --- /dev/null +++ b/pkg/sentry/hostmm/hostmm_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package hostmm + diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD deleted file mode 100644 index 184b566d9..000000000 --- a/pkg/sentry/inet/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "inet", - srcs = [ - "context.go", - "inet.go", - "test_stack.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/inet", - deps = ["//pkg/sentry/context"], -) diff --git a/pkg/sentry/inet/inet_state_autogen.go b/pkg/sentry/inet/inet_state_autogen.go new file mode 100755 index 000000000..00f40556c --- /dev/null +++ b/pkg/sentry/inet/inet_state_autogen.go @@ -0,0 +1,26 @@ +// automatically generated by stateify. + +package inet + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *TCPBufferSize) beforeSave() {} +func (x *TCPBufferSize) save(m state.Map) { + x.beforeSave() + m.Save("Min", &x.Min) + m.Save("Default", &x.Default) + m.Save("Max", &x.Max) +} + +func (x *TCPBufferSize) afterLoad() {} +func (x *TCPBufferSize) load(m state.Map) { + m.Load("Min", &x.Min) + m.Load("Default", &x.Default) + m.Load("Max", &x.Max) +} + +func init() { + state.Register("inet.TCPBufferSize", (*TCPBufferSize)(nil), state.Fns{Save: (*TCPBufferSize).save, Load: (*TCPBufferSize).load}) +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD deleted file mode 100644 index c172d399e..000000000 --- a/pkg/sentry/kernel/BUILD +++ /dev/null @@ -1,237 +0,0 @@ -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") - -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "pending_signals_list", - out = "pending_signals_list.go", - package = "kernel", - prefix = "pendingSignal", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*pendingSignal", - "Linker": "*pendingSignal", - }, -) - -go_template_instance( - name = "process_group_list", - out = "process_group_list.go", - package = "kernel", - prefix = "processGroup", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*ProcessGroup", - "Linker": "*ProcessGroup", - }, -) - -go_template_instance( - name = "seqatomic_taskgoroutineschedinfo", - out = "seqatomic_taskgoroutineschedinfo.go", - package = "kernel", - suffix = "TaskGoroutineSchedInfo", - template = "//third_party/gvsync:generic_seqatomic", - types = { - "Value": "TaskGoroutineSchedInfo", - }, -) - -go_template_instance( - name = "session_list", - out = "session_list.go", - package = "kernel", - prefix = "session", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Session", - "Linker": "*Session", - }, -) - -go_template_instance( - name = "task_list", - out = "task_list.go", - package = "kernel", - prefix = "task", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Task", - "Linker": "*Task", - }, -) - -go_template_instance( - name = "socket_list", - out = "socket_list.go", - package = "kernel", - prefix = "socket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*SocketEntry", - "Linker": "*SocketEntry", - }, -) - -proto_library( - name = "uncaught_signal_proto", - srcs = ["uncaught_signal.proto"], - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_proto"], -) - -go_proto_library( - name = "uncaught_signal_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto", - proto = ":uncaught_signal_proto", - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_go_proto"], -) - -go_library( - name = "kernel", - srcs = [ - "abstract_socket_namespace.go", - "context.go", - "fd_map.go", - "fs_context.go", - "ipc_namespace.go", - "kernel.go", - "kernel_state.go", - "pending_signals.go", - "pending_signals_list.go", - "pending_signals_state.go", - "posixtimer.go", - "process_group_list.go", - "ptrace.go", - "ptrace_amd64.go", - "ptrace_arm64.go", - "rseq.go", - "seccomp.go", - "seqatomic_taskgoroutineschedinfo.go", - "session_list.go", - "sessions.go", - "signal.go", - "signal_handlers.go", - "socket_list.go", - "syscalls.go", - "syscalls_state.go", - "syslog.go", - "task.go", - "task_acct.go", - "task_block.go", - "task_clone.go", - "task_context.go", - "task_exec.go", - "task_exit.go", - "task_futex.go", - "task_identity.go", - "task_list.go", - "task_log.go", - "task_net.go", - "task_run.go", - "task_sched.go", - "task_signals.go", - "task_start.go", - "task_stop.go", - "task_syscall.go", - "task_usermem.go", - "thread_group.go", - "threads.go", - "timekeeper.go", - "timekeeper_state.go", - "uts_namespace.go", - "vdso.go", - "version.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel", - imports = [ - "gvisor.dev/gvisor/pkg/bpf", - "gvisor.dev/gvisor/pkg/sentry/device", - "gvisor.dev/gvisor/pkg/tcpip", - ], - visibility = ["//:sandbox"], - deps = [ - ":uncaught_signal_go_proto", - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/amutex", - "//pkg/binary", - "//pkg/bits", - "//pkg/bpf", - "//pkg/cpuid", - "//pkg/eventchannel", - "//pkg/log", - "//pkg/metric", - "//pkg/refs", - "//pkg/secio", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fs/timerfd", - "//pkg/sentry/hostcpu", - "//pkg/sentry/inet", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/epoll", - "//pkg/sentry/kernel/futex", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/sched", - "//pkg/sentry/kernel/semaphore", - "//pkg/sentry/kernel/shm", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/loader", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/safemem", - "//pkg/sentry/socket/netlink/port", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/time", - "//pkg/sentry/unimpl", - "//pkg/sentry/unimpl:unimplemented_syscall_go_proto", - "//pkg/sentry/uniqueid", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/state", - "//pkg/state/statefile", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/tcpip/stack", - "//pkg/waiter", - "//third_party/gvsync", - ], -) - -go_test( - name = "kernel_test", - size = "small", - srcs = [ - "fd_map_test.go", - "table_test.go", - "task_test.go", - "timekeeper_test.go", - ], - embed = [":kernel"], - deps = [ - "//pkg/abi", - "//pkg/sentry/arch", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/fs/filetest", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/sched", - "//pkg/sentry/limits", - "//pkg/sentry/pgalloc", - "//pkg/sentry/time", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/kernel/README.md b/pkg/sentry/kernel/README.md deleted file mode 100644 index 427311be8..000000000 --- a/pkg/sentry/kernel/README.md +++ /dev/null @@ -1,108 +0,0 @@ -This package contains: - -- A (partial) emulation of the "core Linux kernel", which governs task - execution and scheduling, system call dispatch, and signal handling. See - below for details. - -- The top-level interface for the sentry's Linux kernel emulation in general, - used by the `main` function of all versions of the sentry. This interface - revolves around the `Env` type (defined in `kernel.go`). - -# Background - -In Linux, each schedulable context is referred to interchangeably as a "task" or -"thread". Tasks can be divided into userspace and kernel tasks. In the sentry, -scheduling is managed by the Go runtime, so each schedulable context is a -goroutine; only "userspace" (application) contexts are referred to as tasks, and -represented by Task objects. (From this point forward, "task" refers to the -sentry's notion of a task unless otherwise specified.) - -At a high level, Linux application threads can be thought of as repeating a "run -loop": - -- Some amount of application code is executed in userspace. - -- A trap (explicit syscall invocation, hardware interrupt or exception, etc.) - causes control flow to switch to the kernel. - -- Some amount of kernel code is executed in kernelspace, e.g. to handle the - cause of the trap. - -- The kernel "returns from the trap" into application code. - -Analogously, each task in the sentry is associated with a *task goroutine* that -executes that task's run loop (`Task.run` in `task_run.go`). However, the -sentry's task run loop differs in structure in order to support saving execution -state to, and resuming execution from, checkpoints. - -While in kernelspace, a Linux thread can be descheduled (cease execution) in a -variety of ways: - -- It can yield or be preempted, becoming temporarily descheduled but still - runnable. At present, the sentry delegates scheduling of runnable threads to - the Go runtime. - -- It can exit, becoming permanently descheduled. The sentry's equivalent is - returning from `Task.run`, terminating the task goroutine. - -- It can enter interruptible sleep, a state in which it can be woken by a - caller-defined wakeup or the receipt of a signal. In the sentry, - interruptible sleep (which is ambiguously referred to as *blocking*) is - implemented by making all events that can end blocking (including signal - notifications) communicated via Go channels and using `select` to multiplex - wakeup sources; see `task_block.go`. - -- It can enter uninterruptible sleep, a state in which it can only be woken by - a caller-defined wakeup. Killable sleep is a closely related variant in - which the task can also be woken by SIGKILL. (These definitions also include - Linux's "group-stopped" (`TASK_STOPPED`) and "ptrace-stopped" - (`TASK_TRACED`) states.) - -To maximize compatibility with Linux, sentry checkpointing appears as a spurious -signal-delivery interrupt on all tasks; interrupted system calls return `EINTR` -or are automatically restarted as usual. However, these semantics require that -uninterruptible and killable sleeps do not appear to be interrupted. In other -words, the state of the task, including its progress through the interrupted -operation, must be preserved by checkpointing. For many such sleeps, the wakeup -condition is application-controlled, making it infeasible to wait for the sleep -to end before checkpointing. Instead, we must support checkpointing progress -through sleeping operations. - -# Implementation - -We break the task's control flow graph into *states*, delimited by: - -1. Points where uninterruptible and killable sleeps may occur. For example, - there exists a state boundary between signal dequeueing and signal delivery - because there may be an intervening ptrace signal-delivery-stop. - -2. Points where sleep-induced branches may "rejoin" normal execution. For - example, the syscall exit state exists because it can be reached immediately - following a synchronous syscall, or after a task that is sleeping in - `execve()` or `vfork()` resumes execution. - -3. Points containing large branches. This is strictly for organizational - purposes. For example, the state that processes interrupt-signaled - conditions is kept separate from the main "app" state to reduce the size of - the latter. - -4. `SyscallReinvoke`, which does not correspond to anything in Linux, and - exists solely to serve the autosave feature. - -![dot -Tpng -Goverlap=false -orun_states.png run_states.dot](g3doc/run_states.png "Task control flow graph") - -States before which a stop may occur are represented as implementations of the -`taskRunState` interface named `run(state)`, allowing them to be saved and -restored. States that cannot be immediately preceded by a stop are simply `Task` -methods named `do(state)`. - -Conditions that can require task goroutines to cease execution for unknown -lengths of time are called *stops*. Stops are divided into *internal stops*, -which are stops whose start and end conditions are implemented within the -sentry, and *external stops*, which are stops whose start and end conditions are -not known to the sentry. Hence all uninterruptible and killable sleeps are -internal stops, and the existence of a pending checkpoint operation is an -external stop. Internal stops are reified into instances of the `TaskStop` type, -while external stops are merely counted. The task run loop alternates between -checking for stops and advancing the task's state. This allows checkpointing to -hold tasks in a stopped state while waiting for all tasks in the system to stop. diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD deleted file mode 100644 index 37cb8c8b9..000000000 --- a/pkg/sentry/kernel/auth/BUILD +++ /dev/null @@ -1,57 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") - -go_template_instance( - name = "id_map_range", - out = "id_map_range.go", - package = "auth", - prefix = "idMap", - template = "//pkg/segment:generic_range", - types = { - "T": "uint32", - }, -) - -go_template_instance( - name = "id_map_set", - out = "id_map_set.go", - consts = { - "minDegree": "3", - }, - package = "auth", - prefix = "idMap", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint32", - "Range": "idMapRange", - "Value": "uint32", - "Functions": "idMapFunctions", - }, -) - -go_library( - name = "auth", - srcs = [ - "auth.go", - "capability_set.go", - "context.go", - "credentials.go", - "id.go", - "id_map.go", - "id_map_functions.go", - "id_map_range.go", - "id_map_set.go", - "user_namespace.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/auth", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/bits", - "//pkg/log", - "//pkg/sentry/context", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/kernel/auth/auth_state_autogen.go b/pkg/sentry/kernel/auth/auth_state_autogen.go new file mode 100755 index 000000000..4460d37ed --- /dev/null +++ b/pkg/sentry/kernel/auth/auth_state_autogen.go @@ -0,0 +1,151 @@ +// automatically generated by stateify. + +package auth + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Credentials) beforeSave() {} +func (x *Credentials) save(m state.Map) { + x.beforeSave() + m.Save("RealKUID", &x.RealKUID) + m.Save("EffectiveKUID", &x.EffectiveKUID) + m.Save("SavedKUID", &x.SavedKUID) + m.Save("RealKGID", &x.RealKGID) + m.Save("EffectiveKGID", &x.EffectiveKGID) + m.Save("SavedKGID", &x.SavedKGID) + m.Save("ExtraKGIDs", &x.ExtraKGIDs) + m.Save("PermittedCaps", &x.PermittedCaps) + m.Save("InheritableCaps", &x.InheritableCaps) + m.Save("EffectiveCaps", &x.EffectiveCaps) + m.Save("BoundingCaps", &x.BoundingCaps) + m.Save("KeepCaps", &x.KeepCaps) + m.Save("UserNamespace", &x.UserNamespace) +} + +func (x *Credentials) afterLoad() {} +func (x *Credentials) load(m state.Map) { + m.Load("RealKUID", &x.RealKUID) + m.Load("EffectiveKUID", &x.EffectiveKUID) + m.Load("SavedKUID", &x.SavedKUID) + m.Load("RealKGID", &x.RealKGID) + m.Load("EffectiveKGID", &x.EffectiveKGID) + m.Load("SavedKGID", &x.SavedKGID) + m.Load("ExtraKGIDs", &x.ExtraKGIDs) + m.Load("PermittedCaps", &x.PermittedCaps) + m.Load("InheritableCaps", &x.InheritableCaps) + m.Load("EffectiveCaps", &x.EffectiveCaps) + m.Load("BoundingCaps", &x.BoundingCaps) + m.Load("KeepCaps", &x.KeepCaps) + m.Load("UserNamespace", &x.UserNamespace) +} + +func (x *IDMapEntry) beforeSave() {} +func (x *IDMapEntry) save(m state.Map) { + x.beforeSave() + m.Save("FirstID", &x.FirstID) + m.Save("FirstParentID", &x.FirstParentID) + m.Save("Length", &x.Length) +} + +func (x *IDMapEntry) afterLoad() {} +func (x *IDMapEntry) load(m state.Map) { + m.Load("FirstID", &x.FirstID) + m.Load("FirstParentID", &x.FirstParentID) + m.Load("Length", &x.Length) +} + +func (x *idMapRange) beforeSave() {} +func (x *idMapRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *idMapRange) afterLoad() {} +func (x *idMapRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *idMapSet) beforeSave() {} +func (x *idMapSet) save(m state.Map) { + x.beforeSave() + var root *idMapSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *idMapSet) afterLoad() {} +func (x *idMapSet) load(m state.Map) { + m.LoadValue("root", new(*idMapSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*idMapSegmentDataSlices)) }) +} + +func (x *idMapnode) beforeSave() {} +func (x *idMapnode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *idMapnode) afterLoad() {} +func (x *idMapnode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *idMapSegmentDataSlices) beforeSave() {} +func (x *idMapSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *idMapSegmentDataSlices) afterLoad() {} +func (x *idMapSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *UserNamespace) beforeSave() {} +func (x *UserNamespace) save(m state.Map) { + x.beforeSave() + m.Save("parent", &x.parent) + m.Save("owner", &x.owner) + m.Save("uidMapFromParent", &x.uidMapFromParent) + m.Save("uidMapToParent", &x.uidMapToParent) + m.Save("gidMapFromParent", &x.gidMapFromParent) + m.Save("gidMapToParent", &x.gidMapToParent) +} + +func (x *UserNamespace) afterLoad() {} +func (x *UserNamespace) load(m state.Map) { + m.Load("parent", &x.parent) + m.Load("owner", &x.owner) + m.Load("uidMapFromParent", &x.uidMapFromParent) + m.Load("uidMapToParent", &x.uidMapToParent) + m.Load("gidMapFromParent", &x.gidMapFromParent) + m.Load("gidMapToParent", &x.gidMapToParent) +} + +func init() { + state.Register("auth.Credentials", (*Credentials)(nil), state.Fns{Save: (*Credentials).save, Load: (*Credentials).load}) + state.Register("auth.IDMapEntry", (*IDMapEntry)(nil), state.Fns{Save: (*IDMapEntry).save, Load: (*IDMapEntry).load}) + state.Register("auth.idMapRange", (*idMapRange)(nil), state.Fns{Save: (*idMapRange).save, Load: (*idMapRange).load}) + state.Register("auth.idMapSet", (*idMapSet)(nil), state.Fns{Save: (*idMapSet).save, Load: (*idMapSet).load}) + state.Register("auth.idMapnode", (*idMapnode)(nil), state.Fns{Save: (*idMapnode).save, Load: (*idMapnode).load}) + state.Register("auth.idMapSegmentDataSlices", (*idMapSegmentDataSlices)(nil), state.Fns{Save: (*idMapSegmentDataSlices).save, Load: (*idMapSegmentDataSlices).load}) + state.Register("auth.UserNamespace", (*UserNamespace)(nil), state.Fns{Save: (*UserNamespace).save, Load: (*UserNamespace).load}) +} diff --git a/pkg/sentry/kernel/auth/id_map_range.go b/pkg/sentry/kernel/auth/id_map_range.go new file mode 100755 index 000000000..833fa3518 --- /dev/null +++ b/pkg/sentry/kernel/auth/id_map_range.go @@ -0,0 +1,62 @@ +package auth + +// A Range represents a contiguous range of T. +// +// +stateify savable +type idMapRange struct { + // Start is the inclusive start of the range. + Start uint32 + + // End is the exclusive end of the range. + End uint32 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r idMapRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r idMapRange) Length() uint32 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r idMapRange) Contains(x uint32) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r idMapRange) Overlaps(r2 idMapRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r idMapRange) IsSupersetOf(r2 idMapRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r idMapRange) Intersect(r2 idMapRange) idMapRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r idMapRange) CanSplitAt(x uint32) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/kernel/auth/id_map_set.go b/pkg/sentry/kernel/auth/id_map_set.go new file mode 100755 index 000000000..f72c839c7 --- /dev/null +++ b/pkg/sentry/kernel/auth/id_map_set.go @@ -0,0 +1,1270 @@ +package auth + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + idMapminDegree = 3 + + idMapmaxDegree = 2 * idMapminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type idMapSet struct { + root idMapnode `state:".(*idMapSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *idMapSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *idMapSet) IsEmptyRange(r idMapRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *idMapSet) Span() uint32 { + var sz uint32 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *idMapSet) SpanRange(r idMapRange) uint32 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint32 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *idMapSet) FirstSegment() idMapIterator { + if s.root.nrSegments == 0 { + return idMapIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *idMapSet) LastSegment() idMapIterator { + if s.root.nrSegments == 0 { + return idMapIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *idMapSet) FirstGap() idMapGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return idMapGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *idMapSet) LastGap() idMapGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return idMapGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *idMapSet) Find(key uint32) (idMapIterator, idMapGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return idMapIterator{n, i}, idMapGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return idMapIterator{}, idMapGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *idMapSet) FindSegment(key uint32) idMapIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *idMapSet) LowerBoundSegment(min uint32) idMapIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *idMapSet) UpperBoundSegment(max uint32) idMapIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *idMapSet) FindGap(key uint32) idMapGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *idMapSet) LowerBoundGap(min uint32) idMapGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *idMapSet) UpperBoundGap(max uint32) idMapGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *idMapSet) Add(r idMapRange, val uint32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *idMapSet) AddWithoutMerging(r idMapRange, val uint32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *idMapSet) Insert(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (idMapFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (idMapFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (idMapFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *idMapSet) InsertWithoutMerging(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *idMapSet) InsertWithoutMergingUnchecked(gap idMapGapIterator, r idMapRange, val uint32) idMapIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return idMapIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *idMapSet) Remove(seg idMapIterator) idMapGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + idMapFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(idMapGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *idMapSet) RemoveAll() { + s.root = idMapnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *idMapSet) RemoveRange(r idMapRange) idMapGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *idMapSet) Merge(first, second idMapIterator) idMapIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *idMapSet) MergeUnchecked(first, second idMapIterator) idMapIterator { + if first.End() == second.Start() { + if mval, ok := (idMapFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return idMapIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *idMapSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *idMapSet) MergeRange(r idMapRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *idMapSet) MergeAdjacent(r idMapRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *idMapSet) Split(seg idMapIterator, split uint32) (idMapIterator, idMapIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *idMapSet) SplitUnchecked(seg idMapIterator, split uint32) (idMapIterator, idMapIterator) { + val1, val2 := (idMapFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), idMapRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *idMapSet) SplitAt(split uint32) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *idMapSet) Isolate(seg idMapIterator, r idMapRange) idMapIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *idMapSet) ApplyContiguous(r idMapRange, fn func(seg idMapIterator)) idMapGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return idMapGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return idMapGapIterator{} + } + } +} + +// +stateify savable +type idMapnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *idMapnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [idMapmaxDegree - 1]idMapRange + values [idMapmaxDegree - 1]uint32 + children [idMapmaxDegree]*idMapnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *idMapnode) firstSegment() idMapIterator { + for n.hasChildren { + n = n.children[0] + } + return idMapIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *idMapnode) lastSegment() idMapIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return idMapIterator{n, n.nrSegments - 1} +} + +func (n *idMapnode) prevSibling() *idMapnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *idMapnode) nextSibling() *idMapnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *idMapnode) rebalanceBeforeInsert(gap idMapGapIterator) idMapGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < idMapmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &idMapnode{ + nrSegments: idMapminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &idMapnode{ + nrSegments: idMapminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:idMapminDegree-1], n.keys[:idMapminDegree-1]) + copy(left.values[:idMapminDegree-1], n.values[:idMapminDegree-1]) + copy(right.keys[:idMapminDegree-1], n.keys[idMapminDegree:]) + copy(right.values[:idMapminDegree-1], n.values[idMapminDegree:]) + n.keys[0], n.values[0] = n.keys[idMapminDegree-1], n.values[idMapminDegree-1] + idMapzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:idMapminDegree], n.children[:idMapminDegree]) + copy(right.children[:idMapminDegree], n.children[idMapminDegree:]) + idMapzeroNodeSlice(n.children[2:]) + for i := 0; i < idMapminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < idMapminDegree { + return idMapGapIterator{left, gap.index} + } + return idMapGapIterator{right, gap.index - idMapminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[idMapminDegree-1], n.values[idMapminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &idMapnode{ + nrSegments: idMapminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:idMapminDegree-1], n.keys[idMapminDegree:]) + copy(sibling.values[:idMapminDegree-1], n.values[idMapminDegree:]) + idMapzeroValueSlice(n.values[idMapminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:idMapminDegree], n.children[idMapminDegree:]) + idMapzeroNodeSlice(n.children[idMapminDegree:]) + for i := 0; i < idMapminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = idMapminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < idMapminDegree { + return gap + } + return idMapGapIterator{sibling, gap.index - idMapminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *idMapnode) rebalanceAfterRemove(gap idMapGapIterator) idMapGapIterator { + for { + if n.nrSegments >= idMapminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= idMapminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + idMapFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return idMapGapIterator{n, 0} + } + if gap.node == n { + return idMapGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= idMapminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + idMapFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return idMapGapIterator{n, n.nrSegments} + } + return idMapGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return idMapGapIterator{p, gap.index} + } + if gap.node == right { + return idMapGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *idMapnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = idMapGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + idMapFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type idMapIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *idMapnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg idMapIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg idMapIterator) Range() idMapRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg idMapIterator) Start() uint32 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg idMapIterator) End() uint32 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg idMapIterator) SetRangeUnchecked(r idMapRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg idMapIterator) SetRange(r idMapRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg idMapIterator) SetStartUnchecked(start uint32) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg idMapIterator) SetStart(start uint32) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg idMapIterator) SetEndUnchecked(end uint32) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg idMapIterator) SetEnd(end uint32) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg idMapIterator) Value() uint32 { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg idMapIterator) ValuePtr() *uint32 { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg idMapIterator) SetValue(val uint32) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg idMapIterator) PrevSegment() idMapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return idMapIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return idMapIterator{} + } + return idMapsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg idMapIterator) NextSegment() idMapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return idMapIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return idMapIterator{} + } + return idMapsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg idMapIterator) PrevGap() idMapGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return idMapGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg idMapIterator) NextGap() idMapGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return idMapGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg idMapIterator) PrevNonEmpty() (idMapIterator, idMapGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return idMapIterator{}, gap + } + return gap.PrevSegment(), idMapGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg idMapIterator) NextNonEmpty() (idMapIterator, idMapGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return idMapIterator{}, gap + } + return gap.NextSegment(), idMapGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type idMapGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *idMapnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap idMapGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap idMapGapIterator) Range() idMapRange { + return idMapRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap idMapGapIterator) Start() uint32 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return idMapFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap idMapGapIterator) End() uint32 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return idMapFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap idMapGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap idMapGapIterator) PrevSegment() idMapIterator { + return idMapsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap idMapGapIterator) NextSegment() idMapIterator { + return idMapsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap idMapGapIterator) PrevGap() idMapGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return idMapGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap idMapGapIterator) NextGap() idMapGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return idMapGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func idMapsegmentBeforePosition(n *idMapnode, i int) idMapIterator { + for i == 0 { + if n.parent == nil { + return idMapIterator{} + } + n, i = n.parent, n.parentIndex + } + return idMapIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func idMapsegmentAfterPosition(n *idMapnode, i int) idMapIterator { + for i == n.nrSegments { + if n.parent == nil { + return idMapIterator{} + } + n, i = n.parent, n.parentIndex + } + return idMapIterator{n, i} +} + +func idMapzeroValueSlice(slice []uint32) { + + for i := range slice { + idMapFunctions{}.ClearValue(&slice[i]) + } +} + +func idMapzeroNodeSlice(slice []*idMapnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *idMapSet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *idMapnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *idMapnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type idMapSegmentDataSlices struct { + Start []uint32 + End []uint32 + Values []uint32 +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *idMapSet) ExportSortedSlices() *idMapSegmentDataSlices { + var sds idMapSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *idMapSet) ImportSortedSlices(sds *idMapSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := idMapRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *idMapSet) saveRoot() *idMapSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *idMapSet) loadRoot(sds *idMapSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD deleted file mode 100644 index bec13a3d9..000000000 --- a/pkg/sentry/kernel/contexttest/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "contexttest", - testonly = 1, - srcs = ["contexttest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/kernel", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - ], -) diff --git a/pkg/sentry/kernel/contexttest/contexttest.go b/pkg/sentry/kernel/contexttest/contexttest.go deleted file mode 100644 index 82f9d8922..000000000 --- a/pkg/sentry/kernel/contexttest/contexttest.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package contexttest provides a test context.Context which includes -// a dummy kernel pointing to a valid platform. -package contexttest - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" -) - -// Context returns a Context that may be used in tests. Uses ptrace as the -// platform.Platform, and provides a stub kernel that only serves to point to -// the platform. -func Context(tb testing.TB) context.Context { - ctx := contexttest.Context(tb) - k := &kernel.Kernel{ - Platform: platform.FromContext(ctx), - } - k.SetMemoryFile(pgalloc.MemoryFileFromContext(ctx)) - ctx.(*contexttest.TestContext).RegisterValue(kernel.CtxKernel, k) - return ctx -} diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD deleted file mode 100644 index fb99cfc8f..000000000 --- a/pkg/sentry/kernel/epoll/BUILD +++ /dev/null @@ -1,51 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "epoll_list", - out = "epoll_list.go", - package = "epoll", - prefix = "pollEntry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*pollEntry", - "Linker": "*pollEntry", - }, -) - -go_library( - name = "epoll", - srcs = [ - "epoll.go", - "epoll_list.go", - "epoll_state.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/epoll", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/refs", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "epoll_test", - size = "small", - srcs = [ - "epoll_test.go", - ], - embed = [":epoll"], - deps = [ - "//pkg/sentry/context/contexttest", - "//pkg/sentry/fs/filetest", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/epoll/epoll_list.go b/pkg/sentry/kernel/epoll/epoll_list.go new file mode 100755 index 000000000..94d5c9e57 --- /dev/null +++ b/pkg/sentry/kernel/epoll/epoll_list.go @@ -0,0 +1,173 @@ +package epoll + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type pollEntryElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (pollEntryElementMapper) linkerFor(elem *pollEntry) *pollEntry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type pollEntryList struct { + head *pollEntry + tail *pollEntry +} + +// Reset resets list l to the empty state. +func (l *pollEntryList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *pollEntryList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *pollEntryList) Front() *pollEntry { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *pollEntryList) Back() *pollEntry { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *pollEntryList) PushFront(e *pollEntry) { + pollEntryElementMapper{}.linkerFor(e).SetNext(l.head) + pollEntryElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + pollEntryElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *pollEntryList) PushBack(e *pollEntry) { + pollEntryElementMapper{}.linkerFor(e).SetNext(nil) + pollEntryElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + pollEntryElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *pollEntryList) PushBackList(m *pollEntryList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + pollEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head) + pollEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *pollEntryList) InsertAfter(b, e *pollEntry) { + a := pollEntryElementMapper{}.linkerFor(b).Next() + pollEntryElementMapper{}.linkerFor(e).SetNext(a) + pollEntryElementMapper{}.linkerFor(e).SetPrev(b) + pollEntryElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + pollEntryElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *pollEntryList) InsertBefore(a, e *pollEntry) { + b := pollEntryElementMapper{}.linkerFor(a).Prev() + pollEntryElementMapper{}.linkerFor(e).SetNext(a) + pollEntryElementMapper{}.linkerFor(e).SetPrev(b) + pollEntryElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + pollEntryElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *pollEntryList) Remove(e *pollEntry) { + prev := pollEntryElementMapper{}.linkerFor(e).Prev() + next := pollEntryElementMapper{}.linkerFor(e).Next() + + if prev != nil { + pollEntryElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + pollEntryElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type pollEntryEntry struct { + next *pollEntry + prev *pollEntry +} + +// Next returns the entry that follows e in the list. +func (e *pollEntryEntry) Next() *pollEntry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *pollEntryEntry) Prev() *pollEntry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *pollEntryEntry) SetNext(elem *pollEntry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *pollEntryEntry) SetPrev(elem *pollEntry) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/epoll/epoll_state_autogen.go b/pkg/sentry/kernel/epoll/epoll_state_autogen.go new file mode 100755 index 000000000..c738a9355 --- /dev/null +++ b/pkg/sentry/kernel/epoll/epoll_state_autogen.go @@ -0,0 +1,99 @@ +// automatically generated by stateify. + +package epoll + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *FileIdentifier) beforeSave() {} +func (x *FileIdentifier) save(m state.Map) { + x.beforeSave() + m.Save("File", &x.File) + m.Save("Fd", &x.Fd) +} + +func (x *FileIdentifier) afterLoad() {} +func (x *FileIdentifier) load(m state.Map) { + m.LoadWait("File", &x.File) + m.Load("Fd", &x.Fd) +} + +func (x *pollEntry) beforeSave() {} +func (x *pollEntry) save(m state.Map) { + x.beforeSave() + m.Save("pollEntryEntry", &x.pollEntryEntry) + m.Save("id", &x.id) + m.Save("userData", &x.userData) + m.Save("mask", &x.mask) + m.Save("flags", &x.flags) + m.Save("epoll", &x.epoll) +} + +func (x *pollEntry) load(m state.Map) { + m.Load("pollEntryEntry", &x.pollEntryEntry) + m.LoadWait("id", &x.id) + m.Load("userData", &x.userData) + m.Load("mask", &x.mask) + m.Load("flags", &x.flags) + m.Load("epoll", &x.epoll) + m.AfterLoad(x.afterLoad) +} + +func (x *EventPoll) beforeSave() {} +func (x *EventPoll) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.FilePipeSeek) { m.Failf("FilePipeSeek is %v, expected zero", x.FilePipeSeek) } + if !state.IsZeroValue(x.FileNotDirReaddir) { m.Failf("FileNotDirReaddir is %v, expected zero", x.FileNotDirReaddir) } + if !state.IsZeroValue(x.FileNoFsync) { m.Failf("FileNoFsync is %v, expected zero", x.FileNoFsync) } + if !state.IsZeroValue(x.FileNoopFlush) { m.Failf("FileNoopFlush is %v, expected zero", x.FileNoopFlush) } + if !state.IsZeroValue(x.FileNoIoctl) { m.Failf("FileNoIoctl is %v, expected zero", x.FileNoIoctl) } + if !state.IsZeroValue(x.FileNoMMap) { m.Failf("FileNoMMap is %v, expected zero", x.FileNoMMap) } + if !state.IsZeroValue(x.Queue) { m.Failf("Queue is %v, expected zero", x.Queue) } + m.Save("files", &x.files) + m.Save("readyList", &x.readyList) + m.Save("waitingList", &x.waitingList) + m.Save("disabledList", &x.disabledList) +} + +func (x *EventPoll) load(m state.Map) { + m.Load("files", &x.files) + m.Load("readyList", &x.readyList) + m.Load("waitingList", &x.waitingList) + m.Load("disabledList", &x.disabledList) + m.AfterLoad(x.afterLoad) +} + +func (x *pollEntryList) beforeSave() {} +func (x *pollEntryList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *pollEntryList) afterLoad() {} +func (x *pollEntryList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *pollEntryEntry) beforeSave() {} +func (x *pollEntryEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *pollEntryEntry) afterLoad() {} +func (x *pollEntryEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("epoll.FileIdentifier", (*FileIdentifier)(nil), state.Fns{Save: (*FileIdentifier).save, Load: (*FileIdentifier).load}) + state.Register("epoll.pollEntry", (*pollEntry)(nil), state.Fns{Save: (*pollEntry).save, Load: (*pollEntry).load}) + state.Register("epoll.EventPoll", (*EventPoll)(nil), state.Fns{Save: (*EventPoll).save, Load: (*EventPoll).load}) + state.Register("epoll.pollEntryList", (*pollEntryList)(nil), state.Fns{Save: (*pollEntryList).save, Load: (*pollEntryList).load}) + state.Register("epoll.pollEntryEntry", (*pollEntryEntry)(nil), state.Fns{Save: (*pollEntryEntry).save, Load: (*pollEntryEntry).load}) +} diff --git a/pkg/sentry/kernel/epoll/epoll_test.go b/pkg/sentry/kernel/epoll/epoll_test.go deleted file mode 100644 index 4a20d4c82..000000000 --- a/pkg/sentry/kernel/epoll/epoll_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package epoll - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs/filetest" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestFileDestroyed(t *testing.T) { - f := filetest.NewTestFile(t) - id := FileIdentifier{f, 12} - - efile := NewEventPoll(contexttest.Context(t)) - e := efile.FileOperations.(*EventPoll) - if err := e.AddEntry(id, 0, waiter.EventIn, [2]int32{}); err != nil { - t.Fatalf("addEntry failed: %v", err) - } - - // Check that we get an event reported twice in a row. - evt := e.ReadEvents(1) - if len(evt) != 1 { - t.Fatalf("Unexpected number of ready events: want %v, got %v", 1, len(evt)) - } - - evt = e.ReadEvents(1) - if len(evt) != 1 { - t.Fatalf("Unexpected number of ready events: want %v, got %v", 1, len(evt)) - } - - // Destroy the file. Check that we get no more events. - f.DecRef() - - evt = e.ReadEvents(1) - if len(evt) != 0 { - t.Fatalf("Unexpected number of ready events: want %v, got %v", 0, len(evt)) - } - -} diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD deleted file mode 100644 index 1c5f979d4..000000000 --- a/pkg/sentry/kernel/eventfd/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "eventfd", - srcs = ["eventfd.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/eventfd", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/fdnotifier", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) - -go_test( - name = "eventfd_test", - size = "small", - srcs = ["eventfd_test.go"], - embed = [":eventfd"], - deps = [ - "//pkg/sentry/context/contexttest", - "//pkg/sentry/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go b/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go new file mode 100755 index 000000000..f6e94b521 --- /dev/null +++ b/pkg/sentry/kernel/eventfd/eventfd_state_autogen.go @@ -0,0 +1,27 @@ +// automatically generated by stateify. + +package eventfd + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *EventOperations) beforeSave() {} +func (x *EventOperations) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.wq) { m.Failf("wq is %v, expected zero", x.wq) } + m.Save("val", &x.val) + m.Save("semMode", &x.semMode) + m.Save("hostfd", &x.hostfd) +} + +func (x *EventOperations) afterLoad() {} +func (x *EventOperations) load(m state.Map) { + m.Load("val", &x.val) + m.Load("semMode", &x.semMode) + m.Load("hostfd", &x.hostfd) +} + +func init() { + state.Register("eventfd.EventOperations", (*EventOperations)(nil), state.Fns{Save: (*EventOperations).save, Load: (*EventOperations).load}) +} diff --git a/pkg/sentry/kernel/eventfd/eventfd_test.go b/pkg/sentry/kernel/eventfd/eventfd_test.go deleted file mode 100644 index 018c7f3ef..000000000 --- a/pkg/sentry/kernel/eventfd/eventfd_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package eventfd - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestEventfd(t *testing.T) { - initVals := []uint64{ - 0, - // Using a non-zero initial value verifies that writing to an - // eventfd signals when the eventfd's counter was already - // non-zero. - 343, - } - - for _, initVal := range initVals { - ctx := contexttest.Context(t) - - // Make a new event that is writable. - event := New(ctx, initVal, false) - - // Register a callback for a write event. - w, ch := waiter.NewChannelEntry(nil) - event.EventRegister(&w, waiter.EventIn) - defer event.EventUnregister(&w) - - data := []byte("00000124") - // Create and submit a write request. - n, err := event.Writev(ctx, usermem.BytesIOSequence(data)) - if err != nil { - t.Fatal(err) - } - if n != 8 { - t.Errorf("eventfd.write wrote %d bytes, not full int64", n) - } - - // Check if the callback fired due to the write event. - select { - case <-ch: - default: - t.Errorf("Didn't get notified of EventIn after write") - } - } -} - -func TestEventfdStat(t *testing.T) { - ctx := contexttest.Context(t) - - // Make a new event that is writable. - event := New(ctx, 0, false) - - // Create and submit an stat request. - uattr, err := event.Dirent.Inode.UnstableAttr(ctx) - if err != nil { - t.Fatalf("eventfd stat request failed: %v", err) - } - if uattr.Size != 0 { - t.Fatal("EventFD size should be 0") - } -} diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD deleted file mode 100644 index 5eddca115..000000000 --- a/pkg/sentry/kernel/fasync/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "fasync", - srcs = ["fasync.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/fasync", - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/fs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/fasync/fasync_state_autogen.go b/pkg/sentry/kernel/fasync/fasync_state_autogen.go new file mode 100755 index 000000000..0ffb3f7ad --- /dev/null +++ b/pkg/sentry/kernel/fasync/fasync_state_autogen.go @@ -0,0 +1,30 @@ +// automatically generated by stateify. + +package fasync + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *FileAsync) beforeSave() {} +func (x *FileAsync) save(m state.Map) { + x.beforeSave() + m.Save("e", &x.e) + m.Save("requester", &x.requester) + m.Save("recipientPG", &x.recipientPG) + m.Save("recipientTG", &x.recipientTG) + m.Save("recipientT", &x.recipientT) +} + +func (x *FileAsync) afterLoad() {} +func (x *FileAsync) load(m state.Map) { + m.Load("e", &x.e) + m.Load("requester", &x.requester) + m.Load("recipientPG", &x.recipientPG) + m.Load("recipientTG", &x.recipientTG) + m.Load("recipientT", &x.recipientT) +} + +func init() { + state.Register("fasync.FileAsync", (*FileAsync)(nil), state.Fns{Save: (*FileAsync).save, Load: (*FileAsync).load}) +} diff --git a/pkg/sentry/kernel/fd_map_test.go b/pkg/sentry/kernel/fd_map_test.go deleted file mode 100644 index 8571dbe59..000000000 --- a/pkg/sentry/kernel/fd_map_test.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/fs/filetest" - "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs" - "gvisor.dev/gvisor/pkg/sentry/limits" -) - -const ( - // maxFD is the maximum FD to try to create in the map. - // This number of open files has been seen in the wild. - maxFD = 2 * 1024 -) - -func newTestFDMap() *FDMap { - return &FDMap{ - files: make(map[kdefs.FD]descriptor), - } -} - -// TestFDMapMany allocates maxFD FDs, i.e. maxes out the FDMap, -// until there is no room, then makes sure that NewFDAt works -// and also that if we remove one and add one that works too. -func TestFDMapMany(t *testing.T) { - file := filetest.NewTestFile(t) - limitSet := limits.NewLimitSet() - limitSet.Set(limits.NumberOfFiles, limits.Limit{maxFD, maxFD}, true /* privileged */) - - f := newTestFDMap() - for i := 0; i < maxFD; i++ { - if _, err := f.NewFDFrom(0, file, FDFlags{}, limitSet); err != nil { - t.Fatalf("Allocated %v FDs but wanted to allocate %v", i, maxFD) - } - } - - if _, err := f.NewFDFrom(0, file, FDFlags{}, limitSet); err == nil { - t.Fatalf("f.NewFDFrom(0, r) in full map: got nil, wanted error") - } - - if err := f.NewFDAt(1, file, FDFlags{}, limitSet); err != nil { - t.Fatalf("f.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) - } -} - -// TestFDMap does a set of simple tests to make sure simple adds, -// removes, GetRefs, and DecRefs work. The ordering is just weird -// enough that a table-driven approach seemed clumsy. -func TestFDMap(t *testing.T) { - file := filetest.NewTestFile(t) - limitSet := limits.NewLimitSet() - limitSet.Set(limits.NumberOfFiles, limits.Limit{1, maxFD}, true /* privileged */) - - f := newTestFDMap() - if _, err := f.NewFDFrom(0, file, FDFlags{}, limitSet); err != nil { - t.Fatalf("Adding an FD to an empty 1-size map: got %v, want nil", err) - } - - if _, err := f.NewFDFrom(0, file, FDFlags{}, limitSet); err == nil { - t.Fatalf("Adding an FD to a filled 1-size map: got nil, wanted an error") - } - - largeLimit := limits.Limit{maxFD, maxFD} - limitSet.Set(limits.NumberOfFiles, largeLimit, true /* privileged */) - - if fd, err := f.NewFDFrom(0, file, FDFlags{}, limitSet); err != nil { - t.Fatalf("Adding an FD to a resized map: got %v, want nil", err) - } else if fd != kdefs.FD(1) { - t.Fatalf("Added an FD to a resized map: got %v, want 1", fd) - } - - if err := f.NewFDAt(1, file, FDFlags{}, limitSet); err != nil { - t.Fatalf("Replacing FD 1 via f.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) - } - - if err := f.NewFDAt(maxFD+1, file, FDFlags{}, limitSet); err == nil { - t.Fatalf("Using an FD that was too large via f.NewFDAt(%v, r, FDFlags{}): got nil, wanted an error", maxFD+1) - } - - if ref := f.GetFile(1); ref == nil { - t.Fatalf("f.GetFile(1): got nil, wanted %v", file) - } - - if ref := f.GetFile(2); ref != nil { - t.Fatalf("f.GetFile(2): got a %v, wanted nil", ref) - } - - ref, ok := f.Remove(1) - if !ok { - t.Fatalf("f.Remove(1) for an existing FD: failed, want success") - } - ref.DecRef() - - if ref, ok := f.Remove(1); ok { - ref.DecRef() - t.Fatalf("r.Remove(1) for a removed FD: got success, want failure") - } - -} - -func TestDescriptorFlags(t *testing.T) { - file := filetest.NewTestFile(t) - f := newTestFDMap() - limitSet := limits.NewLimitSet() - limitSet.Set(limits.NumberOfFiles, limits.Limit{maxFD, maxFD}, true /* privileged */) - - origFlags := FDFlags{CloseOnExec: true} - - if err := f.NewFDAt(2, file, origFlags, limitSet); err != nil { - t.Fatalf("f.NewFDAt(2, r, FDFlags{}): got %v, wanted nil", err) - } - - newFile, newFlags := f.GetDescriptor(2) - if newFile == nil { - t.Fatalf("f.GetFile(2): got a %v, wanted nil", newFile) - } - - if newFlags != origFlags { - t.Fatalf("new File flags %+v don't match original %+v", newFlags, origFlags) - } -} diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD deleted file mode 100644 index a5cf1f627..000000000 --- a/pkg/sentry/kernel/futex/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "atomicptr_bucket", - out = "atomicptr_bucket.go", - package = "futex", - suffix = "Bucket", - template = "//third_party/gvsync:generic_atomicptr", - types = { - "Value": "bucket", - }, -) - -go_template_instance( - name = "waiter_list", - out = "waiter_list.go", - package = "futex", - prefix = "waiter", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Waiter", - "Linker": "*Waiter", - }, -) - -go_library( - name = "futex", - srcs = [ - "atomicptr_bucket.go", - "futex.go", - "waiter_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/futex", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/sentry/context", - "//pkg/sentry/memmap", - "//pkg/sentry/usermem", - "//pkg/syserror", - ], -) - -go_test( - name = "futex_test", - size = "small", - srcs = ["futex_test.go"], - embed = [":futex"], - deps = ["//pkg/sentry/usermem"], -) diff --git a/pkg/sentry/kernel/futex/atomicptr_bucket.go b/pkg/sentry/kernel/futex/atomicptr_bucket.go new file mode 100755 index 000000000..2251a6e72 --- /dev/null +++ b/pkg/sentry/kernel/futex/atomicptr_bucket.go @@ -0,0 +1,27 @@ +package futex + +import ( + "sync/atomic" + "unsafe" +) + +// An AtomicPtr is a pointer to a value of type Value that can be atomically +// loaded and stored. The zero value of an AtomicPtr represents nil. +// +// Note that copying AtomicPtr by value performs a non-atomic read of the +// stored pointer, which is unsafe if Store() can be called concurrently; in +// this case, do `dst.Store(src.Load())` instead. +type AtomicPtrBucket struct { + ptr unsafe.Pointer +} + +// Load returns the value set by the most recent Store. It returns nil if there +// has been no previous call to Store. +func (p *AtomicPtrBucket) Load() *bucket { + return (*bucket)(atomic.LoadPointer(&p.ptr)) +} + +// Store sets the value returned by Load to x. +func (p *AtomicPtrBucket) Store(x *bucket) { + atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) +} diff --git a/pkg/sentry/kernel/futex/futex_state_autogen.go b/pkg/sentry/kernel/futex/futex_state_autogen.go new file mode 100755 index 000000000..9c0092bb2 --- /dev/null +++ b/pkg/sentry/kernel/futex/futex_state_autogen.go @@ -0,0 +1,62 @@ +// automatically generated by stateify. + +package futex + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *bucket) beforeSave() {} +func (x *bucket) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.waiters) { m.Failf("waiters is %v, expected zero", x.waiters) } +} + +func (x *bucket) afterLoad() {} +func (x *bucket) load(m state.Map) { +} + +func (x *Manager) beforeSave() {} +func (x *Manager) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.privateBuckets) { m.Failf("privateBuckets is %v, expected zero", x.privateBuckets) } + m.Save("sharedBucket", &x.sharedBucket) +} + +func (x *Manager) afterLoad() {} +func (x *Manager) load(m state.Map) { + m.Load("sharedBucket", &x.sharedBucket) +} + +func (x *waiterList) beforeSave() {} +func (x *waiterList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *waiterList) afterLoad() {} +func (x *waiterList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *waiterEntry) beforeSave() {} +func (x *waiterEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *waiterEntry) afterLoad() {} +func (x *waiterEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("futex.bucket", (*bucket)(nil), state.Fns{Save: (*bucket).save, Load: (*bucket).load}) + state.Register("futex.Manager", (*Manager)(nil), state.Fns{Save: (*Manager).save, Load: (*Manager).load}) + state.Register("futex.waiterList", (*waiterList)(nil), state.Fns{Save: (*waiterList).save, Load: (*waiterList).load}) + state.Register("futex.waiterEntry", (*waiterEntry)(nil), state.Fns{Save: (*waiterEntry).save, Load: (*waiterEntry).load}) +} diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go deleted file mode 100644 index 65e5d1428..000000000 --- a/pkg/sentry/kernel/futex/futex_test.go +++ /dev/null @@ -1,530 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package futex - -import ( - "math" - "runtime" - "sync" - "sync/atomic" - "syscall" - "testing" - "unsafe" - - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -// testData implements the Target interface, and allows us to -// treat the address passed for futex operations as an index in -// a byte slice for testing simplicity. -type testData []byte - -const sizeofInt32 = 4 - -func newTestData(size uint) testData { - return make([]byte, size) -} - -func (t testData) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) { - val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t[addr])), new) - return val, nil -} - -func (t testData) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) { - if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t[addr])), old, new) { - return old, nil - } - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil -} - -func (t testData) LoadUint32(addr usermem.Addr) (uint32, error) { - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil -} - -func (t testData) GetSharedKey(addr usermem.Addr) (Key, error) { - return Key{ - Kind: KindSharedMappable, - Offset: uint64(addr), - }, nil -} - -func futexKind(private bool) string { - if private { - return "private" - } - return "shared" -} - -func newPreparedTestWaiter(t *testing.T, m *Manager, ta Target, addr usermem.Addr, private bool, val uint32, bitmask uint32) *Waiter { - w := NewWaiter() - if err := m.WaitPrepare(w, ta, addr, private, val, bitmask); err != nil { - t.Fatalf("WaitPrepare failed: %v", err) - } - return w -} - -func TestFutexWake(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(sizeofInt32) - - // Start waiting for wakeup. - w := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w) - - // Perform a wakeup. - if n, err := m.Wake(d, 0, private, ^uint32(0), 1); err != nil || n != 1 { - t.Errorf("Wake: got (%d, %v), wanted (1, nil)", n, err) - } - - // Expect the waiter to have been woken. - if !w.woken() { - t.Error("waiter not woken") - } - }) - } -} - -func TestFutexWakeBitmask(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(sizeofInt32) - - // Start waiting for wakeup. - w := newPreparedTestWaiter(t, m, d, 0, private, 0, 0x0000ffff) - defer m.WaitComplete(w) - - // Perform a wakeup using the wrong bitmask. - if n, err := m.Wake(d, 0, private, 0xffff0000, 1); err != nil || n != 0 { - t.Errorf("Wake with non-matching bitmask: got (%d, %v), wanted (0, nil)", n, err) - } - - // Expect the waiter to still be waiting. - if w.woken() { - t.Error("waiter woken unexpectedly") - } - - // Perform a wakeup using the right bitmask. - if n, err := m.Wake(d, 0, private, 0x00000001, 1); err != nil || n != 1 { - t.Errorf("Wake with matching bitmask: got (%d, %v), wanted (1, nil)", n, err) - } - - // Expect that the waiter was woken. - if !w.woken() { - t.Error("waiter not woken") - } - }) - } -} - -func TestFutexWakeTwo(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(sizeofInt32) - - // Start three waiters waiting for wakeup. - var ws [3]*Waiter - for i := range ws { - ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) - } - - // Perform two wakeups. - const wakeups = 2 - if n, err := m.Wake(d, 0, private, ^uint32(0), 2); err != nil || n != wakeups { - t.Errorf("Wake: got (%d, %v), wanted (%d, nil)", n, err, wakeups) - } - - // Expect that exactly two waiters were woken. - // We don't get guarantees about exactly which two, - // (although we expect them to be w1 and w2). - awake := 0 - for i := range ws { - if ws[i].woken() { - awake++ - } - } - if awake != wakeups { - t.Errorf("got %d woken waiters, wanted %d", awake, wakeups) - } - }) - } -} - -func TestFutexWakeUnrelated(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(2 * sizeofInt32) - - // Start two waiters waiting for wakeup on different addresses. - w1 := newPreparedTestWaiter(t, m, d, 0*sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, 1*sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Perform two wakeups on the second address. - if n, err := m.Wake(d, 1*sizeofInt32, private, ^uint32(0), 2); err != nil || n != 1 { - t.Errorf("Wake: got (%d, %v), wanted (1, nil)", n, err) - } - - // Expect that only the second waiter was woken. - if w1.woken() { - t.Error("w1 woken unexpectedly") - } - if !w2.woken() { - t.Error("w2 not woken") - } - }) - } -} - -func TestWakeOpEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(2 * sizeofInt32) - - // Perform wakeups with no waiters. - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 0); err != nil || n != 0 { - t.Fatalf("WakeOp: got (%d, %v), wanted (0, nil)", n, err) - } - }) - } -} - -func TestWakeOpFirstNonEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address 0. - w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Perform 10 wakeups on address 0. - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 0, 0); err != nil || n != 2 { - t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err) - } - - // Expect that both waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - }) - } -} - -func TestWakeOpSecondNonEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address sizeofInt32. - w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Perform 10 wakeups on address sizeofInt32 (contingent on - // d.Op(0), which should succeed). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 0, 10, 0); err != nil || n != 2 { - t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err) - } - - // Expect that both waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - }) - } -} - -func TestWakeOpSecondNonEmptyFailingOp(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address sizeofInt32. - w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Perform 10 wakeups on address sizeofInt32 (contingent on - // d.Op(1), which should fail). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 0, 10, 1); err != nil || n != 0 { - t.Errorf("WakeOp: got (%d, %v), wanted (0, nil)", n, err) - } - - // Expect that neither waiter was woken. - if w1.woken() { - t.Error("w1 woken unexpectedly") - } - if w2.woken() { - t.Error("w2 woken unexpectedly") - } - }) - } -} - -func TestWakeOpAllNonEmpty(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address 0. - w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Add two waiters on address sizeofInt32. - w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w3) - w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w4) - - // Perform 10 wakeups on address 0 (unconditionally), and 10 - // wakeups on address sizeofInt32 (contingent on d.Op(0), which - // should succeed). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 0); err != nil || n != 4 { - t.Errorf("WakeOp: got (%d, %v), wanted (4, nil)", n, err) - } - - // Expect that all waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - if !w3.woken() { - t.Error("w3 not woken") - } - if !w4.woken() { - t.Error("w4 not woken") - } - }) - } -} - -func TestWakeOpAllNonEmptyFailingOp(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add two waiters on address 0. - w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w1) - w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(w2) - - // Add two waiters on address sizeofInt32. - w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w3) - w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0)) - defer m.WaitComplete(w4) - - // Perform 10 wakeups on address 0 (unconditionally), and 10 - // wakeups on address sizeofInt32 (contingent on d.Op(1), which - // should fail). - if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 10, 1); err != nil || n != 2 { - t.Errorf("WakeOp: got (%d, %v), wanted (2, nil)", n, err) - } - - // Expect that only the first two waiters were woken. - if !w1.woken() { - t.Error("w1 not woken") - } - if !w2.woken() { - t.Error("w2 not woken") - } - if w3.woken() { - t.Error("w3 woken unexpectedly") - } - if w4.woken() { - t.Error("w4 woken unexpectedly") - } - }) - } -} - -func TestWakeOpSameAddress(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add four waiters on address 0. - var ws [4]*Waiter - for i := range ws { - ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) - } - - // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup - // on address 0 (contingent on d.Op(0), which should succeed). - const wakeups = 2 - if n, err := m.WakeOp(d, 0, 0, private, 1, 1, 0); err != nil || n != wakeups { - t.Errorf("WakeOp: got (%d, %v), wanted (%d, nil)", n, err, wakeups) - } - - // Expect that exactly two waiters were woken. - awake := 0 - for i := range ws { - if ws[i].woken() { - awake++ - } - } - if awake != wakeups { - t.Errorf("got %d woken waiters, wanted %d", awake, wakeups) - } - }) - } -} - -func TestWakeOpSameAddressFailingOp(t *testing.T) { - for _, private := range []bool{false, true} { - t.Run(futexKind(private), func(t *testing.T) { - m := NewManager() - d := newTestData(8) - - // Add four waiters on address 0. - var ws [4]*Waiter - for i := range ws { - ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0)) - defer m.WaitComplete(ws[i]) - } - - // Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup - // on address 0 (contingent on d.Op(1), which should fail). - const wakeups = 1 - if n, err := m.WakeOp(d, 0, 0, private, 1, 1, 1); err != nil || n != wakeups { - t.Errorf("WakeOp: got (%d, %v), wanted (%d, nil)", n, err, wakeups) - } - - // Expect that exactly one waiter was woken. - awake := 0 - for i := range ws { - if ws[i].woken() { - awake++ - } - } - if awake != wakeups { - t.Errorf("got %d woken waiters, wanted %d", awake, wakeups) - } - }) - } -} - -const ( - testMutexSize = sizeofInt32 - testMutexLocked uint32 = 1 - testMutexUnlocked uint32 = 0 -) - -// testMutex ties together a testData slice, an address, and a -// futex manager in order to implement the sync.Locker interface. -// Beyond being used as a Locker, this is a simple mechanism for -// changing the underlying values for simpler tests. -type testMutex struct { - a usermem.Addr - d testData - m *Manager -} - -func newTestMutex(addr usermem.Addr, d testData, m *Manager) *testMutex { - return &testMutex{a: addr, d: d, m: m} -} - -// Lock acquires the testMutex. -// This may wait for it to be available via the futex manager. -func (t *testMutex) Lock() { - for { - // Attempt to grab the lock. - if atomic.CompareAndSwapUint32( - (*uint32)(unsafe.Pointer(&t.d[t.a])), - testMutexUnlocked, - testMutexLocked) { - // Lock held. - return - } - - // Wait for it to be "not locked". - w := NewWaiter() - err := t.m.WaitPrepare(w, t.d, t.a, true, testMutexLocked, ^uint32(0)) - if err == syscall.EAGAIN { - continue - } - if err != nil { - // Should never happen. - panic("WaitPrepare returned unexpected error: " + err.Error()) - } - <-w.C - t.m.WaitComplete(w) - } -} - -// Unlock releases the testMutex. -// This will notify any waiters via the futex manager. -func (t *testMutex) Unlock() { - // Unlock. - atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d[t.a])), testMutexUnlocked) - - // Notify all waiters. - t.m.Wake(t.d, t.a, true, ^uint32(0), math.MaxInt32) -} - -// This function was shamelessly stolen from mutex_test.go. -func HammerMutex(l sync.Locker, loops int, cdone chan bool) { - for i := 0; i < loops; i++ { - l.Lock() - runtime.Gosched() - l.Unlock() - } - cdone <- true -} - -func TestMutexStress(t *testing.T) { - m := NewManager() - d := newTestData(testMutexSize) - tm := newTestMutex(0*testMutexSize, d, m) - c := make(chan bool) - - for i := 0; i < 10; i++ { - go HammerMutex(tm, 1000, c) - } - - for i := 0; i < 10; i++ { - <-c - } -} diff --git a/pkg/sentry/kernel/futex/waiter_list.go b/pkg/sentry/kernel/futex/waiter_list.go new file mode 100755 index 000000000..cca5c4721 --- /dev/null +++ b/pkg/sentry/kernel/futex/waiter_list.go @@ -0,0 +1,173 @@ +package futex + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type waiterElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (waiterElementMapper) linkerFor(elem *Waiter) *Waiter { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type waiterList struct { + head *Waiter + tail *Waiter +} + +// Reset resets list l to the empty state. +func (l *waiterList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *waiterList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *waiterList) Front() *Waiter { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *waiterList) Back() *Waiter { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *waiterList) PushFront(e *Waiter) { + waiterElementMapper{}.linkerFor(e).SetNext(l.head) + waiterElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + waiterElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *waiterList) PushBack(e *Waiter) { + waiterElementMapper{}.linkerFor(e).SetNext(nil) + waiterElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *waiterList) PushBackList(m *waiterList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head) + waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *waiterList) InsertAfter(b, e *Waiter) { + a := waiterElementMapper{}.linkerFor(b).Next() + waiterElementMapper{}.linkerFor(e).SetNext(a) + waiterElementMapper{}.linkerFor(e).SetPrev(b) + waiterElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + waiterElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *waiterList) InsertBefore(a, e *Waiter) { + b := waiterElementMapper{}.linkerFor(a).Prev() + waiterElementMapper{}.linkerFor(e).SetNext(a) + waiterElementMapper{}.linkerFor(e).SetPrev(b) + waiterElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + waiterElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *waiterList) Remove(e *Waiter) { + prev := waiterElementMapper{}.linkerFor(e).Prev() + next := waiterElementMapper{}.linkerFor(e).Next() + + if prev != nil { + waiterElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + waiterElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type waiterEntry struct { + next *Waiter + prev *Waiter +} + +// Next returns the entry that follows e in the list. +func (e *waiterEntry) Next() *Waiter { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *waiterEntry) Prev() *Waiter { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *waiterEntry) SetNext(elem *Waiter) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *waiterEntry) SetPrev(elem *Waiter) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/g3doc/run_states.dot b/pkg/sentry/kernel/g3doc/run_states.dot deleted file mode 100644 index 7861fe1f5..000000000 --- a/pkg/sentry/kernel/g3doc/run_states.dot +++ /dev/null @@ -1,99 +0,0 @@ -digraph { - subgraph { - App; - } - subgraph { - Interrupt; - InterruptAfterSignalDeliveryStop; - } - subgraph { - Syscall; - SyscallAfterPtraceEventSeccomp; - SyscallEnter; - SyscallAfterSyscallEnterStop; - SyscallAfterSysemuStop; - SyscallInvoke; - SyscallAfterPtraceEventClone; - SyscallAfterExecStop; - SyscallAfterVforkStop; - SyscallReinvoke; - SyscallExit; - } - subgraph { - Vsyscall; - VsyscallAfterPtraceEventSeccomp; - VsyscallInvoke; - } - subgraph { - Exit; - ExitMain; // leave thread group, release resources, reparent children, kill PID namespace and wait if TGID 1 - ExitNotify; // signal parent/tracer, become waitable - ExitDone; // represented by t.runState == nil - } - - // Task exit - Exit -> ExitMain; - ExitMain -> ExitNotify; - ExitNotify -> ExitDone; - - // Execution of untrusted application code - App -> App; - - // Interrupts (usually signal delivery) - App -> Interrupt; - Interrupt -> Interrupt; // if other interrupt conditions may still apply - Interrupt -> Exit; // if killed - - // Syscalls - App -> Syscall; - Syscall -> SyscallEnter; - SyscallEnter -> SyscallInvoke; - SyscallInvoke -> SyscallExit; - SyscallExit -> App; - - // exit, exit_group - SyscallInvoke -> Exit; - - // execve - SyscallInvoke -> SyscallAfterExecStop; - SyscallAfterExecStop -> SyscallExit; - SyscallAfterExecStop -> App; // fatal signal pending - - // vfork - SyscallInvoke -> SyscallAfterVforkStop; - SyscallAfterVforkStop -> SyscallExit; - - // Vsyscalls - App -> Vsyscall; - Vsyscall -> VsyscallInvoke; - Vsyscall -> App; // fault while reading return address from stack - VsyscallInvoke -> App; - - // ptrace-specific branches - Interrupt -> InterruptAfterSignalDeliveryStop; - InterruptAfterSignalDeliveryStop -> Interrupt; - SyscallEnter -> SyscallAfterSyscallEnterStop; - SyscallAfterSyscallEnterStop -> SyscallInvoke; - SyscallAfterSyscallEnterStop -> SyscallExit; // skipped by tracer - SyscallAfterSyscallEnterStop -> App; // fatal signal pending - SyscallEnter -> SyscallAfterSysemuStop; - SyscallAfterSysemuStop -> SyscallExit; - SyscallAfterSysemuStop -> App; // fatal signal pending - SyscallInvoke -> SyscallAfterPtraceEventClone; - SyscallAfterPtraceEventClone -> SyscallExit; - SyscallAfterPtraceEventClone -> SyscallAfterVforkStop; - - // seccomp - Syscall -> App; // SECCOMP_RET_TRAP, SECCOMP_RET_ERRNO, SECCOMP_RET_KILL, SECCOMP_RET_TRACE without tracer - Syscall -> SyscallAfterPtraceEventSeccomp; // SECCOMP_RET_TRACE - SyscallAfterPtraceEventSeccomp -> SyscallEnter; - SyscallAfterPtraceEventSeccomp -> SyscallExit; // skipped by tracer - SyscallAfterPtraceEventSeccomp -> App; // fatal signal pending - Vsyscall -> VsyscallAfterPtraceEventSeccomp; - VsyscallAfterPtraceEventSeccomp -> VsyscallInvoke; - VsyscallAfterPtraceEventSeccomp -> App; - - // Autosave - SyscallInvoke -> SyscallReinvoke; - SyscallReinvoke -> SyscallInvoke; -} diff --git a/pkg/sentry/kernel/g3doc/run_states.png b/pkg/sentry/kernel/g3doc/run_states.png Binary files differdeleted file mode 100644 index b63b60f02..000000000 --- a/pkg/sentry/kernel/g3doc/run_states.png +++ /dev/null diff --git a/pkg/sentry/kernel/kdefs/BUILD b/pkg/sentry/kernel/kdefs/BUILD deleted file mode 100644 index 5d62f406a..000000000 --- a/pkg/sentry/kernel/kdefs/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "kdefs", - srcs = ["kdefs.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs", - visibility = ["//:sandbox"], -) diff --git a/pkg/sentry/kernel/kdefs/kdefs_state_autogen.go b/pkg/sentry/kernel/kdefs/kdefs_state_autogen.go new file mode 100755 index 000000000..cef77125b --- /dev/null +++ b/pkg/sentry/kernel/kdefs/kdefs_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package kdefs + diff --git a/pkg/sentry/kernel/kernel_state_autogen.go b/pkg/sentry/kernel/kernel_state_autogen.go new file mode 100755 index 000000000..74b740ec8 --- /dev/null +++ b/pkg/sentry/kernel/kernel_state_autogen.go @@ -0,0 +1,1179 @@ +// automatically generated by stateify. + +package kernel + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/sentry/device" + "gvisor.dev/gvisor/pkg/tcpip" +) + +func (x *abstractEndpoint) beforeSave() {} +func (x *abstractEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("ep", &x.ep) + m.Save("wr", &x.wr) + m.Save("name", &x.name) + m.Save("ns", &x.ns) +} + +func (x *abstractEndpoint) afterLoad() {} +func (x *abstractEndpoint) load(m state.Map) { + m.Load("ep", &x.ep) + m.Load("wr", &x.wr) + m.Load("name", &x.name) + m.Load("ns", &x.ns) +} + +func (x *AbstractSocketNamespace) beforeSave() {} +func (x *AbstractSocketNamespace) save(m state.Map) { + x.beforeSave() + m.Save("endpoints", &x.endpoints) +} + +func (x *AbstractSocketNamespace) afterLoad() {} +func (x *AbstractSocketNamespace) load(m state.Map) { + m.Load("endpoints", &x.endpoints) +} + +func (x *FDFlags) beforeSave() {} +func (x *FDFlags) save(m state.Map) { + x.beforeSave() + m.Save("CloseOnExec", &x.CloseOnExec) +} + +func (x *FDFlags) afterLoad() {} +func (x *FDFlags) load(m state.Map) { + m.Load("CloseOnExec", &x.CloseOnExec) +} + +func (x *descriptor) beforeSave() {} +func (x *descriptor) save(m state.Map) { + x.beforeSave() + m.Save("file", &x.file) + m.Save("flags", &x.flags) +} + +func (x *descriptor) afterLoad() {} +func (x *descriptor) load(m state.Map) { + m.Load("file", &x.file) + m.Load("flags", &x.flags) +} + +func (x *FDMap) beforeSave() {} +func (x *FDMap) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("k", &x.k) + m.Save("files", &x.files) + m.Save("uid", &x.uid) +} + +func (x *FDMap) afterLoad() {} +func (x *FDMap) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("k", &x.k) + m.Load("files", &x.files) + m.Load("uid", &x.uid) +} + +func (x *FSContext) beforeSave() {} +func (x *FSContext) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("root", &x.root) + m.Save("cwd", &x.cwd) + m.Save("umask", &x.umask) +} + +func (x *FSContext) afterLoad() {} +func (x *FSContext) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("root", &x.root) + m.Load("cwd", &x.cwd) + m.Load("umask", &x.umask) +} + +func (x *IPCNamespace) beforeSave() {} +func (x *IPCNamespace) save(m state.Map) { + x.beforeSave() + m.Save("userNS", &x.userNS) + m.Save("semaphores", &x.semaphores) + m.Save("shms", &x.shms) +} + +func (x *IPCNamespace) afterLoad() {} +func (x *IPCNamespace) load(m state.Map) { + m.Load("userNS", &x.userNS) + m.Load("semaphores", &x.semaphores) + m.Load("shms", &x.shms) +} + +func (x *Kernel) beforeSave() {} +func (x *Kernel) save(m state.Map) { + x.beforeSave() + var danglingEndpoints []tcpip.Endpoint = x.saveDanglingEndpoints() + m.SaveValue("danglingEndpoints", danglingEndpoints) + var deviceRegistry *device.Registry = x.saveDeviceRegistry() + m.SaveValue("deviceRegistry", deviceRegistry) + m.Save("featureSet", &x.featureSet) + m.Save("timekeeper", &x.timekeeper) + m.Save("tasks", &x.tasks) + m.Save("rootUserNamespace", &x.rootUserNamespace) + m.Save("applicationCores", &x.applicationCores) + m.Save("useHostCores", &x.useHostCores) + m.Save("extraAuxv", &x.extraAuxv) + m.Save("vdso", &x.vdso) + m.Save("rootUTSNamespace", &x.rootUTSNamespace) + m.Save("rootIPCNamespace", &x.rootIPCNamespace) + m.Save("rootAbstractSocketNamespace", &x.rootAbstractSocketNamespace) + m.Save("mounts", &x.mounts) + m.Save("futexes", &x.futexes) + m.Save("globalInit", &x.globalInit) + m.Save("realtimeClock", &x.realtimeClock) + m.Save("monotonicClock", &x.monotonicClock) + m.Save("syslog", &x.syslog) + m.Save("cpuClock", &x.cpuClock) + m.Save("fdMapUids", &x.fdMapUids) + m.Save("uniqueID", &x.uniqueID) + m.Save("nextInotifyCookie", &x.nextInotifyCookie) + m.Save("netlinkPorts", &x.netlinkPorts) + m.Save("sockets", &x.sockets) + m.Save("nextSocketEntry", &x.nextSocketEntry) + m.Save("DirentCacheLimiter", &x.DirentCacheLimiter) +} + +func (x *Kernel) afterLoad() {} +func (x *Kernel) load(m state.Map) { + m.Load("featureSet", &x.featureSet) + m.Load("timekeeper", &x.timekeeper) + m.Load("tasks", &x.tasks) + m.Load("rootUserNamespace", &x.rootUserNamespace) + m.Load("applicationCores", &x.applicationCores) + m.Load("useHostCores", &x.useHostCores) + m.Load("extraAuxv", &x.extraAuxv) + m.Load("vdso", &x.vdso) + m.Load("rootUTSNamespace", &x.rootUTSNamespace) + m.Load("rootIPCNamespace", &x.rootIPCNamespace) + m.Load("rootAbstractSocketNamespace", &x.rootAbstractSocketNamespace) + m.Load("mounts", &x.mounts) + m.Load("futexes", &x.futexes) + m.Load("globalInit", &x.globalInit) + m.Load("realtimeClock", &x.realtimeClock) + m.Load("monotonicClock", &x.monotonicClock) + m.Load("syslog", &x.syslog) + m.Load("cpuClock", &x.cpuClock) + m.Load("fdMapUids", &x.fdMapUids) + m.Load("uniqueID", &x.uniqueID) + m.Load("nextInotifyCookie", &x.nextInotifyCookie) + m.Load("netlinkPorts", &x.netlinkPorts) + m.Load("sockets", &x.sockets) + m.Load("nextSocketEntry", &x.nextSocketEntry) + m.Load("DirentCacheLimiter", &x.DirentCacheLimiter) + m.LoadValue("danglingEndpoints", new([]tcpip.Endpoint), func(y interface{}) { x.loadDanglingEndpoints(y.([]tcpip.Endpoint)) }) + m.LoadValue("deviceRegistry", new(*device.Registry), func(y interface{}) { x.loadDeviceRegistry(y.(*device.Registry)) }) +} + +func (x *SocketEntry) beforeSave() {} +func (x *SocketEntry) save(m state.Map) { + x.beforeSave() + m.Save("socketEntry", &x.socketEntry) + m.Save("k", &x.k) + m.Save("Sock", &x.Sock) + m.Save("ID", &x.ID) +} + +func (x *SocketEntry) afterLoad() {} +func (x *SocketEntry) load(m state.Map) { + m.Load("socketEntry", &x.socketEntry) + m.Load("k", &x.k) + m.Load("Sock", &x.Sock) + m.Load("ID", &x.ID) +} + +func (x *pendingSignals) beforeSave() {} +func (x *pendingSignals) save(m state.Map) { + x.beforeSave() + var signals []savedPendingSignal = x.saveSignals() + m.SaveValue("signals", signals) +} + +func (x *pendingSignals) afterLoad() {} +func (x *pendingSignals) load(m state.Map) { + m.LoadValue("signals", new([]savedPendingSignal), func(y interface{}) { x.loadSignals(y.([]savedPendingSignal)) }) +} + +func (x *pendingSignalQueue) beforeSave() {} +func (x *pendingSignalQueue) save(m state.Map) { + x.beforeSave() + m.Save("pendingSignalList", &x.pendingSignalList) + m.Save("length", &x.length) +} + +func (x *pendingSignalQueue) afterLoad() {} +func (x *pendingSignalQueue) load(m state.Map) { + m.Load("pendingSignalList", &x.pendingSignalList) + m.Load("length", &x.length) +} + +func (x *pendingSignal) beforeSave() {} +func (x *pendingSignal) save(m state.Map) { + x.beforeSave() + m.Save("pendingSignalEntry", &x.pendingSignalEntry) + m.Save("SignalInfo", &x.SignalInfo) + m.Save("timer", &x.timer) +} + +func (x *pendingSignal) afterLoad() {} +func (x *pendingSignal) load(m state.Map) { + m.Load("pendingSignalEntry", &x.pendingSignalEntry) + m.Load("SignalInfo", &x.SignalInfo) + m.Load("timer", &x.timer) +} + +func (x *pendingSignalList) beforeSave() {} +func (x *pendingSignalList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *pendingSignalList) afterLoad() {} +func (x *pendingSignalList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *pendingSignalEntry) beforeSave() {} +func (x *pendingSignalEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *pendingSignalEntry) afterLoad() {} +func (x *pendingSignalEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *savedPendingSignal) beforeSave() {} +func (x *savedPendingSignal) save(m state.Map) { + x.beforeSave() + m.Save("si", &x.si) + m.Save("timer", &x.timer) +} + +func (x *savedPendingSignal) afterLoad() {} +func (x *savedPendingSignal) load(m state.Map) { + m.Load("si", &x.si) + m.Load("timer", &x.timer) +} + +func (x *IntervalTimer) beforeSave() {} +func (x *IntervalTimer) save(m state.Map) { + x.beforeSave() + m.Save("timer", &x.timer) + m.Save("target", &x.target) + m.Save("signo", &x.signo) + m.Save("id", &x.id) + m.Save("sigval", &x.sigval) + m.Save("group", &x.group) + m.Save("sigpending", &x.sigpending) + m.Save("sigorphan", &x.sigorphan) + m.Save("overrunCur", &x.overrunCur) + m.Save("overrunLast", &x.overrunLast) +} + +func (x *IntervalTimer) afterLoad() {} +func (x *IntervalTimer) load(m state.Map) { + m.Load("timer", &x.timer) + m.Load("target", &x.target) + m.Load("signo", &x.signo) + m.Load("id", &x.id) + m.Load("sigval", &x.sigval) + m.Load("group", &x.group) + m.Load("sigpending", &x.sigpending) + m.Load("sigorphan", &x.sigorphan) + m.Load("overrunCur", &x.overrunCur) + m.Load("overrunLast", &x.overrunLast) +} + +func (x *processGroupList) beforeSave() {} +func (x *processGroupList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *processGroupList) afterLoad() {} +func (x *processGroupList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *processGroupEntry) beforeSave() {} +func (x *processGroupEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *processGroupEntry) afterLoad() {} +func (x *processGroupEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *ptraceOptions) beforeSave() {} +func (x *ptraceOptions) save(m state.Map) { + x.beforeSave() + m.Save("ExitKill", &x.ExitKill) + m.Save("SysGood", &x.SysGood) + m.Save("TraceClone", &x.TraceClone) + m.Save("TraceExec", &x.TraceExec) + m.Save("TraceExit", &x.TraceExit) + m.Save("TraceFork", &x.TraceFork) + m.Save("TraceSeccomp", &x.TraceSeccomp) + m.Save("TraceVfork", &x.TraceVfork) + m.Save("TraceVforkDone", &x.TraceVforkDone) +} + +func (x *ptraceOptions) afterLoad() {} +func (x *ptraceOptions) load(m state.Map) { + m.Load("ExitKill", &x.ExitKill) + m.Load("SysGood", &x.SysGood) + m.Load("TraceClone", &x.TraceClone) + m.Load("TraceExec", &x.TraceExec) + m.Load("TraceExit", &x.TraceExit) + m.Load("TraceFork", &x.TraceFork) + m.Load("TraceSeccomp", &x.TraceSeccomp) + m.Load("TraceVfork", &x.TraceVfork) + m.Load("TraceVforkDone", &x.TraceVforkDone) +} + +func (x *ptraceStop) beforeSave() {} +func (x *ptraceStop) save(m state.Map) { + x.beforeSave() + m.Save("frozen", &x.frozen) + m.Save("listen", &x.listen) +} + +func (x *ptraceStop) afterLoad() {} +func (x *ptraceStop) load(m state.Map) { + m.Load("frozen", &x.frozen) + m.Load("listen", &x.listen) +} + +func (x *RSEQCriticalRegion) beforeSave() {} +func (x *RSEQCriticalRegion) save(m state.Map) { + x.beforeSave() + m.Save("CriticalSection", &x.CriticalSection) + m.Save("Restart", &x.Restart) +} + +func (x *RSEQCriticalRegion) afterLoad() {} +func (x *RSEQCriticalRegion) load(m state.Map) { + m.Load("CriticalSection", &x.CriticalSection) + m.Load("Restart", &x.Restart) +} + +func (x *sessionList) beforeSave() {} +func (x *sessionList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *sessionList) afterLoad() {} +func (x *sessionList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *sessionEntry) beforeSave() {} +func (x *sessionEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *sessionEntry) afterLoad() {} +func (x *sessionEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *Session) beforeSave() {} +func (x *Session) save(m state.Map) { + x.beforeSave() + m.Save("refs", &x.refs) + m.Save("leader", &x.leader) + m.Save("id", &x.id) + m.Save("processGroups", &x.processGroups) + m.Save("sessionEntry", &x.sessionEntry) +} + +func (x *Session) afterLoad() {} +func (x *Session) load(m state.Map) { + m.Load("refs", &x.refs) + m.Load("leader", &x.leader) + m.Load("id", &x.id) + m.Load("processGroups", &x.processGroups) + m.Load("sessionEntry", &x.sessionEntry) +} + +func (x *ProcessGroup) beforeSave() {} +func (x *ProcessGroup) save(m state.Map) { + x.beforeSave() + m.Save("refs", &x.refs) + m.Save("originator", &x.originator) + m.Save("id", &x.id) + m.Save("session", &x.session) + m.Save("ancestors", &x.ancestors) + m.Save("processGroupEntry", &x.processGroupEntry) +} + +func (x *ProcessGroup) afterLoad() {} +func (x *ProcessGroup) load(m state.Map) { + m.Load("refs", &x.refs) + m.Load("originator", &x.originator) + m.Load("id", &x.id) + m.Load("session", &x.session) + m.Load("ancestors", &x.ancestors) + m.Load("processGroupEntry", &x.processGroupEntry) +} + +func (x *SignalHandlers) beforeSave() {} +func (x *SignalHandlers) save(m state.Map) { + x.beforeSave() + m.Save("actions", &x.actions) +} + +func (x *SignalHandlers) afterLoad() {} +func (x *SignalHandlers) load(m state.Map) { + m.Load("actions", &x.actions) +} + +func (x *socketList) beforeSave() {} +func (x *socketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *socketList) afterLoad() {} +func (x *socketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *socketEntry) beforeSave() {} +func (x *socketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *socketEntry) afterLoad() {} +func (x *socketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *SyscallTable) beforeSave() {} +func (x *SyscallTable) save(m state.Map) { + x.beforeSave() + m.Save("OS", &x.OS) + m.Save("Arch", &x.Arch) +} + +func (x *SyscallTable) load(m state.Map) { + m.LoadWait("OS", &x.OS) + m.LoadWait("Arch", &x.Arch) + m.AfterLoad(x.afterLoad) +} + +func (x *syslog) beforeSave() {} +func (x *syslog) save(m state.Map) { + x.beforeSave() + m.Save("msg", &x.msg) +} + +func (x *syslog) afterLoad() {} +func (x *syslog) load(m state.Map) { + m.Load("msg", &x.msg) +} + +func (x *Task) beforeSave() {} +func (x *Task) save(m state.Map) { + x.beforeSave() + var ptraceTracer *Task = x.savePtraceTracer() + m.SaveValue("ptraceTracer", ptraceTracer) + var logPrefix string = x.saveLogPrefix() + m.SaveValue("logPrefix", logPrefix) + var syscallFilters []bpf.Program = x.saveSyscallFilters() + m.SaveValue("syscallFilters", syscallFilters) + m.Save("taskNode", &x.taskNode) + m.Save("runState", &x.runState) + m.Save("haveSyscallReturn", &x.haveSyscallReturn) + m.Save("gosched", &x.gosched) + m.Save("yieldCount", &x.yieldCount) + m.Save("pendingSignals", &x.pendingSignals) + m.Save("signalMask", &x.signalMask) + m.Save("realSignalMask", &x.realSignalMask) + m.Save("haveSavedSignalMask", &x.haveSavedSignalMask) + m.Save("savedSignalMask", &x.savedSignalMask) + m.Save("signalStack", &x.signalStack) + m.Save("groupStopPending", &x.groupStopPending) + m.Save("groupStopAcknowledged", &x.groupStopAcknowledged) + m.Save("trapStopPending", &x.trapStopPending) + m.Save("trapNotifyPending", &x.trapNotifyPending) + m.Save("stop", &x.stop) + m.Save("exitStatus", &x.exitStatus) + m.Save("syscallRestartBlock", &x.syscallRestartBlock) + m.Save("k", &x.k) + m.Save("containerID", &x.containerID) + m.Save("tc", &x.tc) + m.Save("fsc", &x.fsc) + m.Save("fds", &x.fds) + m.Save("vforkParent", &x.vforkParent) + m.Save("exitState", &x.exitState) + m.Save("exitTracerNotified", &x.exitTracerNotified) + m.Save("exitTracerAcked", &x.exitTracerAcked) + m.Save("exitParentNotified", &x.exitParentNotified) + m.Save("exitParentAcked", &x.exitParentAcked) + m.Save("ptraceTracees", &x.ptraceTracees) + m.Save("ptraceSeized", &x.ptraceSeized) + m.Save("ptraceOpts", &x.ptraceOpts) + m.Save("ptraceSyscallMode", &x.ptraceSyscallMode) + m.Save("ptraceSinglestep", &x.ptraceSinglestep) + m.Save("ptraceCode", &x.ptraceCode) + m.Save("ptraceSiginfo", &x.ptraceSiginfo) + m.Save("ptraceEventMsg", &x.ptraceEventMsg) + m.Save("ioUsage", &x.ioUsage) + m.Save("creds", &x.creds) + m.Save("utsns", &x.utsns) + m.Save("ipcns", &x.ipcns) + m.Save("abstractSockets", &x.abstractSockets) + m.Save("parentDeathSignal", &x.parentDeathSignal) + m.Save("cleartid", &x.cleartid) + m.Save("allowedCPUMask", &x.allowedCPUMask) + m.Save("cpu", &x.cpu) + m.Save("niceness", &x.niceness) + m.Save("numaPolicy", &x.numaPolicy) + m.Save("numaNodeMask", &x.numaNodeMask) + m.Save("netns", &x.netns) + m.Save("rseqCPUAddr", &x.rseqCPUAddr) + m.Save("rseqCPU", &x.rseqCPU) + m.Save("startTime", &x.startTime) +} + +func (x *Task) load(m state.Map) { + m.Load("taskNode", &x.taskNode) + m.Load("runState", &x.runState) + m.Load("haveSyscallReturn", &x.haveSyscallReturn) + m.Load("gosched", &x.gosched) + m.Load("yieldCount", &x.yieldCount) + m.Load("pendingSignals", &x.pendingSignals) + m.Load("signalMask", &x.signalMask) + m.Load("realSignalMask", &x.realSignalMask) + m.Load("haveSavedSignalMask", &x.haveSavedSignalMask) + m.Load("savedSignalMask", &x.savedSignalMask) + m.Load("signalStack", &x.signalStack) + m.Load("groupStopPending", &x.groupStopPending) + m.Load("groupStopAcknowledged", &x.groupStopAcknowledged) + m.Load("trapStopPending", &x.trapStopPending) + m.Load("trapNotifyPending", &x.trapNotifyPending) + m.Load("stop", &x.stop) + m.Load("exitStatus", &x.exitStatus) + m.Load("syscallRestartBlock", &x.syscallRestartBlock) + m.Load("k", &x.k) + m.Load("containerID", &x.containerID) + m.Load("tc", &x.tc) + m.Load("fsc", &x.fsc) + m.Load("fds", &x.fds) + m.Load("vforkParent", &x.vforkParent) + m.Load("exitState", &x.exitState) + m.Load("exitTracerNotified", &x.exitTracerNotified) + m.Load("exitTracerAcked", &x.exitTracerAcked) + m.Load("exitParentNotified", &x.exitParentNotified) + m.Load("exitParentAcked", &x.exitParentAcked) + m.Load("ptraceTracees", &x.ptraceTracees) + m.Load("ptraceSeized", &x.ptraceSeized) + m.Load("ptraceOpts", &x.ptraceOpts) + m.Load("ptraceSyscallMode", &x.ptraceSyscallMode) + m.Load("ptraceSinglestep", &x.ptraceSinglestep) + m.Load("ptraceCode", &x.ptraceCode) + m.Load("ptraceSiginfo", &x.ptraceSiginfo) + m.Load("ptraceEventMsg", &x.ptraceEventMsg) + m.Load("ioUsage", &x.ioUsage) + m.Load("creds", &x.creds) + m.Load("utsns", &x.utsns) + m.Load("ipcns", &x.ipcns) + m.Load("abstractSockets", &x.abstractSockets) + m.Load("parentDeathSignal", &x.parentDeathSignal) + m.Load("cleartid", &x.cleartid) + m.Load("allowedCPUMask", &x.allowedCPUMask) + m.Load("cpu", &x.cpu) + m.Load("niceness", &x.niceness) + m.Load("numaPolicy", &x.numaPolicy) + m.Load("numaNodeMask", &x.numaNodeMask) + m.Load("netns", &x.netns) + m.Load("rseqCPUAddr", &x.rseqCPUAddr) + m.Load("rseqCPU", &x.rseqCPU) + m.Load("startTime", &x.startTime) + m.LoadValue("ptraceTracer", new(*Task), func(y interface{}) { x.loadPtraceTracer(y.(*Task)) }) + m.LoadValue("logPrefix", new(string), func(y interface{}) { x.loadLogPrefix(y.(string)) }) + m.LoadValue("syscallFilters", new([]bpf.Program), func(y interface{}) { x.loadSyscallFilters(y.([]bpf.Program)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *runSyscallAfterPtraceEventClone) beforeSave() {} +func (x *runSyscallAfterPtraceEventClone) save(m state.Map) { + x.beforeSave() + m.Save("vforkChild", &x.vforkChild) + m.Save("vforkChildTID", &x.vforkChildTID) +} + +func (x *runSyscallAfterPtraceEventClone) afterLoad() {} +func (x *runSyscallAfterPtraceEventClone) load(m state.Map) { + m.Load("vforkChild", &x.vforkChild) + m.Load("vforkChildTID", &x.vforkChildTID) +} + +func (x *runSyscallAfterVforkStop) beforeSave() {} +func (x *runSyscallAfterVforkStop) save(m state.Map) { + x.beforeSave() + m.Save("childTID", &x.childTID) +} + +func (x *runSyscallAfterVforkStop) afterLoad() {} +func (x *runSyscallAfterVforkStop) load(m state.Map) { + m.Load("childTID", &x.childTID) +} + +func (x *vforkStop) beforeSave() {} +func (x *vforkStop) save(m state.Map) { + x.beforeSave() +} + +func (x *vforkStop) afterLoad() {} +func (x *vforkStop) load(m state.Map) { +} + +func (x *TaskContext) beforeSave() {} +func (x *TaskContext) save(m state.Map) { + x.beforeSave() + m.Save("Name", &x.Name) + m.Save("Arch", &x.Arch) + m.Save("MemoryManager", &x.MemoryManager) + m.Save("fu", &x.fu) + m.Save("st", &x.st) +} + +func (x *TaskContext) afterLoad() {} +func (x *TaskContext) load(m state.Map) { + m.Load("Name", &x.Name) + m.Load("Arch", &x.Arch) + m.Load("MemoryManager", &x.MemoryManager) + m.Load("fu", &x.fu) + m.Load("st", &x.st) +} + +func (x *execStop) beforeSave() {} +func (x *execStop) save(m state.Map) { + x.beforeSave() +} + +func (x *execStop) afterLoad() {} +func (x *execStop) load(m state.Map) { +} + +func (x *runSyscallAfterExecStop) beforeSave() {} +func (x *runSyscallAfterExecStop) save(m state.Map) { + x.beforeSave() + m.Save("tc", &x.tc) +} + +func (x *runSyscallAfterExecStop) afterLoad() {} +func (x *runSyscallAfterExecStop) load(m state.Map) { + m.Load("tc", &x.tc) +} + +func (x *ExitStatus) beforeSave() {} +func (x *ExitStatus) save(m state.Map) { + x.beforeSave() + m.Save("Code", &x.Code) + m.Save("Signo", &x.Signo) +} + +func (x *ExitStatus) afterLoad() {} +func (x *ExitStatus) load(m state.Map) { + m.Load("Code", &x.Code) + m.Load("Signo", &x.Signo) +} + +func (x *runExit) beforeSave() {} +func (x *runExit) save(m state.Map) { + x.beforeSave() +} + +func (x *runExit) afterLoad() {} +func (x *runExit) load(m state.Map) { +} + +func (x *runExitMain) beforeSave() {} +func (x *runExitMain) save(m state.Map) { + x.beforeSave() +} + +func (x *runExitMain) afterLoad() {} +func (x *runExitMain) load(m state.Map) { +} + +func (x *runExitNotify) beforeSave() {} +func (x *runExitNotify) save(m state.Map) { + x.beforeSave() +} + +func (x *runExitNotify) afterLoad() {} +func (x *runExitNotify) load(m state.Map) { +} + +func (x *taskList) beforeSave() {} +func (x *taskList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *taskList) afterLoad() {} +func (x *taskList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *taskEntry) beforeSave() {} +func (x *taskEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *taskEntry) afterLoad() {} +func (x *taskEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *runApp) beforeSave() {} +func (x *runApp) save(m state.Map) { + x.beforeSave() +} + +func (x *runApp) afterLoad() {} +func (x *runApp) load(m state.Map) { +} + +func (x *TaskGoroutineSchedInfo) beforeSave() {} +func (x *TaskGoroutineSchedInfo) save(m state.Map) { + x.beforeSave() + m.Save("Timestamp", &x.Timestamp) + m.Save("State", &x.State) + m.Save("UserTicks", &x.UserTicks) + m.Save("SysTicks", &x.SysTicks) +} + +func (x *TaskGoroutineSchedInfo) afterLoad() {} +func (x *TaskGoroutineSchedInfo) load(m state.Map) { + m.Load("Timestamp", &x.Timestamp) + m.Load("State", &x.State) + m.Load("UserTicks", &x.UserTicks) + m.Load("SysTicks", &x.SysTicks) +} + +func (x *taskClock) beforeSave() {} +func (x *taskClock) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) + m.Save("includeSys", &x.includeSys) +} + +func (x *taskClock) afterLoad() {} +func (x *taskClock) load(m state.Map) { + m.Load("t", &x.t) + m.Load("includeSys", &x.includeSys) +} + +func (x *tgClock) beforeSave() {} +func (x *tgClock) save(m state.Map) { + x.beforeSave() + m.Save("tg", &x.tg) + m.Save("includeSys", &x.includeSys) +} + +func (x *tgClock) afterLoad() {} +func (x *tgClock) load(m state.Map) { + m.Load("tg", &x.tg) + m.Load("includeSys", &x.includeSys) +} + +func (x *groupStop) beforeSave() {} +func (x *groupStop) save(m state.Map) { + x.beforeSave() +} + +func (x *groupStop) afterLoad() {} +func (x *groupStop) load(m state.Map) { +} + +func (x *runInterrupt) beforeSave() {} +func (x *runInterrupt) save(m state.Map) { + x.beforeSave() +} + +func (x *runInterrupt) afterLoad() {} +func (x *runInterrupt) load(m state.Map) { +} + +func (x *runInterruptAfterSignalDeliveryStop) beforeSave() {} +func (x *runInterruptAfterSignalDeliveryStop) save(m state.Map) { + x.beforeSave() +} + +func (x *runInterruptAfterSignalDeliveryStop) afterLoad() {} +func (x *runInterruptAfterSignalDeliveryStop) load(m state.Map) { +} + +func (x *runSyscallAfterSyscallEnterStop) beforeSave() {} +func (x *runSyscallAfterSyscallEnterStop) save(m state.Map) { + x.beforeSave() +} + +func (x *runSyscallAfterSyscallEnterStop) afterLoad() {} +func (x *runSyscallAfterSyscallEnterStop) load(m state.Map) { +} + +func (x *runSyscallAfterSysemuStop) beforeSave() {} +func (x *runSyscallAfterSysemuStop) save(m state.Map) { + x.beforeSave() +} + +func (x *runSyscallAfterSysemuStop) afterLoad() {} +func (x *runSyscallAfterSysemuStop) load(m state.Map) { +} + +func (x *runSyscallReinvoke) beforeSave() {} +func (x *runSyscallReinvoke) save(m state.Map) { + x.beforeSave() +} + +func (x *runSyscallReinvoke) afterLoad() {} +func (x *runSyscallReinvoke) load(m state.Map) { +} + +func (x *runSyscallExit) beforeSave() {} +func (x *runSyscallExit) save(m state.Map) { + x.beforeSave() +} + +func (x *runSyscallExit) afterLoad() {} +func (x *runSyscallExit) load(m state.Map) { +} + +func (x *ThreadGroup) beforeSave() {} +func (x *ThreadGroup) save(m state.Map) { + x.beforeSave() + var rscr *RSEQCriticalRegion = x.saveRscr() + m.SaveValue("rscr", rscr) + m.Save("threadGroupNode", &x.threadGroupNode) + m.Save("signalHandlers", &x.signalHandlers) + m.Save("pendingSignals", &x.pendingSignals) + m.Save("groupStopDequeued", &x.groupStopDequeued) + m.Save("groupStopSignal", &x.groupStopSignal) + m.Save("groupStopPendingCount", &x.groupStopPendingCount) + m.Save("groupStopComplete", &x.groupStopComplete) + m.Save("groupStopWaitable", &x.groupStopWaitable) + m.Save("groupContNotify", &x.groupContNotify) + m.Save("groupContInterrupted", &x.groupContInterrupted) + m.Save("groupContWaitable", &x.groupContWaitable) + m.Save("exiting", &x.exiting) + m.Save("exitStatus", &x.exitStatus) + m.Save("terminationSignal", &x.terminationSignal) + m.Save("itimerRealTimer", &x.itimerRealTimer) + m.Save("itimerVirtSetting", &x.itimerVirtSetting) + m.Save("itimerProfSetting", &x.itimerProfSetting) + m.Save("rlimitCPUSoftSetting", &x.rlimitCPUSoftSetting) + m.Save("cpuTimersEnabled", &x.cpuTimersEnabled) + m.Save("timers", &x.timers) + m.Save("nextTimerID", &x.nextTimerID) + m.Save("exitedCPUStats", &x.exitedCPUStats) + m.Save("childCPUStats", &x.childCPUStats) + m.Save("ioUsage", &x.ioUsage) + m.Save("maxRSS", &x.maxRSS) + m.Save("childMaxRSS", &x.childMaxRSS) + m.Save("limits", &x.limits) + m.Save("processGroup", &x.processGroup) + m.Save("execed", &x.execed) +} + +func (x *ThreadGroup) afterLoad() {} +func (x *ThreadGroup) load(m state.Map) { + m.Load("threadGroupNode", &x.threadGroupNode) + m.Load("signalHandlers", &x.signalHandlers) + m.Load("pendingSignals", &x.pendingSignals) + m.Load("groupStopDequeued", &x.groupStopDequeued) + m.Load("groupStopSignal", &x.groupStopSignal) + m.Load("groupStopPendingCount", &x.groupStopPendingCount) + m.Load("groupStopComplete", &x.groupStopComplete) + m.Load("groupStopWaitable", &x.groupStopWaitable) + m.Load("groupContNotify", &x.groupContNotify) + m.Load("groupContInterrupted", &x.groupContInterrupted) + m.Load("groupContWaitable", &x.groupContWaitable) + m.Load("exiting", &x.exiting) + m.Load("exitStatus", &x.exitStatus) + m.Load("terminationSignal", &x.terminationSignal) + m.Load("itimerRealTimer", &x.itimerRealTimer) + m.Load("itimerVirtSetting", &x.itimerVirtSetting) + m.Load("itimerProfSetting", &x.itimerProfSetting) + m.Load("rlimitCPUSoftSetting", &x.rlimitCPUSoftSetting) + m.Load("cpuTimersEnabled", &x.cpuTimersEnabled) + m.Load("timers", &x.timers) + m.Load("nextTimerID", &x.nextTimerID) + m.Load("exitedCPUStats", &x.exitedCPUStats) + m.Load("childCPUStats", &x.childCPUStats) + m.Load("ioUsage", &x.ioUsage) + m.Load("maxRSS", &x.maxRSS) + m.Load("childMaxRSS", &x.childMaxRSS) + m.Load("limits", &x.limits) + m.Load("processGroup", &x.processGroup) + m.Load("execed", &x.execed) + m.LoadValue("rscr", new(*RSEQCriticalRegion), func(y interface{}) { x.loadRscr(y.(*RSEQCriticalRegion)) }) +} + +func (x *itimerRealListener) beforeSave() {} +func (x *itimerRealListener) save(m state.Map) { + x.beforeSave() + m.Save("tg", &x.tg) +} + +func (x *itimerRealListener) afterLoad() {} +func (x *itimerRealListener) load(m state.Map) { + m.Load("tg", &x.tg) +} + +func (x *TaskSet) beforeSave() {} +func (x *TaskSet) save(m state.Map) { + x.beforeSave() + m.Save("Root", &x.Root) + m.Save("sessions", &x.sessions) +} + +func (x *TaskSet) afterLoad() {} +func (x *TaskSet) load(m state.Map) { + m.Load("Root", &x.Root) + m.Load("sessions", &x.sessions) +} + +func (x *PIDNamespace) beforeSave() {} +func (x *PIDNamespace) save(m state.Map) { + x.beforeSave() + m.Save("owner", &x.owner) + m.Save("parent", &x.parent) + m.Save("userns", &x.userns) + m.Save("last", &x.last) + m.Save("tasks", &x.tasks) + m.Save("tids", &x.tids) + m.Save("tgids", &x.tgids) + m.Save("sessions", &x.sessions) + m.Save("sids", &x.sids) + m.Save("processGroups", &x.processGroups) + m.Save("pgids", &x.pgids) + m.Save("exiting", &x.exiting) +} + +func (x *PIDNamespace) afterLoad() {} +func (x *PIDNamespace) load(m state.Map) { + m.Load("owner", &x.owner) + m.Load("parent", &x.parent) + m.Load("userns", &x.userns) + m.Load("last", &x.last) + m.Load("tasks", &x.tasks) + m.Load("tids", &x.tids) + m.Load("tgids", &x.tgids) + m.Load("sessions", &x.sessions) + m.Load("sids", &x.sids) + m.Load("processGroups", &x.processGroups) + m.Load("pgids", &x.pgids) + m.Load("exiting", &x.exiting) +} + +func (x *threadGroupNode) beforeSave() {} +func (x *threadGroupNode) save(m state.Map) { + x.beforeSave() + m.Save("pidns", &x.pidns) + m.Save("leader", &x.leader) + m.Save("execing", &x.execing) + m.Save("tasks", &x.tasks) + m.Save("tasksCount", &x.tasksCount) + m.Save("liveTasks", &x.liveTasks) + m.Save("activeTasks", &x.activeTasks) +} + +func (x *threadGroupNode) afterLoad() {} +func (x *threadGroupNode) load(m state.Map) { + m.Load("pidns", &x.pidns) + m.Load("leader", &x.leader) + m.Load("execing", &x.execing) + m.Load("tasks", &x.tasks) + m.Load("tasksCount", &x.tasksCount) + m.Load("liveTasks", &x.liveTasks) + m.Load("activeTasks", &x.activeTasks) +} + +func (x *taskNode) beforeSave() {} +func (x *taskNode) save(m state.Map) { + x.beforeSave() + m.Save("tg", &x.tg) + m.Save("taskEntry", &x.taskEntry) + m.Save("parent", &x.parent) + m.Save("children", &x.children) + m.Save("childPIDNamespace", &x.childPIDNamespace) +} + +func (x *taskNode) afterLoad() {} +func (x *taskNode) load(m state.Map) { + m.LoadWait("tg", &x.tg) + m.Load("taskEntry", &x.taskEntry) + m.Load("parent", &x.parent) + m.Load("children", &x.children) + m.Load("childPIDNamespace", &x.childPIDNamespace) +} + +func (x *Timekeeper) save(m state.Map) { + x.beforeSave() + m.Save("bootTime", &x.bootTime) + m.Save("saveMonotonic", &x.saveMonotonic) + m.Save("saveRealtime", &x.saveRealtime) + m.Save("params", &x.params) +} + +func (x *Timekeeper) load(m state.Map) { + m.Load("bootTime", &x.bootTime) + m.Load("saveMonotonic", &x.saveMonotonic) + m.Load("saveRealtime", &x.saveRealtime) + m.Load("params", &x.params) + m.AfterLoad(x.afterLoad) +} + +func (x *timekeeperClock) beforeSave() {} +func (x *timekeeperClock) save(m state.Map) { + x.beforeSave() + m.Save("tk", &x.tk) + m.Save("c", &x.c) +} + +func (x *timekeeperClock) afterLoad() {} +func (x *timekeeperClock) load(m state.Map) { + m.Load("tk", &x.tk) + m.Load("c", &x.c) +} + +func (x *UTSNamespace) beforeSave() {} +func (x *UTSNamespace) save(m state.Map) { + x.beforeSave() + m.Save("hostName", &x.hostName) + m.Save("domainName", &x.domainName) + m.Save("userns", &x.userns) +} + +func (x *UTSNamespace) afterLoad() {} +func (x *UTSNamespace) load(m state.Map) { + m.Load("hostName", &x.hostName) + m.Load("domainName", &x.domainName) + m.Load("userns", &x.userns) +} + +func (x *VDSOParamPage) beforeSave() {} +func (x *VDSOParamPage) save(m state.Map) { + x.beforeSave() + m.Save("mfp", &x.mfp) + m.Save("fr", &x.fr) + m.Save("seq", &x.seq) +} + +func (x *VDSOParamPage) afterLoad() {} +func (x *VDSOParamPage) load(m state.Map) { + m.Load("mfp", &x.mfp) + m.Load("fr", &x.fr) + m.Load("seq", &x.seq) +} + +func init() { + state.Register("kernel.abstractEndpoint", (*abstractEndpoint)(nil), state.Fns{Save: (*abstractEndpoint).save, Load: (*abstractEndpoint).load}) + state.Register("kernel.AbstractSocketNamespace", (*AbstractSocketNamespace)(nil), state.Fns{Save: (*AbstractSocketNamespace).save, Load: (*AbstractSocketNamespace).load}) + state.Register("kernel.FDFlags", (*FDFlags)(nil), state.Fns{Save: (*FDFlags).save, Load: (*FDFlags).load}) + state.Register("kernel.descriptor", (*descriptor)(nil), state.Fns{Save: (*descriptor).save, Load: (*descriptor).load}) + state.Register("kernel.FDMap", (*FDMap)(nil), state.Fns{Save: (*FDMap).save, Load: (*FDMap).load}) + state.Register("kernel.FSContext", (*FSContext)(nil), state.Fns{Save: (*FSContext).save, Load: (*FSContext).load}) + state.Register("kernel.IPCNamespace", (*IPCNamespace)(nil), state.Fns{Save: (*IPCNamespace).save, Load: (*IPCNamespace).load}) + state.Register("kernel.Kernel", (*Kernel)(nil), state.Fns{Save: (*Kernel).save, Load: (*Kernel).load}) + state.Register("kernel.SocketEntry", (*SocketEntry)(nil), state.Fns{Save: (*SocketEntry).save, Load: (*SocketEntry).load}) + state.Register("kernel.pendingSignals", (*pendingSignals)(nil), state.Fns{Save: (*pendingSignals).save, Load: (*pendingSignals).load}) + state.Register("kernel.pendingSignalQueue", (*pendingSignalQueue)(nil), state.Fns{Save: (*pendingSignalQueue).save, Load: (*pendingSignalQueue).load}) + state.Register("kernel.pendingSignal", (*pendingSignal)(nil), state.Fns{Save: (*pendingSignal).save, Load: (*pendingSignal).load}) + state.Register("kernel.pendingSignalList", (*pendingSignalList)(nil), state.Fns{Save: (*pendingSignalList).save, Load: (*pendingSignalList).load}) + state.Register("kernel.pendingSignalEntry", (*pendingSignalEntry)(nil), state.Fns{Save: (*pendingSignalEntry).save, Load: (*pendingSignalEntry).load}) + state.Register("kernel.savedPendingSignal", (*savedPendingSignal)(nil), state.Fns{Save: (*savedPendingSignal).save, Load: (*savedPendingSignal).load}) + state.Register("kernel.IntervalTimer", (*IntervalTimer)(nil), state.Fns{Save: (*IntervalTimer).save, Load: (*IntervalTimer).load}) + state.Register("kernel.processGroupList", (*processGroupList)(nil), state.Fns{Save: (*processGroupList).save, Load: (*processGroupList).load}) + state.Register("kernel.processGroupEntry", (*processGroupEntry)(nil), state.Fns{Save: (*processGroupEntry).save, Load: (*processGroupEntry).load}) + state.Register("kernel.ptraceOptions", (*ptraceOptions)(nil), state.Fns{Save: (*ptraceOptions).save, Load: (*ptraceOptions).load}) + state.Register("kernel.ptraceStop", (*ptraceStop)(nil), state.Fns{Save: (*ptraceStop).save, Load: (*ptraceStop).load}) + state.Register("kernel.RSEQCriticalRegion", (*RSEQCriticalRegion)(nil), state.Fns{Save: (*RSEQCriticalRegion).save, Load: (*RSEQCriticalRegion).load}) + state.Register("kernel.sessionList", (*sessionList)(nil), state.Fns{Save: (*sessionList).save, Load: (*sessionList).load}) + state.Register("kernel.sessionEntry", (*sessionEntry)(nil), state.Fns{Save: (*sessionEntry).save, Load: (*sessionEntry).load}) + state.Register("kernel.Session", (*Session)(nil), state.Fns{Save: (*Session).save, Load: (*Session).load}) + state.Register("kernel.ProcessGroup", (*ProcessGroup)(nil), state.Fns{Save: (*ProcessGroup).save, Load: (*ProcessGroup).load}) + state.Register("kernel.SignalHandlers", (*SignalHandlers)(nil), state.Fns{Save: (*SignalHandlers).save, Load: (*SignalHandlers).load}) + state.Register("kernel.socketList", (*socketList)(nil), state.Fns{Save: (*socketList).save, Load: (*socketList).load}) + state.Register("kernel.socketEntry", (*socketEntry)(nil), state.Fns{Save: (*socketEntry).save, Load: (*socketEntry).load}) + state.Register("kernel.SyscallTable", (*SyscallTable)(nil), state.Fns{Save: (*SyscallTable).save, Load: (*SyscallTable).load}) + state.Register("kernel.syslog", (*syslog)(nil), state.Fns{Save: (*syslog).save, Load: (*syslog).load}) + state.Register("kernel.Task", (*Task)(nil), state.Fns{Save: (*Task).save, Load: (*Task).load}) + state.Register("kernel.runSyscallAfterPtraceEventClone", (*runSyscallAfterPtraceEventClone)(nil), state.Fns{Save: (*runSyscallAfterPtraceEventClone).save, Load: (*runSyscallAfterPtraceEventClone).load}) + state.Register("kernel.runSyscallAfterVforkStop", (*runSyscallAfterVforkStop)(nil), state.Fns{Save: (*runSyscallAfterVforkStop).save, Load: (*runSyscallAfterVforkStop).load}) + state.Register("kernel.vforkStop", (*vforkStop)(nil), state.Fns{Save: (*vforkStop).save, Load: (*vforkStop).load}) + state.Register("kernel.TaskContext", (*TaskContext)(nil), state.Fns{Save: (*TaskContext).save, Load: (*TaskContext).load}) + state.Register("kernel.execStop", (*execStop)(nil), state.Fns{Save: (*execStop).save, Load: (*execStop).load}) + state.Register("kernel.runSyscallAfterExecStop", (*runSyscallAfterExecStop)(nil), state.Fns{Save: (*runSyscallAfterExecStop).save, Load: (*runSyscallAfterExecStop).load}) + state.Register("kernel.ExitStatus", (*ExitStatus)(nil), state.Fns{Save: (*ExitStatus).save, Load: (*ExitStatus).load}) + state.Register("kernel.runExit", (*runExit)(nil), state.Fns{Save: (*runExit).save, Load: (*runExit).load}) + state.Register("kernel.runExitMain", (*runExitMain)(nil), state.Fns{Save: (*runExitMain).save, Load: (*runExitMain).load}) + state.Register("kernel.runExitNotify", (*runExitNotify)(nil), state.Fns{Save: (*runExitNotify).save, Load: (*runExitNotify).load}) + state.Register("kernel.taskList", (*taskList)(nil), state.Fns{Save: (*taskList).save, Load: (*taskList).load}) + state.Register("kernel.taskEntry", (*taskEntry)(nil), state.Fns{Save: (*taskEntry).save, Load: (*taskEntry).load}) + state.Register("kernel.runApp", (*runApp)(nil), state.Fns{Save: (*runApp).save, Load: (*runApp).load}) + state.Register("kernel.TaskGoroutineSchedInfo", (*TaskGoroutineSchedInfo)(nil), state.Fns{Save: (*TaskGoroutineSchedInfo).save, Load: (*TaskGoroutineSchedInfo).load}) + state.Register("kernel.taskClock", (*taskClock)(nil), state.Fns{Save: (*taskClock).save, Load: (*taskClock).load}) + state.Register("kernel.tgClock", (*tgClock)(nil), state.Fns{Save: (*tgClock).save, Load: (*tgClock).load}) + state.Register("kernel.groupStop", (*groupStop)(nil), state.Fns{Save: (*groupStop).save, Load: (*groupStop).load}) + state.Register("kernel.runInterrupt", (*runInterrupt)(nil), state.Fns{Save: (*runInterrupt).save, Load: (*runInterrupt).load}) + state.Register("kernel.runInterruptAfterSignalDeliveryStop", (*runInterruptAfterSignalDeliveryStop)(nil), state.Fns{Save: (*runInterruptAfterSignalDeliveryStop).save, Load: (*runInterruptAfterSignalDeliveryStop).load}) + state.Register("kernel.runSyscallAfterSyscallEnterStop", (*runSyscallAfterSyscallEnterStop)(nil), state.Fns{Save: (*runSyscallAfterSyscallEnterStop).save, Load: (*runSyscallAfterSyscallEnterStop).load}) + state.Register("kernel.runSyscallAfterSysemuStop", (*runSyscallAfterSysemuStop)(nil), state.Fns{Save: (*runSyscallAfterSysemuStop).save, Load: (*runSyscallAfterSysemuStop).load}) + state.Register("kernel.runSyscallReinvoke", (*runSyscallReinvoke)(nil), state.Fns{Save: (*runSyscallReinvoke).save, Load: (*runSyscallReinvoke).load}) + state.Register("kernel.runSyscallExit", (*runSyscallExit)(nil), state.Fns{Save: (*runSyscallExit).save, Load: (*runSyscallExit).load}) + state.Register("kernel.ThreadGroup", (*ThreadGroup)(nil), state.Fns{Save: (*ThreadGroup).save, Load: (*ThreadGroup).load}) + state.Register("kernel.itimerRealListener", (*itimerRealListener)(nil), state.Fns{Save: (*itimerRealListener).save, Load: (*itimerRealListener).load}) + state.Register("kernel.TaskSet", (*TaskSet)(nil), state.Fns{Save: (*TaskSet).save, Load: (*TaskSet).load}) + state.Register("kernel.PIDNamespace", (*PIDNamespace)(nil), state.Fns{Save: (*PIDNamespace).save, Load: (*PIDNamespace).load}) + state.Register("kernel.threadGroupNode", (*threadGroupNode)(nil), state.Fns{Save: (*threadGroupNode).save, Load: (*threadGroupNode).load}) + state.Register("kernel.taskNode", (*taskNode)(nil), state.Fns{Save: (*taskNode).save, Load: (*taskNode).load}) + state.Register("kernel.Timekeeper", (*Timekeeper)(nil), state.Fns{Save: (*Timekeeper).save, Load: (*Timekeeper).load}) + state.Register("kernel.timekeeperClock", (*timekeeperClock)(nil), state.Fns{Save: (*timekeeperClock).save, Load: (*timekeeperClock).load}) + state.Register("kernel.UTSNamespace", (*UTSNamespace)(nil), state.Fns{Save: (*UTSNamespace).save, Load: (*UTSNamespace).load}) + state.Register("kernel.VDSOParamPage", (*VDSOParamPage)(nil), state.Fns{Save: (*VDSOParamPage).save, Load: (*VDSOParamPage).load}) +} diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD deleted file mode 100644 index ebcfaa619..000000000 --- a/pkg/sentry/kernel/memevent/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "memevent", - srcs = ["memory_events.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent", - visibility = ["//:sandbox"], - deps = [ - ":memory_events_go_proto", - "//pkg/eventchannel", - "//pkg/log", - "//pkg/metric", - "//pkg/sentry/kernel", - "//pkg/sentry/usage", - ], -) - -proto_library( - name = "memory_events_proto", - srcs = ["memory_events.proto"], - visibility = ["//visibility:public"], -) - -go_proto_library( - name = "memory_events_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto", - proto = ":memory_events_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/kernel/memevent/memory_events.go b/pkg/sentry/kernel/memevent/memory_events.go deleted file mode 100644 index b0d98e7f0..000000000 --- a/pkg/sentry/kernel/memevent/memory_events.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package memevent implements the memory usage events controller, which -// periodically emits events via the eventchannel. -package memevent - -import ( - "sync" - "time" - - "gvisor.dev/gvisor/pkg/eventchannel" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/metric" - "gvisor.dev/gvisor/pkg/sentry/kernel" - pb "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto" - "gvisor.dev/gvisor/pkg/sentry/usage" -) - -var totalTicks = metric.MustCreateNewUint64Metric("/memory_events/ticks", false /*sync*/, "Total number of memory event periods that have elapsed since startup.") -var totalEvents = metric.MustCreateNewUint64Metric("/memory_events/events", false /*sync*/, "Total number of memory events emitted.") - -// MemoryEvents describes the configuration for the global memory event emitter. -type MemoryEvents struct { - k *kernel.Kernel - - // The period is how often to emit an event. The memory events goroutine - // will ensure a minimum of one event is emitted per this period, regardless - // how of much memory usage has changed. - period time.Duration - - // Writing to this channel indicates the memory goroutine should stop. - stop chan struct{} - - // done is used to signal when the memory event goroutine has exited. - done sync.WaitGroup -} - -// New creates a new MemoryEvents. -func New(k *kernel.Kernel, period time.Duration) *MemoryEvents { - return &MemoryEvents{ - k: k, - period: period, - stop: make(chan struct{}), - } -} - -// Stop stops the memory usage events emitter goroutine. Stop must not be called -// concurrently with Start and may only be called once. -func (m *MemoryEvents) Stop() { - close(m.stop) - m.done.Wait() -} - -// Start starts the memory usage events emitter goroutine. Start must not be -// called concurrently with Stop and may only be called once. -func (m *MemoryEvents) Start() { - if m.period == 0 { - return - } - m.done.Add(1) - go m.run() // S/R-SAFE: doesn't interact with saved state. -} - -func (m *MemoryEvents) run() { - defer m.done.Done() - - // Emit the first event immediately on startup. - totalTicks.Increment() - m.emit() - - ticker := time.NewTicker(m.period) - defer ticker.Stop() - - for { - select { - case <-m.stop: - return - case <-ticker.C: - totalTicks.Increment() - m.emit() - } - } -} - -func (m *MemoryEvents) emit() { - totalPlatform, err := m.k.MemoryFile().TotalUsage() - if err != nil { - log.Warningf("Failed to fetch memory usage for memory events: %v", err) - return - } - snapshot, _ := usage.MemoryAccounting.Copy() - total := totalPlatform + snapshot.Mapped - - totalEvents.Increment() - eventchannel.Emit(&pb.MemoryUsageEvent{ - Mapped: snapshot.Mapped, - Total: total, - }) -} diff --git a/pkg/sentry/kernel/memevent/memory_events.proto b/pkg/sentry/kernel/memevent/memory_events.proto deleted file mode 100644 index bf8029ff5..000000000 --- a/pkg/sentry/kernel/memevent/memory_events.proto +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -// MemoryUsageEvent describes the memory usage of the sandbox at a single -// instant in time. These messages are emitted periodically on the eventchannel. -message MemoryUsageEvent { - // The total memory usage of the sandboxed application in bytes, calculated - // using the 'fast' method. - uint64 total = 1; - - // Memory used to back memory-mapped regions for files in the application, in - // bytes. This corresponds to the usage.MemoryKind.Mapped memory type. - uint64 mapped = 2; -} diff --git a/pkg/sentry/kernel/pending_signals_list.go b/pkg/sentry/kernel/pending_signals_list.go new file mode 100755 index 000000000..a3499371a --- /dev/null +++ b/pkg/sentry/kernel/pending_signals_list.go @@ -0,0 +1,173 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type pendingSignalElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (pendingSignalElementMapper) linkerFor(elem *pendingSignal) *pendingSignal { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type pendingSignalList struct { + head *pendingSignal + tail *pendingSignal +} + +// Reset resets list l to the empty state. +func (l *pendingSignalList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *pendingSignalList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *pendingSignalList) Front() *pendingSignal { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *pendingSignalList) Back() *pendingSignal { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *pendingSignalList) PushFront(e *pendingSignal) { + pendingSignalElementMapper{}.linkerFor(e).SetNext(l.head) + pendingSignalElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + pendingSignalElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *pendingSignalList) PushBack(e *pendingSignal) { + pendingSignalElementMapper{}.linkerFor(e).SetNext(nil) + pendingSignalElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + pendingSignalElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *pendingSignalList) PushBackList(m *pendingSignalList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + pendingSignalElementMapper{}.linkerFor(l.tail).SetNext(m.head) + pendingSignalElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *pendingSignalList) InsertAfter(b, e *pendingSignal) { + a := pendingSignalElementMapper{}.linkerFor(b).Next() + pendingSignalElementMapper{}.linkerFor(e).SetNext(a) + pendingSignalElementMapper{}.linkerFor(e).SetPrev(b) + pendingSignalElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + pendingSignalElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *pendingSignalList) InsertBefore(a, e *pendingSignal) { + b := pendingSignalElementMapper{}.linkerFor(a).Prev() + pendingSignalElementMapper{}.linkerFor(e).SetNext(a) + pendingSignalElementMapper{}.linkerFor(e).SetPrev(b) + pendingSignalElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + pendingSignalElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *pendingSignalList) Remove(e *pendingSignal) { + prev := pendingSignalElementMapper{}.linkerFor(e).Prev() + next := pendingSignalElementMapper{}.linkerFor(e).Next() + + if prev != nil { + pendingSignalElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + pendingSignalElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type pendingSignalEntry struct { + next *pendingSignal + prev *pendingSignal +} + +// Next returns the entry that follows e in the list. +func (e *pendingSignalEntry) Next() *pendingSignal { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *pendingSignalEntry) Prev() *pendingSignal { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *pendingSignalEntry) SetNext(elem *pendingSignal) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *pendingSignalEntry) SetPrev(elem *pendingSignal) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD deleted file mode 100644 index 4d15cca85..000000000 --- a/pkg/sentry/kernel/pipe/BUILD +++ /dev/null @@ -1,64 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "buffer_list", - out = "buffer_list.go", - package = "pipe", - prefix = "buffer", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*buffer", - "Linker": "*buffer", - }, -) - -go_library( - name = "pipe", - srcs = [ - "buffer.go", - "buffer_list.go", - "device.go", - "node.go", - "pipe.go", - "reader.go", - "reader_writer.go", - "writer.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/amutex", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/safemem", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) - -go_test( - name = "pipe_test", - size = "small", - srcs = [ - "buffer_test.go", - "node_test.go", - "pipe_test.go", - ], - embed = [":pipe"], - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/fs", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/pipe/buffer_list.go b/pkg/sentry/kernel/pipe/buffer_list.go new file mode 100755 index 000000000..42ec78788 --- /dev/null +++ b/pkg/sentry/kernel/pipe/buffer_list.go @@ -0,0 +1,173 @@ +package pipe + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type bufferElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (bufferElementMapper) linkerFor(elem *buffer) *buffer { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type bufferList struct { + head *buffer + tail *buffer +} + +// Reset resets list l to the empty state. +func (l *bufferList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *bufferList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *bufferList) Front() *buffer { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *bufferList) Back() *buffer { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *bufferList) PushFront(e *buffer) { + bufferElementMapper{}.linkerFor(e).SetNext(l.head) + bufferElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + bufferElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *bufferList) PushBack(e *buffer) { + bufferElementMapper{}.linkerFor(e).SetNext(nil) + bufferElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + bufferElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *bufferList) PushBackList(m *bufferList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + bufferElementMapper{}.linkerFor(l.tail).SetNext(m.head) + bufferElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *bufferList) InsertAfter(b, e *buffer) { + a := bufferElementMapper{}.linkerFor(b).Next() + bufferElementMapper{}.linkerFor(e).SetNext(a) + bufferElementMapper{}.linkerFor(e).SetPrev(b) + bufferElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + bufferElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *bufferList) InsertBefore(a, e *buffer) { + b := bufferElementMapper{}.linkerFor(a).Prev() + bufferElementMapper{}.linkerFor(e).SetNext(a) + bufferElementMapper{}.linkerFor(e).SetPrev(b) + bufferElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + bufferElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *bufferList) Remove(e *buffer) { + prev := bufferElementMapper{}.linkerFor(e).Prev() + next := bufferElementMapper{}.linkerFor(e).Next() + + if prev != nil { + bufferElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + bufferElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type bufferEntry struct { + next *buffer + prev *buffer +} + +// Next returns the entry that follows e in the list. +func (e *bufferEntry) Next() *buffer { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *bufferEntry) Prev() *buffer { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *bufferEntry) SetNext(elem *buffer) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *bufferEntry) SetPrev(elem *buffer) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/pipe/buffer_test.go b/pkg/sentry/kernel/pipe/buffer_test.go deleted file mode 100644 index ee1b90115..000000000 --- a/pkg/sentry/kernel/pipe/buffer_test.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "testing" - "unsafe" - - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -func TestBufferSize(t *testing.T) { - bufferSize := unsafe.Sizeof(buffer{}) - if bufferSize < usermem.PageSize { - t.Errorf("buffer is less than a page") - } - if bufferSize > (2 * usermem.PageSize) { - t.Errorf("buffer is greater than two pages") - } -} diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go deleted file mode 100644 index 7af2afd8c..000000000 --- a/pkg/sentry/kernel/pipe/node_test.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserror" -) - -type sleeper struct { - context.Context - ch chan struct{} -} - -func newSleeperContext(t *testing.T) context.Context { - return &sleeper{ - Context: contexttest.Context(t), - ch: make(chan struct{}), - } -} - -func (s *sleeper) SleepStart() <-chan struct{} { - return s.ch -} - -func (s *sleeper) SleepFinish(bool) { -} - -func (s *sleeper) Cancel() { - s.ch <- struct{}{} -} - -func (s *sleeper) Interrupted() bool { - return len(s.ch) != 0 -} - -type openResult struct { - *fs.File - error -} - -var perms fs.FilePermissions = fs.FilePermissions{ - User: fs.PermMask{Read: true, Write: true}, -} - -func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, doneChan chan<- struct{}) (*fs.File, error) { - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) - d := fs.NewDirent(inode, "pipe") - file, err := n.GetFile(ctx, d, flags) - if err != nil { - t.Fatalf("open with flags %+v failed: %v", flags, err) - } - if doneChan != nil { - doneChan <- struct{}{} - } - return file, err -} - -func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, resChan chan<- openResult) (*fs.File, error) { - inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) - d := fs.NewDirent(inode, "pipe") - file, err := n.GetFile(ctx, d, flags) - if resChan != nil { - resChan <- openResult{file, err} - } - return file, err -} - -func newNamedPipe(t *testing.T) *Pipe { - return NewPipe(contexttest.Context(t), true, DefaultPipeSize, usermem.PageSize) -} - -func newAnonPipe(t *testing.T) *Pipe { - return NewPipe(contexttest.Context(t), false, DefaultPipeSize, usermem.PageSize) -} - -// assertRecvBlocks ensures that a recv attempt on c blocks for at least -// blockDuration. This is useful for checking that a goroutine that is supposed -// to be executing a blocking operation is actually blocking. -func assertRecvBlocks(t *testing.T, c <-chan struct{}, blockDuration time.Duration, failMsg string) { - select { - case <-c: - t.Fatalf(failMsg) - case <-time.After(blockDuration): - // Ok, blocked for the required duration. - } -} - -func TestReadOpenBlocksForWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - // Verify that the open for read is blocking. - assertRecvBlocks(t, rDone, time.Millisecond*100, - "open for read not blocking with no writers") - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - <-wDone - <-rDone -} - -func TestWriteOpenBlocksForReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - // Verify that the open for write is blocking - assertRecvBlocks(t, wDone, time.Millisecond*100, - "open for write not blocking with no readers") - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - <-rDone - <-wDone -} - -func TestMultipleWriteOpenDoesntCountAsReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone1 := make(chan struct{}) - rDone2 := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone1) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone2) - - assertRecvBlocks(t, rDone1, time.Millisecond*100, - "open for read didn't block with no writers") - assertRecvBlocks(t, rDone2, time.Millisecond*100, - "open for read didn't block with no writers") - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - <-wDone - <-rDone2 - <-rDone1 -} - -func TestClosedReaderBlocksWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rFile, _ := testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil) - rFile.DecRef() - - wDone := make(chan struct{}) - // This open for write should block because the reader is now gone. - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - assertRecvBlocks(t, wDone, time.Millisecond*100, - "open for write didn't block with no concurrent readers") - - // Open for read again. This should unblock the open for write. - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - <-rDone - <-wDone -} - -func TestReadWriteOpenNeverBlocks(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rwDone := make(chan struct{}) - // Open for read-write never wait for a reader or writer, even if the - // nonblocking flag is not set. - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true, NonBlocking: false}, rwDone) - <-rwDone -} - -func TestReadWriteOpenUnblocksReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - rwDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true}, rwDone) - - <-rwDone - <-rDone -} - -func TestReadWriteOpenUnblocksWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - rwDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, Write: true}, rwDone) - - <-rwDone - <-wDone -} - -func TestBlockedOpenIsCancellable(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - done := make(chan openResult) - go testOpen(ctx, t, f, fs.FileFlags{Read: true}, done) - select { - case <-done: - t.Fatalf("open for read didn't block with no writers") - case <-time.After(time.Millisecond * 100): - // Ok. - } - - ctx.(*sleeper).Cancel() - // If the cancel on the sleeper didn't work, the open for read would never - // return. - res := <-done - if res.error != syserror.ErrInterrupted { - t.Fatalf("Cancellation didn't cause GetFile to return fs.ErrInterrupted, got %v.", - res.error) - } -} - -func TestNonblockingReadOpenFileNoWriters(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil); err != nil { - t.Fatalf("Nonblocking open for read failed with error %v.", err) - } -} - -func TestNonblockingWriteOpenFileNoReaders(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); err != syserror.ENXIO { - t.Fatalf("Nonblocking open for write failed unexpected error %v.", err) - } -} - -func TestNonBlockingReadOpenWithWriter(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - wDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Write: true}, wDone) - - // Open for write blocks since there are no readers yet. - assertRecvBlocks(t, wDone, time.Millisecond*100, - "Open for write didn't block with no reader.") - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil); err != nil { - t.Fatalf("Nonblocking open for read failed with error %v.", err) - } - - // Open for write should now be unblocked. - <-wDone -} - -func TestNonBlockingWriteOpenWithReader(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - - rDone := make(chan struct{}) - go testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true}, rDone) - - // Open for write blocked, since no reader yet. - assertRecvBlocks(t, rDone, time.Millisecond*100, - "Open for reader didn't block with no writer.") - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); err != nil { - t.Fatalf("Nonblocking open for write failed with error %v.", err) - } - - // Open for write should now be unblocked. - <-rDone -} - -func TestAnonReadOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newAnonPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Read: true}, nil); err != nil { - t.Fatalf("open anon pipe for read failed: %v", err) - } -} - -func TestAnonWriteOpen(t *testing.T) { - ctx := newSleeperContext(t) - f := NewInodeOperations(ctx, perms, newAnonPipe(t)) - - if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true}, nil); err != nil { - t.Fatalf("open anon pipe for write failed: %v", err) - } -} diff --git a/pkg/sentry/kernel/pipe/pipe_state_autogen.go b/pkg/sentry/kernel/pipe/pipe_state_autogen.go new file mode 100755 index 000000000..cad035789 --- /dev/null +++ b/pkg/sentry/kernel/pipe/pipe_state_autogen.go @@ -0,0 +1,132 @@ +// automatically generated by stateify. + +package pipe + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *buffer) beforeSave() {} +func (x *buffer) save(m state.Map) { + x.beforeSave() + m.Save("data", &x.data) + m.Save("read", &x.read) + m.Save("write", &x.write) + m.Save("bufferEntry", &x.bufferEntry) +} + +func (x *buffer) afterLoad() {} +func (x *buffer) load(m state.Map) { + m.Load("data", &x.data) + m.Load("read", &x.read) + m.Load("write", &x.write) + m.Load("bufferEntry", &x.bufferEntry) +} + +func (x *bufferList) beforeSave() {} +func (x *bufferList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *bufferList) afterLoad() {} +func (x *bufferList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *bufferEntry) beforeSave() {} +func (x *bufferEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *bufferEntry) afterLoad() {} +func (x *bufferEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *inodeOperations) beforeSave() {} +func (x *inodeOperations) save(m state.Map) { + x.beforeSave() + m.Save("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Save("p", &x.p) +} + +func (x *inodeOperations) afterLoad() {} +func (x *inodeOperations) load(m state.Map) { + m.Load("InodeSimpleAttributes", &x.InodeSimpleAttributes) + m.Load("p", &x.p) +} + +func (x *Pipe) beforeSave() {} +func (x *Pipe) save(m state.Map) { + x.beforeSave() + m.Save("isNamed", &x.isNamed) + m.Save("atomicIOBytes", &x.atomicIOBytes) + m.Save("readers", &x.readers) + m.Save("writers", &x.writers) + m.Save("data", &x.data) + m.Save("max", &x.max) + m.Save("size", &x.size) + m.Save("hadWriter", &x.hadWriter) +} + +func (x *Pipe) afterLoad() {} +func (x *Pipe) load(m state.Map) { + m.Load("isNamed", &x.isNamed) + m.Load("atomicIOBytes", &x.atomicIOBytes) + m.Load("readers", &x.readers) + m.Load("writers", &x.writers) + m.Load("data", &x.data) + m.Load("max", &x.max) + m.Load("size", &x.size) + m.Load("hadWriter", &x.hadWriter) +} + +func (x *Reader) beforeSave() {} +func (x *Reader) save(m state.Map) { + x.beforeSave() + m.Save("ReaderWriter", &x.ReaderWriter) +} + +func (x *Reader) afterLoad() {} +func (x *Reader) load(m state.Map) { + m.Load("ReaderWriter", &x.ReaderWriter) +} + +func (x *ReaderWriter) beforeSave() {} +func (x *ReaderWriter) save(m state.Map) { + x.beforeSave() + m.Save("Pipe", &x.Pipe) +} + +func (x *ReaderWriter) afterLoad() {} +func (x *ReaderWriter) load(m state.Map) { + m.Load("Pipe", &x.Pipe) +} + +func (x *Writer) beforeSave() {} +func (x *Writer) save(m state.Map) { + x.beforeSave() + m.Save("ReaderWriter", &x.ReaderWriter) +} + +func (x *Writer) afterLoad() {} +func (x *Writer) load(m state.Map) { + m.Load("ReaderWriter", &x.ReaderWriter) +} + +func init() { + state.Register("pipe.buffer", (*buffer)(nil), state.Fns{Save: (*buffer).save, Load: (*buffer).load}) + state.Register("pipe.bufferList", (*bufferList)(nil), state.Fns{Save: (*bufferList).save, Load: (*bufferList).load}) + state.Register("pipe.bufferEntry", (*bufferEntry)(nil), state.Fns{Save: (*bufferEntry).save, Load: (*bufferEntry).load}) + state.Register("pipe.inodeOperations", (*inodeOperations)(nil), state.Fns{Save: (*inodeOperations).save, Load: (*inodeOperations).load}) + state.Register("pipe.Pipe", (*Pipe)(nil), state.Fns{Save: (*Pipe).save, Load: (*Pipe).load}) + state.Register("pipe.Reader", (*Reader)(nil), state.Fns{Save: (*Reader).save, Load: (*Reader).load}) + state.Register("pipe.ReaderWriter", (*ReaderWriter)(nil), state.Fns{Save: (*ReaderWriter).save, Load: (*ReaderWriter).load}) + state.Register("pipe.Writer", (*Writer)(nil), state.Fns{Save: (*Writer).save, Load: (*Writer).load}) +} diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go deleted file mode 100644 index e3a14b665..000000000 --- a/pkg/sentry/kernel/pipe/pipe_test.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "bytes" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestPipeRW(t *testing.T) { - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) - defer r.DecRef() - defer w.DecRef() - - msg := []byte("here's some bytes") - wantN := int64(len(msg)) - n, err := w.Writev(ctx, usermem.BytesIOSequence(msg)) - if n != wantN || err != nil { - t.Fatalf("Writev: got (%d, %v), wanted (%d, nil)", n, err, wantN) - } - - buf := make([]byte, len(msg)) - n, err = r.Readv(ctx, usermem.BytesIOSequence(buf)) - if n != wantN || err != nil || !bytes.Equal(buf, msg) { - t.Fatalf("Readv: got (%d, %v) %q, wanted (%d, nil) %q", n, err, buf, wantN, msg) - } -} - -func TestPipeReadBlock(t *testing.T) { - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) - defer r.DecRef() - defer w.DecRef() - - n, err := r.Readv(ctx, usermem.BytesIOSequence(make([]byte, 1))) - if n != 0 || err != syserror.ErrWouldBlock { - t.Fatalf("Readv: got (%d, %v), wanted (0, %v)", n, err, syserror.ErrWouldBlock) - } -} - -func TestPipeWriteBlock(t *testing.T) { - const atomicIOBytes = 2 - const capacity = MinimumPipeSize - - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes) - defer r.DecRef() - defer w.DecRef() - - msg := make([]byte, capacity+1) - n, err := w.Writev(ctx, usermem.BytesIOSequence(msg)) - if wantN, wantErr := int64(capacity), syserror.ErrWouldBlock; n != wantN || err != wantErr { - t.Fatalf("Writev: got (%d, %v), wanted (%d, %v)", n, err, wantN, wantErr) - } -} - -func TestPipeWriteUntilEnd(t *testing.T) { - const atomicIOBytes = 2 - - ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes) - defer r.DecRef() - defer w.DecRef() - - msg := []byte("here's some bytes") - - wDone := make(chan struct{}, 0) - rDone := make(chan struct{}, 0) - defer func() { - // Signal the reader to stop and wait until it does so. - close(wDone) - <-rDone - }() - - go func() { - defer close(rDone) - // Read from r until done is closed. - ctx := contexttest.Context(t) - buf := make([]byte, len(msg)+1) - dst := usermem.BytesIOSequence(buf) - e, ch := waiter.NewChannelEntry(nil) - r.EventRegister(&e, waiter.EventIn) - defer r.EventUnregister(&e) - for { - n, err := r.Readv(ctx, dst) - dst = dst.DropFirst64(n) - if err == syserror.ErrWouldBlock { - select { - case <-ch: - continue - case <-wDone: - // We expect to have 1 byte left in dst since len(buf) == - // len(msg)+1. - if dst.NumBytes() != 1 || !bytes.Equal(buf[:len(msg)], msg) { - t.Errorf("Reader: got %q (%d bytes remaining), wanted %q", buf, dst.NumBytes(), msg) - } - return - } - } - if err != nil { - t.Fatalf("Readv: got unexpected error %v", err) - } - } - }() - - src := usermem.BytesIOSequence(msg) - e, ch := waiter.NewChannelEntry(nil) - w.EventRegister(&e, waiter.EventOut) - defer w.EventUnregister(&e) - for src.NumBytes() != 0 { - n, err := w.Writev(ctx, src) - src = src.DropFirst64(n) - if err == syserror.ErrWouldBlock { - <-ch - continue - } - if err != nil { - t.Fatalf("Writev: got (%d, %v)", n, err) - } - } -} diff --git a/pkg/sentry/kernel/process_group_list.go b/pkg/sentry/kernel/process_group_list.go new file mode 100755 index 000000000..853145237 --- /dev/null +++ b/pkg/sentry/kernel/process_group_list.go @@ -0,0 +1,173 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type processGroupElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (processGroupElementMapper) linkerFor(elem *ProcessGroup) *ProcessGroup { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type processGroupList struct { + head *ProcessGroup + tail *ProcessGroup +} + +// Reset resets list l to the empty state. +func (l *processGroupList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *processGroupList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *processGroupList) Front() *ProcessGroup { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *processGroupList) Back() *ProcessGroup { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *processGroupList) PushFront(e *ProcessGroup) { + processGroupElementMapper{}.linkerFor(e).SetNext(l.head) + processGroupElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + processGroupElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *processGroupList) PushBack(e *ProcessGroup) { + processGroupElementMapper{}.linkerFor(e).SetNext(nil) + processGroupElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + processGroupElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *processGroupList) PushBackList(m *processGroupList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + processGroupElementMapper{}.linkerFor(l.tail).SetNext(m.head) + processGroupElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *processGroupList) InsertAfter(b, e *ProcessGroup) { + a := processGroupElementMapper{}.linkerFor(b).Next() + processGroupElementMapper{}.linkerFor(e).SetNext(a) + processGroupElementMapper{}.linkerFor(e).SetPrev(b) + processGroupElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + processGroupElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *processGroupList) InsertBefore(a, e *ProcessGroup) { + b := processGroupElementMapper{}.linkerFor(a).Prev() + processGroupElementMapper{}.linkerFor(e).SetNext(a) + processGroupElementMapper{}.linkerFor(e).SetPrev(b) + processGroupElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + processGroupElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *processGroupList) Remove(e *ProcessGroup) { + prev := processGroupElementMapper{}.linkerFor(e).Prev() + next := processGroupElementMapper{}.linkerFor(e).Next() + + if prev != nil { + processGroupElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + processGroupElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type processGroupEntry struct { + next *ProcessGroup + prev *ProcessGroup +} + +// Next returns the entry that follows e in the list. +func (e *processGroupEntry) Next() *ProcessGroup { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *processGroupEntry) Prev() *ProcessGroup { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *processGroupEntry) SetNext(elem *ProcessGroup) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *processGroupEntry) SetPrev(elem *ProcessGroup) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD deleted file mode 100644 index 1725b8562..000000000 --- a/pkg/sentry/kernel/sched/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "sched", - srcs = [ - "cpuset.go", - "sched.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/sched", - visibility = ["//pkg/sentry:internal"], -) - -go_test( - name = "sched_test", - size = "small", - srcs = ["cpuset_test.go"], - embed = [":sched"], -) diff --git a/pkg/sentry/kernel/sched/cpuset_test.go b/pkg/sentry/kernel/sched/cpuset_test.go deleted file mode 100644 index 3af9f1197..000000000 --- a/pkg/sentry/kernel/sched/cpuset_test.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sched - -import ( - "testing" -) - -func TestNumCPUs(t *testing.T) { - for i := uint(0); i < 1024; i++ { - c := NewCPUSet(i) - for j := uint(0); j < i; j++ { - c.Set(j) - } - n := c.NumCPUs() - if n != i { - t.Errorf("got wrong number of cpus %d, want %d", n, i) - } - } -} - -func TestClearAbove(t *testing.T) { - const n = 1024 - c := NewFullCPUSet(n) - for i := uint(0); i < n; i++ { - cpu := n - i - c.ClearAbove(cpu) - if got := c.NumCPUs(); got != cpu { - t.Errorf("iteration %d: got %d cpus, wanted %d", i, got, cpu) - } - } -} diff --git a/pkg/sentry/kernel/sched/sched_state_autogen.go b/pkg/sentry/kernel/sched/sched_state_autogen.go new file mode 100755 index 000000000..2a482732e --- /dev/null +++ b/pkg/sentry/kernel/sched/sched_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package sched + diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD deleted file mode 100644 index 36edf10f3..000000000 --- a/pkg/sentry/kernel/semaphore/BUILD +++ /dev/null @@ -1,49 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "waiter_list", - out = "waiter_list.go", - package = "semaphore", - prefix = "waiter", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*waiter", - "Linker": "*waiter", - }, -) - -go_library( - name = "semaphore", - srcs = [ - "semaphore.go", - "waiter_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/semaphore", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/syserror", - ], -) - -go_test( - name = "semaphore_test", - size = "small", - srcs = ["semaphore_test.go"], - embed = [":semaphore"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/kernel/auth", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go b/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go new file mode 100755 index 000000000..f225f6f75 --- /dev/null +++ b/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go @@ -0,0 +1,115 @@ +// automatically generated by stateify. + +package semaphore + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Registry) beforeSave() {} +func (x *Registry) save(m state.Map) { + x.beforeSave() + m.Save("userNS", &x.userNS) + m.Save("semaphores", &x.semaphores) + m.Save("lastIDUsed", &x.lastIDUsed) +} + +func (x *Registry) afterLoad() {} +func (x *Registry) load(m state.Map) { + m.Load("userNS", &x.userNS) + m.Load("semaphores", &x.semaphores) + m.Load("lastIDUsed", &x.lastIDUsed) +} + +func (x *Set) beforeSave() {} +func (x *Set) save(m state.Map) { + x.beforeSave() + m.Save("registry", &x.registry) + m.Save("ID", &x.ID) + m.Save("key", &x.key) + m.Save("creator", &x.creator) + m.Save("owner", &x.owner) + m.Save("perms", &x.perms) + m.Save("opTime", &x.opTime) + m.Save("changeTime", &x.changeTime) + m.Save("sems", &x.sems) + m.Save("dead", &x.dead) +} + +func (x *Set) afterLoad() {} +func (x *Set) load(m state.Map) { + m.Load("registry", &x.registry) + m.Load("ID", &x.ID) + m.Load("key", &x.key) + m.Load("creator", &x.creator) + m.Load("owner", &x.owner) + m.Load("perms", &x.perms) + m.Load("opTime", &x.opTime) + m.Load("changeTime", &x.changeTime) + m.Load("sems", &x.sems) + m.Load("dead", &x.dead) +} + +func (x *sem) beforeSave() {} +func (x *sem) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.waiters) { m.Failf("waiters is %v, expected zero", x.waiters) } + m.Save("value", &x.value) + m.Save("pid", &x.pid) +} + +func (x *sem) afterLoad() {} +func (x *sem) load(m state.Map) { + m.Load("value", &x.value) + m.Load("pid", &x.pid) +} + +func (x *waiter) beforeSave() {} +func (x *waiter) save(m state.Map) { + x.beforeSave() + m.Save("waiterEntry", &x.waiterEntry) + m.Save("value", &x.value) + m.Save("ch", &x.ch) +} + +func (x *waiter) afterLoad() {} +func (x *waiter) load(m state.Map) { + m.Load("waiterEntry", &x.waiterEntry) + m.Load("value", &x.value) + m.Load("ch", &x.ch) +} + +func (x *waiterList) beforeSave() {} +func (x *waiterList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *waiterList) afterLoad() {} +func (x *waiterList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *waiterEntry) beforeSave() {} +func (x *waiterEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *waiterEntry) afterLoad() {} +func (x *waiterEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("semaphore.Registry", (*Registry)(nil), state.Fns{Save: (*Registry).save, Load: (*Registry).load}) + state.Register("semaphore.Set", (*Set)(nil), state.Fns{Save: (*Set).save, Load: (*Set).load}) + state.Register("semaphore.sem", (*sem)(nil), state.Fns{Save: (*sem).save, Load: (*sem).load}) + state.Register("semaphore.waiter", (*waiter)(nil), state.Fns{Save: (*waiter).save, Load: (*waiter).load}) + state.Register("semaphore.waiterList", (*waiterList)(nil), state.Fns{Save: (*waiterList).save, Load: (*waiterList).load}) + state.Register("semaphore.waiterEntry", (*waiterEntry)(nil), state.Fns{Save: (*waiterEntry).save, Load: (*waiterEntry).load}) +} diff --git a/pkg/sentry/kernel/semaphore/semaphore_test.go b/pkg/sentry/kernel/semaphore/semaphore_test.go deleted file mode 100644 index c235f6ca4..000000000 --- a/pkg/sentry/kernel/semaphore/semaphore_test.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package semaphore - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" -) - -func executeOps(ctx context.Context, t *testing.T, set *Set, ops []linux.Sembuf, block bool) chan struct{} { - ch, _, err := set.executeOps(ctx, ops, 123) - if err != nil { - t.Fatalf("ExecuteOps(ops) failed, err: %v, ops: %+v", err, ops) - } - if block { - if ch == nil { - t.Fatalf("ExecuteOps(ops) got: nil, expected: !nil, ops: %+v", ops) - } - if signalled(ch) { - t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops) - } - } else { - if ch != nil { - t.Fatalf("ExecuteOps(ops) got: %v, expected: nil, ops: %+v", ch, ops) - } - } - return ch -} - -func signalled(ch chan struct{}) bool { - select { - case <-ch: - return true - default: - return false - } -} - -func TestBasic(t *testing.T) { - ctx := contexttest.Context(t) - set := &Set{ID: 123, sems: make([]sem, 1)} - ops := []linux.Sembuf{ - {SemOp: 1}, - } - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -1 - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -1 - ch1 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 1 - executeOps(ctx, t, set, ops, false) - if !signalled(ch1) { - t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops) - } -} - -func TestWaitForZero(t *testing.T) { - ctx := contexttest.Context(t) - set := &Set{ID: 123, sems: make([]sem, 1)} - ops := []linux.Sembuf{ - {SemOp: 0}, - } - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -2 - ch1 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 0 - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = 1 - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = 0 - chZero1 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 0 - chZero2 := executeOps(ctx, t, set, ops, true) - - ops[0].SemOp = 1 - executeOps(ctx, t, set, ops, false) - if !signalled(ch1) { - t.Fatalf("ExecuteOps(ops) channel should have been signalled, ops: %+v, set: %+v", ops, set) - } - - ops[0].SemOp = -2 - executeOps(ctx, t, set, ops, false) - if !signalled(chZero1) { - t.Fatalf("ExecuteOps(ops) channel zero 1 should have been signalled, ops: %+v, set: %+v", ops, set) - } - if !signalled(chZero2) { - t.Fatalf("ExecuteOps(ops) channel zero 2 should have been signalled, ops: %+v, set: %+v", ops, set) - } -} - -func TestNoWait(t *testing.T) { - ctx := contexttest.Context(t) - set := &Set{ID: 123, sems: make([]sem, 1)} - ops := []linux.Sembuf{ - {SemOp: 1}, - } - executeOps(ctx, t, set, ops, false) - - ops[0].SemOp = -2 - ops[0].SemFlg = linux.IPC_NOWAIT - if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock { - t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock) - } - - ops[0].SemOp = 0 - ops[0].SemFlg = linux.IPC_NOWAIT - if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock { - t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock) - } -} - -func TestUnregister(t *testing.T) { - ctx := contexttest.Context(t) - r := NewRegistry(auth.NewRootUserNamespace()) - set, err := r.FindOrCreate(ctx, 123, 2, linux.FileMode(0x600), true, true, true) - if err != nil { - t.Fatalf("FindOrCreate() failed, err: %v", err) - } - if got := r.FindByID(set.ID); got.ID != set.ID { - t.Fatalf("FindById(%d) failed, got: %+v, expected: %+v", set.ID, got, set) - } - - ops := []linux.Sembuf{ - {SemOp: -1}, - } - chs := make([]chan struct{}, 0, 5) - for i := 0; i < 5; i++ { - ch := executeOps(ctx, t, set, ops, true) - chs = append(chs, ch) - } - - creds := auth.CredentialsFromContext(ctx) - if err := r.RemoveID(set.ID, creds); err != nil { - t.Fatalf("RemoveID(%d) failed, err: %v", set.ID, err) - } - if !set.dead { - t.Fatalf("set is not dead: %+v", set) - } - if got := r.FindByID(set.ID); got != nil { - t.Fatalf("FindById(%d) failed, got: %+v, expected: nil", set.ID, got) - } - for i, ch := range chs { - if !signalled(ch) { - t.Fatalf("channel %d should have been signalled", i) - } - } -} diff --git a/pkg/sentry/kernel/semaphore/waiter_list.go b/pkg/sentry/kernel/semaphore/waiter_list.go new file mode 100755 index 000000000..33e29fb55 --- /dev/null +++ b/pkg/sentry/kernel/semaphore/waiter_list.go @@ -0,0 +1,173 @@ +package semaphore + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type waiterElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (waiterElementMapper) linkerFor(elem *waiter) *waiter { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type waiterList struct { + head *waiter + tail *waiter +} + +// Reset resets list l to the empty state. +func (l *waiterList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *waiterList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *waiterList) Front() *waiter { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *waiterList) Back() *waiter { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *waiterList) PushFront(e *waiter) { + waiterElementMapper{}.linkerFor(e).SetNext(l.head) + waiterElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + waiterElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *waiterList) PushBack(e *waiter) { + waiterElementMapper{}.linkerFor(e).SetNext(nil) + waiterElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *waiterList) PushBackList(m *waiterList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head) + waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *waiterList) InsertAfter(b, e *waiter) { + a := waiterElementMapper{}.linkerFor(b).Next() + waiterElementMapper{}.linkerFor(e).SetNext(a) + waiterElementMapper{}.linkerFor(e).SetPrev(b) + waiterElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + waiterElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *waiterList) InsertBefore(a, e *waiter) { + b := waiterElementMapper{}.linkerFor(a).Prev() + waiterElementMapper{}.linkerFor(e).SetNext(a) + waiterElementMapper{}.linkerFor(e).SetPrev(b) + waiterElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + waiterElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *waiterList) Remove(e *waiter) { + prev := waiterElementMapper{}.linkerFor(e).Prev() + next := waiterElementMapper{}.linkerFor(e).Next() + + if prev != nil { + waiterElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + waiterElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type waiterEntry struct { + next *waiter + prev *waiter +} + +// Next returns the entry that follows e in the list. +func (e *waiterEntry) Next() *waiter { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *waiterEntry) Prev() *waiter { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *waiterEntry) SetNext(elem *waiter) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *waiterEntry) SetPrev(elem *waiter) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo.go b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo.go new file mode 100755 index 000000000..25ad17a4e --- /dev/null +++ b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo.go @@ -0,0 +1,55 @@ +package kernel + +import ( + "fmt" + "reflect" + "strings" + "unsafe" + + "gvisor.dev/gvisor/third_party/gvsync" +) + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in sc. +func SeqAtomicLoadTaskGoroutineSchedInfo(sc *gvsync.SeqCount, ptr *TaskGoroutineSchedInfo) TaskGoroutineSchedInfo { + // This function doesn't use SeqAtomicTryLoad because doing so is + // measurably, significantly (~20%) slower; Go is awful at inlining. + var val TaskGoroutineSchedInfo + for { + epoch := sc.BeginRead() + if gvsync.RaceEnabled { + + gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + + val = *ptr + } + if sc.ReadOk(epoch) { + break + } + } + return val +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read +// would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +func SeqAtomicTryLoadTaskGoroutineSchedInfo(sc *gvsync.SeqCount, epoch gvsync.SeqCountEpoch, ptr *TaskGoroutineSchedInfo) (TaskGoroutineSchedInfo, bool) { + var val TaskGoroutineSchedInfo + if gvsync.RaceEnabled { + gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + val = *ptr + } + return val, sc.ReadOk(epoch) +} + +func initTaskGoroutineSchedInfo() { + var val TaskGoroutineSchedInfo + typ := reflect.TypeOf(val) + name := typ.Name() + if ptrs := gvsync.PointersInType(typ, name); len(ptrs) != 0 { + panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) + } +} diff --git a/pkg/sentry/kernel/session_list.go b/pkg/sentry/kernel/session_list.go new file mode 100755 index 000000000..9ba27b164 --- /dev/null +++ b/pkg/sentry/kernel/session_list.go @@ -0,0 +1,173 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type sessionElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (sessionElementMapper) linkerFor(elem *Session) *Session { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type sessionList struct { + head *Session + tail *Session +} + +// Reset resets list l to the empty state. +func (l *sessionList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *sessionList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *sessionList) Front() *Session { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *sessionList) Back() *Session { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *sessionList) PushFront(e *Session) { + sessionElementMapper{}.linkerFor(e).SetNext(l.head) + sessionElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + sessionElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *sessionList) PushBack(e *Session) { + sessionElementMapper{}.linkerFor(e).SetNext(nil) + sessionElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + sessionElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *sessionList) PushBackList(m *sessionList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + sessionElementMapper{}.linkerFor(l.tail).SetNext(m.head) + sessionElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *sessionList) InsertAfter(b, e *Session) { + a := sessionElementMapper{}.linkerFor(b).Next() + sessionElementMapper{}.linkerFor(e).SetNext(a) + sessionElementMapper{}.linkerFor(e).SetPrev(b) + sessionElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + sessionElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *sessionList) InsertBefore(a, e *Session) { + b := sessionElementMapper{}.linkerFor(a).Prev() + sessionElementMapper{}.linkerFor(e).SetNext(a) + sessionElementMapper{}.linkerFor(e).SetPrev(b) + sessionElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + sessionElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *sessionList) Remove(e *Session) { + prev := sessionElementMapper{}.linkerFor(e).Prev() + next := sessionElementMapper{}.linkerFor(e).Next() + + if prev != nil { + sessionElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + sessionElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type sessionEntry struct { + next *Session + prev *Session +} + +// Next returns the entry that follows e in the list. +func (e *sessionEntry) Next() *Session { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *sessionEntry) Prev() *Session { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *sessionEntry) SetNext(elem *Session) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *sessionEntry) SetPrev(elem *Session) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD deleted file mode 100644 index aa7471eb6..000000000 --- a/pkg/sentry/kernel/shm/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "shm", - srcs = [ - "device.go", - "shm.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/shm", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/refs", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/kernel/shm/shm_state_autogen.go b/pkg/sentry/kernel/shm/shm_state_autogen.go new file mode 100755 index 000000000..4fd44bb6a --- /dev/null +++ b/pkg/sentry/kernel/shm/shm_state_autogen.go @@ -0,0 +1,74 @@ +// automatically generated by stateify. + +package shm + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Registry) beforeSave() {} +func (x *Registry) save(m state.Map) { + x.beforeSave() + m.Save("userNS", &x.userNS) + m.Save("shms", &x.shms) + m.Save("keysToShms", &x.keysToShms) + m.Save("totalPages", &x.totalPages) + m.Save("lastIDUsed", &x.lastIDUsed) +} + +func (x *Registry) afterLoad() {} +func (x *Registry) load(m state.Map) { + m.Load("userNS", &x.userNS) + m.Load("shms", &x.shms) + m.Load("keysToShms", &x.keysToShms) + m.Load("totalPages", &x.totalPages) + m.Load("lastIDUsed", &x.lastIDUsed) +} + +func (x *Shm) beforeSave() {} +func (x *Shm) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("mfp", &x.mfp) + m.Save("registry", &x.registry) + m.Save("ID", &x.ID) + m.Save("creator", &x.creator) + m.Save("size", &x.size) + m.Save("effectiveSize", &x.effectiveSize) + m.Save("fr", &x.fr) + m.Save("key", &x.key) + m.Save("perms", &x.perms) + m.Save("owner", &x.owner) + m.Save("attachTime", &x.attachTime) + m.Save("detachTime", &x.detachTime) + m.Save("changeTime", &x.changeTime) + m.Save("creatorPID", &x.creatorPID) + m.Save("lastAttachDetachPID", &x.lastAttachDetachPID) + m.Save("pendingDestruction", &x.pendingDestruction) +} + +func (x *Shm) afterLoad() {} +func (x *Shm) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("mfp", &x.mfp) + m.Load("registry", &x.registry) + m.Load("ID", &x.ID) + m.Load("creator", &x.creator) + m.Load("size", &x.size) + m.Load("effectiveSize", &x.effectiveSize) + m.Load("fr", &x.fr) + m.Load("key", &x.key) + m.Load("perms", &x.perms) + m.Load("owner", &x.owner) + m.Load("attachTime", &x.attachTime) + m.Load("detachTime", &x.detachTime) + m.Load("changeTime", &x.changeTime) + m.Load("creatorPID", &x.creatorPID) + m.Load("lastAttachDetachPID", &x.lastAttachDetachPID) + m.Load("pendingDestruction", &x.pendingDestruction) +} + +func init() { + state.Register("shm.Registry", (*Registry)(nil), state.Fns{Save: (*Registry).save, Load: (*Registry).load}) + state.Register("shm.Shm", (*Shm)(nil), state.Fns{Save: (*Shm).save, Load: (*Shm).load}) +} diff --git a/pkg/sentry/kernel/socket_list.go b/pkg/sentry/kernel/socket_list.go new file mode 100755 index 000000000..aed0a555e --- /dev/null +++ b/pkg/sentry/kernel/socket_list.go @@ -0,0 +1,173 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type socketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (socketElementMapper) linkerFor(elem *SocketEntry) *SocketEntry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type socketList struct { + head *SocketEntry + tail *SocketEntry +} + +// Reset resets list l to the empty state. +func (l *socketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *socketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *socketList) Front() *SocketEntry { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *socketList) Back() *SocketEntry { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *socketList) PushFront(e *SocketEntry) { + socketElementMapper{}.linkerFor(e).SetNext(l.head) + socketElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + socketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *socketList) PushBack(e *SocketEntry) { + socketElementMapper{}.linkerFor(e).SetNext(nil) + socketElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + socketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *socketList) PushBackList(m *socketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + socketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + socketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *socketList) InsertAfter(b, e *SocketEntry) { + a := socketElementMapper{}.linkerFor(b).Next() + socketElementMapper{}.linkerFor(e).SetNext(a) + socketElementMapper{}.linkerFor(e).SetPrev(b) + socketElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + socketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *socketList) InsertBefore(a, e *SocketEntry) { + b := socketElementMapper{}.linkerFor(a).Prev() + socketElementMapper{}.linkerFor(e).SetNext(a) + socketElementMapper{}.linkerFor(e).SetPrev(b) + socketElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + socketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *socketList) Remove(e *SocketEntry) { + prev := socketElementMapper{}.linkerFor(e).Prev() + next := socketElementMapper{}.linkerFor(e).Next() + + if prev != nil { + socketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + socketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type socketEntry struct { + next *SocketEntry + prev *SocketEntry +} + +// Next returns the entry that follows e in the list. +func (e *socketEntry) Next() *SocketEntry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *socketEntry) Prev() *SocketEntry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *socketEntry) SetNext(elem *SocketEntry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *socketEntry) SetPrev(elem *SocketEntry) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/table_test.go b/pkg/sentry/kernel/table_test.go deleted file mode 100644 index 32cf47e05..000000000 --- a/pkg/sentry/kernel/table_test.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/abi" - "gvisor.dev/gvisor/pkg/sentry/arch" -) - -const ( - maxTestSyscall = 1000 -) - -func createSyscallTable() *SyscallTable { - m := make(map[uintptr]Syscall) - for i := uintptr(0); i <= maxTestSyscall; i++ { - j := i - m[i] = Syscall{ - Fn: func(*Task, arch.SyscallArguments) (uintptr, *SyscallControl, error) { - return j, nil, nil - }, - } - } - - s := &SyscallTable{ - OS: abi.Linux, - Arch: arch.AMD64, - Table: m, - } - - RegisterSyscallTable(s) - return s -} - -func TestTable(t *testing.T) { - table := createSyscallTable() - defer func() { - // Cleanup registered tables to keep tests separate. - allSyscallTables = []*SyscallTable{} - }() - - // Go through all functions and check that they return the right value. - for i := uintptr(0); i < maxTestSyscall; i++ { - fn := table.Lookup(i) - if fn == nil { - t.Errorf("Syscall %v is set to nil", i) - continue - } - - v, _, _ := fn(nil, arch.SyscallArguments{}) - if v != i { - t.Errorf("Wrong return value for syscall %v: expected %v, got %v", i, i, v) - } - } - - // Check that values outside the range return nil. - for i := uintptr(maxTestSyscall + 1); i < maxTestSyscall+100; i++ { - fn := table.Lookup(i) - if fn != nil { - t.Errorf("Syscall %v is not nil: %v", i, fn) - continue - } - } -} - -func BenchmarkTableLookup(b *testing.B) { - table := createSyscallTable() - - b.ResetTimer() - - j := uintptr(0) - for i := 0; i < b.N; i++ { - table.Lookup(j) - j = (j + 1) % 310 - } - - b.StopTimer() - // Cleanup registered tables to keep tests separate. - allSyscallTables = []*SyscallTable{} -} - -func BenchmarkTableMapLookup(b *testing.B) { - table := createSyscallTable() - - b.ResetTimer() - - j := uintptr(0) - for i := 0; i < b.N; i++ { - table.mapLookup(j) - j = (j + 1) % 310 - } - - b.StopTimer() - // Cleanup registered tables to keep tests separate. - allSyscallTables = []*SyscallTable{} -} diff --git a/pkg/sentry/kernel/task_list.go b/pkg/sentry/kernel/task_list.go new file mode 100755 index 000000000..57d3f098d --- /dev/null +++ b/pkg/sentry/kernel/task_list.go @@ -0,0 +1,173 @@ +package kernel + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type taskElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (taskElementMapper) linkerFor(elem *Task) *Task { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type taskList struct { + head *Task + tail *Task +} + +// Reset resets list l to the empty state. +func (l *taskList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *taskList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *taskList) Front() *Task { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *taskList) Back() *Task { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *taskList) PushFront(e *Task) { + taskElementMapper{}.linkerFor(e).SetNext(l.head) + taskElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + taskElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *taskList) PushBack(e *Task) { + taskElementMapper{}.linkerFor(e).SetNext(nil) + taskElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + taskElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *taskList) PushBackList(m *taskList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + taskElementMapper{}.linkerFor(l.tail).SetNext(m.head) + taskElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *taskList) InsertAfter(b, e *Task) { + a := taskElementMapper{}.linkerFor(b).Next() + taskElementMapper{}.linkerFor(e).SetNext(a) + taskElementMapper{}.linkerFor(e).SetPrev(b) + taskElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + taskElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *taskList) InsertBefore(a, e *Task) { + b := taskElementMapper{}.linkerFor(a).Prev() + taskElementMapper{}.linkerFor(e).SetNext(a) + taskElementMapper{}.linkerFor(e).SetPrev(b) + taskElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + taskElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *taskList) Remove(e *Task) { + prev := taskElementMapper{}.linkerFor(e).Prev() + next := taskElementMapper{}.linkerFor(e).Next() + + if prev != nil { + taskElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + taskElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type taskEntry struct { + next *Task + prev *Task +} + +// Next returns the entry that follows e in the list. +func (e *taskEntry) Next() *Task { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *taskEntry) Prev() *Task { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *taskEntry) SetNext(elem *Task) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *taskEntry) SetPrev(elem *Task) { + e.prev = elem +} diff --git a/pkg/sentry/kernel/task_test.go b/pkg/sentry/kernel/task_test.go deleted file mode 100644 index cfcde9a7a..000000000 --- a/pkg/sentry/kernel/task_test.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/kernel/sched" -) - -func TestTaskCPU(t *testing.T) { - for _, test := range []struct { - mask sched.CPUSet - tid ThreadID - cpu int32 - }{ - { - mask: []byte{0xff}, - tid: 1, - cpu: 0, - }, - { - mask: []byte{0xff}, - tid: 10, - cpu: 1, - }, - { - // more than 8 cpus. - mask: []byte{0xff, 0xff}, - tid: 10, - cpu: 9, - }, - { - // missing the first cpu. - mask: []byte{0xfe}, - tid: 1, - cpu: 1, - }, - { - mask: []byte{0xfe}, - tid: 10, - cpu: 3, - }, - { - // missing the fifth cpu. - mask: []byte{0xef}, - tid: 10, - cpu: 2, - }, - } { - assigned := assignCPU(test.mask, test.tid) - if test.cpu != assigned { - t.Errorf("assignCPU(%v, %v) got %v, want %v", test.mask, test.tid, assigned, test.cpu) - } - } - -} diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD deleted file mode 100644 index 9beae4b31..000000000 --- a/pkg/sentry/kernel/time/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "time", - srcs = [ - "context.go", - "time.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/time", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/syserror", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/kernel/time/time_state_autogen.go b/pkg/sentry/kernel/time/time_state_autogen.go new file mode 100755 index 000000000..77b676cec --- /dev/null +++ b/pkg/sentry/kernel/time/time_state_autogen.go @@ -0,0 +1,56 @@ +// automatically generated by stateify. + +package time + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Time) beforeSave() {} +func (x *Time) save(m state.Map) { + x.beforeSave() + m.Save("ns", &x.ns) +} + +func (x *Time) afterLoad() {} +func (x *Time) load(m state.Map) { + m.Load("ns", &x.ns) +} + +func (x *Setting) beforeSave() {} +func (x *Setting) save(m state.Map) { + x.beforeSave() + m.Save("Enabled", &x.Enabled) + m.Save("Next", &x.Next) + m.Save("Period", &x.Period) +} + +func (x *Setting) afterLoad() {} +func (x *Setting) load(m state.Map) { + m.Load("Enabled", &x.Enabled) + m.Load("Next", &x.Next) + m.Load("Period", &x.Period) +} + +func (x *Timer) beforeSave() {} +func (x *Timer) save(m state.Map) { + x.beforeSave() + m.Save("clock", &x.clock) + m.Save("listener", &x.listener) + m.Save("setting", &x.setting) + m.Save("paused", &x.paused) +} + +func (x *Timer) afterLoad() {} +func (x *Timer) load(m state.Map) { + m.Load("clock", &x.clock) + m.Load("listener", &x.listener) + m.Load("setting", &x.setting) + m.Load("paused", &x.paused) +} + +func init() { + state.Register("time.Time", (*Time)(nil), state.Fns{Save: (*Time).save, Load: (*Time).load}) + state.Register("time.Setting", (*Setting)(nil), state.Fns{Save: (*Setting).save, Load: (*Setting).load}) + state.Register("time.Timer", (*Timer)(nil), state.Fns{Save: (*Timer).save, Load: (*Timer).load}) +} diff --git a/pkg/sentry/kernel/timekeeper_test.go b/pkg/sentry/kernel/timekeeper_test.go deleted file mode 100644 index 849c5b646..000000000 --- a/pkg/sentry/kernel/timekeeper_test.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kernel - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - sentrytime "gvisor.dev/gvisor/pkg/sentry/time" - "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserror" -) - -// mockClocks is a sentrytime.Clocks that simply returns the times in the -// struct. -type mockClocks struct { - monotonic int64 - realtime int64 -} - -// Update implements sentrytime.Clocks.Update. It does nothing. -func (*mockClocks) Update() (monotonicParams sentrytime.Parameters, monotonicOk bool, realtimeParam sentrytime.Parameters, realtimeOk bool) { - return -} - -// Update implements sentrytime.Clocks.GetTime. -func (c *mockClocks) GetTime(id sentrytime.ClockID) (int64, error) { - switch id { - case sentrytime.Monotonic: - return c.monotonic, nil - case sentrytime.Realtime: - return c.realtime, nil - default: - return 0, syserror.EINVAL - } -} - -// stateTestClocklessTimekeeper returns a test Timekeeper which has not had -// SetClocks called. -func stateTestClocklessTimekeeper(tb testing.TB) *Timekeeper { - ctx := contexttest.Context(tb) - mfp := pgalloc.MemoryFileProviderFromContext(ctx) - fr, err := mfp.MemoryFile().Allocate(usermem.PageSize, usage.Anonymous) - if err != nil { - tb.Fatalf("failed to allocate memory: %v", err) - } - return &Timekeeper{ - params: NewVDSOParamPage(mfp, fr), - } -} - -func stateTestTimekeeper(tb testing.TB) *Timekeeper { - t := stateTestClocklessTimekeeper(tb) - t.SetClocks(sentrytime.NewCalibratedClocks()) - return t -} - -// TestTimekeeperMonotonicZero tests that monotonic time starts at zero. -func TestTimekeeperMonotonicZero(t *testing.T) { - c := &mockClocks{ - monotonic: 100000, - } - - tk := stateTestClocklessTimekeeper(t) - tk.SetClocks(c) - defer tk.Destroy() - - now, err := tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 0 { - t.Errorf("GetTime got %d want 0", now) - } - - c.monotonic += 10 - - now, err = tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 10 { - t.Errorf("GetTime got %d want 10", now) - } -} - -// TestTimekeeperMonotonicJumpForward tests that monotonic time jumps forward -// after restore. -func TestTimekeeperMonotonicForward(t *testing.T) { - c := &mockClocks{ - monotonic: 900000, - realtime: 600000, - } - - tk := stateTestClocklessTimekeeper(t) - tk.restored = make(chan struct{}) - tk.saveMonotonic = 100000 - tk.saveRealtime = 400000 - tk.SetClocks(c) - defer tk.Destroy() - - // The monotonic clock should jump ahead by 200000 to 300000. - // - // The new system monotonic time (900000) is irrelevant to what the app - // sees. - now, err := tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 300000 { - t.Errorf("GetTime got %d want 300000", now) - } -} - -// TestTimekeeperMonotonicJumpBackwards tests that monotonic time does not jump -// backwards when realtime goes backwards. -func TestTimekeeperMonotonicJumpBackwards(t *testing.T) { - c := &mockClocks{ - monotonic: 900000, - realtime: 400000, - } - - tk := stateTestClocklessTimekeeper(t) - tk.restored = make(chan struct{}) - tk.saveMonotonic = 100000 - tk.saveRealtime = 600000 - tk.SetClocks(c) - defer tk.Destroy() - - // The monotonic clock should remain at 100000. - // - // The new system monotonic time (900000) is irrelevant to what the app - // sees and we don't want to jump the monotonic clock backwards like - // realtime did. - now, err := tk.GetTime(sentrytime.Monotonic) - if err != nil { - t.Errorf("GetTime err got %v want nil", err) - } - if now != 100000 { - t.Errorf("GetTime got %d want 100000", now) - } -} diff --git a/pkg/sentry/kernel/uncaught_signal.proto b/pkg/sentry/kernel/uncaught_signal.proto deleted file mode 100644 index 0bdb062cb..000000000 --- a/pkg/sentry/kernel/uncaught_signal.proto +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -import "pkg/sentry/arch/registers.proto"; - -message UncaughtSignal { - // Thread ID. - int32 tid = 1; - - // Process ID. - int32 pid = 2; - - // Registers at the time of the fault or signal. - Registers registers = 3; - - // Signal number. - int32 signal_number = 4; - - // The memory location which caused the fault (set if applicable, 0 - // otherwise). This will be set for SIGILL, SIGFPE, SIGSEGV, and SIGBUS. - uint64 fault_addr = 5; -} diff --git a/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go b/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go new file mode 100755 index 000000000..822e549ab --- /dev/null +++ b/pkg/sentry/kernel/uncaught_signal_go_proto/uncaught_signal.pb.go @@ -0,0 +1,119 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/sentry/kernel/uncaught_signal.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + registers_go_proto "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type UncaughtSignal struct { + Tid int32 `protobuf:"varint,1,opt,name=tid,proto3" json:"tid,omitempty"` + Pid int32 `protobuf:"varint,2,opt,name=pid,proto3" json:"pid,omitempty"` + Registers *registers_go_proto.Registers `protobuf:"bytes,3,opt,name=registers,proto3" json:"registers,omitempty"` + SignalNumber int32 `protobuf:"varint,4,opt,name=signal_number,json=signalNumber,proto3" json:"signal_number,omitempty"` + FaultAddr uint64 `protobuf:"varint,5,opt,name=fault_addr,json=faultAddr,proto3" json:"fault_addr,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *UncaughtSignal) Reset() { *m = UncaughtSignal{} } +func (m *UncaughtSignal) String() string { return proto.CompactTextString(m) } +func (*UncaughtSignal) ProtoMessage() {} +func (*UncaughtSignal) Descriptor() ([]byte, []int) { + return fileDescriptor_5ca9e03e13704688, []int{0} +} + +func (m *UncaughtSignal) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_UncaughtSignal.Unmarshal(m, b) +} +func (m *UncaughtSignal) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_UncaughtSignal.Marshal(b, m, deterministic) +} +func (m *UncaughtSignal) XXX_Merge(src proto.Message) { + xxx_messageInfo_UncaughtSignal.Merge(m, src) +} +func (m *UncaughtSignal) XXX_Size() int { + return xxx_messageInfo_UncaughtSignal.Size(m) +} +func (m *UncaughtSignal) XXX_DiscardUnknown() { + xxx_messageInfo_UncaughtSignal.DiscardUnknown(m) +} + +var xxx_messageInfo_UncaughtSignal proto.InternalMessageInfo + +func (m *UncaughtSignal) GetTid() int32 { + if m != nil { + return m.Tid + } + return 0 +} + +func (m *UncaughtSignal) GetPid() int32 { + if m != nil { + return m.Pid + } + return 0 +} + +func (m *UncaughtSignal) GetRegisters() *registers_go_proto.Registers { + if m != nil { + return m.Registers + } + return nil +} + +func (m *UncaughtSignal) GetSignalNumber() int32 { + if m != nil { + return m.SignalNumber + } + return 0 +} + +func (m *UncaughtSignal) GetFaultAddr() uint64 { + if m != nil { + return m.FaultAddr + } + return 0 +} + +func init() { + proto.RegisterType((*UncaughtSignal)(nil), "gvisor.UncaughtSignal") +} + +func init() { + proto.RegisterFile("pkg/sentry/kernel/uncaught_signal.proto", fileDescriptor_5ca9e03e13704688) +} + +var fileDescriptor_5ca9e03e13704688 = []byte{ + // 210 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x4c, 0x8e, 0x4d, 0x4a, 0xc6, 0x30, + 0x10, 0x86, 0x89, 0xfd, 0x81, 0xc6, 0x1f, 0x34, 0xab, 0x20, 0x88, 0x45, 0x17, 0x76, 0xd5, 0x80, + 0x9e, 0xc0, 0x0b, 0xb8, 0x88, 0xb8, 0x2e, 0x69, 0x13, 0xd3, 0xd0, 0x9a, 0x86, 0x49, 0x22, 0x78, + 0x24, 0x6f, 0x29, 0x4d, 0xd4, 0xef, 0xdb, 0x0d, 0xcf, 0xbc, 0xf3, 0xcc, 0x8b, 0x1f, 0xdc, 0xa2, + 0x99, 0x57, 0x36, 0xc0, 0x17, 0x5b, 0x14, 0x58, 0xb5, 0xb2, 0x68, 0x27, 0x11, 0xf5, 0x1c, 0x06, + 0x6f, 0xb4, 0x15, 0x6b, 0xef, 0x60, 0x0b, 0x1b, 0xa9, 0xf5, 0xa7, 0xf1, 0x1b, 0x5c, 0xdf, 0x1e, + 0x1d, 0x08, 0x98, 0x66, 0x06, 0x4a, 0x1b, 0x1f, 0x14, 0xf8, 0x1c, 0xbc, 0xfb, 0x46, 0xf8, 0xe2, + 0xed, 0x57, 0xf1, 0x9a, 0x0c, 0xe4, 0x12, 0x17, 0xc1, 0x48, 0x8a, 0x5a, 0xd4, 0x55, 0x7c, 0x1f, + 0x77, 0xe2, 0x8c, 0xa4, 0x27, 0x99, 0x38, 0x23, 0x09, 0xc3, 0xcd, 0xbf, 0x89, 0x16, 0x2d, 0xea, + 0x4e, 0x1f, 0xaf, 0xfa, 0xfc, 0xb3, 0xe7, 0x7f, 0x0b, 0x7e, 0xc8, 0x90, 0x7b, 0x7c, 0x9e, 0x0b, + 0x0e, 0x36, 0x7e, 0x8c, 0x0a, 0x68, 0x99, 0x64, 0x67, 0x19, 0xbe, 0x24, 0x46, 0x6e, 0x30, 0x7e, + 0x17, 0x71, 0x0d, 0x83, 0x90, 0x12, 0x68, 0xd5, 0xa2, 0xae, 0xe4, 0x4d, 0x22, 0xcf, 0x52, 0xc2, + 0x58, 0xa7, 0xca, 0x4f, 0x3f, 0x01, 0x00, 0x00, 0xff, 0xff, 0xfd, 0x62, 0x54, 0xdf, 0x06, 0x01, + 0x00, 0x00, +} diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD deleted file mode 100644 index 40025d62d..000000000 --- a/pkg/sentry/limits/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "limits", - srcs = [ - "context.go", - "limits.go", - "linux.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/limits", - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/context", - ], -) - -go_test( - name = "limits_test", - size = "small", - srcs = [ - "limits_test.go", - ], - embed = [":limits"], -) diff --git a/pkg/sentry/limits/limits_state_autogen.go b/pkg/sentry/limits/limits_state_autogen.go new file mode 100755 index 000000000..912a99cae --- /dev/null +++ b/pkg/sentry/limits/limits_state_autogen.go @@ -0,0 +1,36 @@ +// automatically generated by stateify. + +package limits + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Limit) beforeSave() {} +func (x *Limit) save(m state.Map) { + x.beforeSave() + m.Save("Cur", &x.Cur) + m.Save("Max", &x.Max) +} + +func (x *Limit) afterLoad() {} +func (x *Limit) load(m state.Map) { + m.Load("Cur", &x.Cur) + m.Load("Max", &x.Max) +} + +func (x *LimitSet) beforeSave() {} +func (x *LimitSet) save(m state.Map) { + x.beforeSave() + m.Save("data", &x.data) +} + +func (x *LimitSet) afterLoad() {} +func (x *LimitSet) load(m state.Map) { + m.Load("data", &x.data) +} + +func init() { + state.Register("limits.Limit", (*Limit)(nil), state.Fns{Save: (*Limit).save, Load: (*Limit).load}) + state.Register("limits.LimitSet", (*LimitSet)(nil), state.Fns{Save: (*LimitSet).save, Load: (*LimitSet).load}) +} diff --git a/pkg/sentry/limits/limits_test.go b/pkg/sentry/limits/limits_test.go deleted file mode 100644 index 658a20f56..000000000 --- a/pkg/sentry/limits/limits_test.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package limits - -import ( - "syscall" - "testing" -) - -func TestSet(t *testing.T) { - testCases := []struct { - limit Limit - privileged bool - expectedErr error - }{ - {limit: Limit{Cur: 50, Max: 50}, privileged: false, expectedErr: nil}, - {limit: Limit{Cur: 20, Max: 50}, privileged: false, expectedErr: nil}, - {limit: Limit{Cur: 20, Max: 60}, privileged: false, expectedErr: syscall.EPERM}, - {limit: Limit{Cur: 60, Max: 50}, privileged: false, expectedErr: syscall.EINVAL}, - {limit: Limit{Cur: 11, Max: 10}, privileged: false, expectedErr: syscall.EINVAL}, - {limit: Limit{Cur: 20, Max: 60}, privileged: true, expectedErr: nil}, - } - - ls := NewLimitSet() - for _, tc := range testCases { - if _, err := ls.Set(1, tc.limit, tc.privileged); err != tc.expectedErr { - t.Fatalf("Tried to set Limit to %+v and privilege %t: got %v, wanted %v", tc.limit, tc.privileged, err, tc.expectedErr) - } - } - -} diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD deleted file mode 100644 index 3b322f5f3..000000000 --- a/pkg/sentry/loader/BUILD +++ /dev/null @@ -1,51 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_embed_data") - -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_embed_data( - name = "vdso_bin", - src = "//vdso:vdso.so", - package = "loader", - var = "vdsoBin", -) - -go_library( - name = "loader", - srcs = [ - "elf.go", - "interpreter.go", - "loader.go", - "vdso.go", - "vdso_state.go", - ":vdso_bin", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/loader", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/cpuid", - "//pkg/log", - "//pkg/rand", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", - "//pkg/sentry/safemem", - "//pkg/sentry/uniqueid", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/loader/loader_state_autogen.go b/pkg/sentry/loader/loader_state_autogen.go new file mode 100755 index 000000000..4fe3c779d --- /dev/null +++ b/pkg/sentry/loader/loader_state_autogen.go @@ -0,0 +1,57 @@ +// automatically generated by stateify. + +package loader + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *VDSO) beforeSave() {} +func (x *VDSO) save(m state.Map) { + x.beforeSave() + var phdrs []elfProgHeader = x.savePhdrs() + m.SaveValue("phdrs", phdrs) + m.Save("ParamPage", &x.ParamPage) + m.Save("vdso", &x.vdso) + m.Save("os", &x.os) + m.Save("arch", &x.arch) +} + +func (x *VDSO) afterLoad() {} +func (x *VDSO) load(m state.Map) { + m.Load("ParamPage", &x.ParamPage) + m.Load("vdso", &x.vdso) + m.Load("os", &x.os) + m.Load("arch", &x.arch) + m.LoadValue("phdrs", new([]elfProgHeader), func(y interface{}) { x.loadPhdrs(y.([]elfProgHeader)) }) +} + +func (x *elfProgHeader) beforeSave() {} +func (x *elfProgHeader) save(m state.Map) { + x.beforeSave() + m.Save("Type", &x.Type) + m.Save("Flags", &x.Flags) + m.Save("Off", &x.Off) + m.Save("Vaddr", &x.Vaddr) + m.Save("Paddr", &x.Paddr) + m.Save("Filesz", &x.Filesz) + m.Save("Memsz", &x.Memsz) + m.Save("Align", &x.Align) +} + +func (x *elfProgHeader) afterLoad() {} +func (x *elfProgHeader) load(m state.Map) { + m.Load("Type", &x.Type) + m.Load("Flags", &x.Flags) + m.Load("Off", &x.Off) + m.Load("Vaddr", &x.Vaddr) + m.Load("Paddr", &x.Paddr) + m.Load("Filesz", &x.Filesz) + m.Load("Memsz", &x.Memsz) + m.Load("Align", &x.Align) +} + +func init() { + state.Register("loader.VDSO", (*VDSO)(nil), state.Fns{Save: (*VDSO).save, Load: (*VDSO).load}) + state.Register("loader.elfProgHeader", (*elfProgHeader)(nil), state.Fns{Save: (*elfProgHeader).save, Load: (*elfProgHeader).load}) +} diff --git a/pkg/sentry/loader/vdso_bin.go b/pkg/sentry/loader/vdso_bin.go new file mode 100755 index 000000000..cf351f9dc --- /dev/null +++ b/pkg/sentry/loader/vdso_bin.go @@ -0,0 +1,5 @@ +// Generated by go_embed_data for //pkg/sentry/loader:vdso_bin. DO NOT EDIT. + +package loader + +var vdsoBin = []byte("ELF\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00>\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x008\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00p\xff\xff\xff\xff\xff\x96\x00\x00\x00\x00\x00\x00\x96\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00x\x00\x00\x00\x00\x00\x00xp\xff\xff\xff\xff\xffxp\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00$\x00\x00\x00\x00\x00\x00$p\xff\xff\xff\xff\xff$p\xff\xff\xff\xff\xff@\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00P\xe5td\x00\x00\x00d\x00\x00\x00\x00\x00\x00dp\xff\xff\xff\xff\xffdp\xff\xff\xff\xff\xff<\x00\x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@p\xff\xff\xff\xff\xff(\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00pp\xff\xff\xff\xff\xffl\x00\x00\x00\x00\x00\x00\x00D\x00\x00\x00\"\x00\x00@p\xff\xff\xff\xff\xff(\x00\x00\x00\x00\x00\x00\x00R\x00\x00\x00\x00\x00\xe0p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00^\x00\x00\x00\"\x00\x00pp\xff\xff\xff\xff\xffl\x00\x00\x00\x00\x00\x00\x00k\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00y\x00\x00\x00\"\x00\x00\xe0p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00~\x00\x00\x00\"\x00\x00p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00linux-vdso.so.1\x00LINUX_2.6\x00__vdso_clock_gettime\x00__vdso_gettimeofday\x00clock_gettime\x00__vdso_time\x00gettimeofday\x00__vdso_getcpu\x00time\x00getcpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa1\xbf\xee
\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf6u\xae\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\x00\x00GNU\x00gold 1.11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00GNU\x00\x8e\xf8\x94\x9b\xe5\xdd[.\xba\xd9T\xc5\xdcFc.Jd;8\x00\x00\x00\x00\x00\x00\xdc\x00\x00T\x00\x00\x00
\x00\x00l\x00\x00\x00|
\x00\x00\x9c\x00\x00\x00\xac
\x00\x00\xbc\x00\x00\x00\xbc
\x00\x00\xd4\x00\x00\x00|\x00\x00\xf4\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00zR\x00x\x90\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00\x00(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,\x00\x00\x004\x00\x00\x00\x98\x00\x00l\x00\x00\x00\x00A\x86A\x83G0X\nAAE\x00\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\xd8\x00\x00\"\x00\x00\x00\x00A\x83G XA\x00\x00\x00\x00\x84\x00\x00\x00\xe8\x00\x00\n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x9c\x00\x00\x00\xe0\x00\x00\xb3\x00\x00\x00\x00A\x83\nhJ\x00\x00\x00\xbc\x00\x00\x00\x80
\x00\x00\xb6\x00\x00\x00\x00A\x83\nhM\x00\x00\x00\x00\x00\x00\x00`p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00Pp\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0\xff\xffo\x00\x00\x00\x00\xd6p\xff\xff\xff\xff\xff\xfc\xff\xffo\x00\x00\x00\x00\xecp\xff\xff\xff\xff\xff\xfd\xff\xffo\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x85\xfft\x83\xffuH\x89\xf7\xe9\x8f\x00\x00\x80\x00\x00\x00\x00\xb8\xe4\x00\x00\x00\xc3H\x89\xf7\xe9\xb8\x00\x00\x00\x84\x00\x00\x00\x00\x00USH\x89\xf3H\x83\xecH\x85\xfft;H\x89\xfdH\x89\xe7\xe8\x97\x00\x00\x00\x85\xc0u@H\x8b$H\x8bL$H\xba\xcf\xf7S㥛\xc4 H\x89E\x00H\x89\xc8H\xc1\xf9?H\xf7\xeaH\xc1\xfaH)\xcaH\x89UH\x85\xdbt\xc7\x00\x00\x00\x00\xc7C\x00\x00\x00\x001\xc0H\x83\xc4[]\xc3@\x001\xc0\xeb\xf1@\x00SH\x89\xfbH\x83\xecH\x89\xe7\xe80\x00\x00\x00H\x85\xdbH\x8b$tH\x89H\x83\xc4[\xc3@\x00f.\x84\x00\x00\x00\x00\x00\xb85\x00\x00H\x98Ð\x90\x90\x90\x90\x90SH\x89\xfeH\x8d
\xd5\xde\xff\xffH\x8b9\x83\xe7\xfeH\x8bY(L\x8bA8Lc\xcfL\x8bY0L\x8bQ@\xae\xe81H\x8b9L9\xcfu\xddH\x85\xdbtrH\x89щ\xc0H\xc1\xe1 H \xc11\xc0I9\xcb(H\xb8\x00\x00\x00\x00\x00ʚ;1\xd2L)\xd9I\xf7\xf2H\x89\xcfH\xc1\xff?H\xaf\xf8H\xf7\xe1H\xfaH\xac\xd0 I\x8d<\x00H\xb9SZ\x9b\xa0/\xb8D\x00[H\x89\xf8H\xc1\xe8 H\xf7\xe1H\x89\xd0H\xc1\xe8H\x89Hi\xc0\x00ʚ;H)\xc71\xc0H\x89~\xc3\x00\xb8\xe4\x00\x00\x001\xff[\xc3\x00f.\x84\x00\x00\x00\x00\x00SH\x89\xfeH\x8d
\xde\xff\xffH\x8b9\x83\xe7\xfeH\x8bYL\x8bALc\xcfL\x8bYL\x8bQ \xae\xe81H\x8b9L9\xcfu\xddH\x85\xdbtrH\x89щ\xc0H\xc1\xe1 H \xc11\xc0I9\xcb(H\xb8\x00\x00\x00\x00\x00ʚ;1\xd2L)\xd9I\xf7\xf2H\x89\xcfH\xc1\xff?H\xaf\xf8H\xf7\xe1H\xfaH\xac\xd0 I\x8d<\x00H\xb9SZ\x9b\xa0/\xb8D\x00[H\x89\xf8H\xc1\xe8 H\xf7\xe1H\x89\xd0H\xc1\xe8H\x89Hi\xc0\x00ʚ;H)\xc71\xc0H\x89~\xc3\x00\xb8\xe4\x00\x00\x00\xbf\x00\x00\x00[\xc3\x00GCC: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff\xb3\x00\x00\x00\x00\x00\x00\x009\x00\x00\x00\x00\x00\xe0p\xff\xff\xff\xff\xff\xb6\x00\x00\x00\x00\x00\x00\x00]\x00\x00\x00 \x00xp\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00f\x00\x00\x00\x00\x00\xf1\xff\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00s\x00\x00\x00\x00\x00\xf1\xff\x00\xf0o\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00{\x00\x00\x00\x00\xf1\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00\x00@p\xff\xff\xff\xff\xff(\x00\x00\x00\x00\x00\x00\x00\x9a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xb0\x00\x00\x00\x00\x00pp\xff\xff\xff\xff\xffl\x00\x00\x00\x00\x00\x00\x00\xc4\x00\x00\x00\"\x00\x00@p\xff\xff\xff\xff\xff(\x00\x00\x00\x00\x00\x00\x00\xd2\x00\x00\x00\x00\x00\xe0p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00\xde\x00\x00\x00\"\x00\x00pp\xff\xff\xff\xff\xffl\x00\x00\x00\x00\x00\x00\x00\xeb\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\xf9\x00\x00\x00\"\x00\x00\xe0p\xff\xff\xff\xff\xff\"\x00\x00\x00\x00\x00\x00\x00\xfe\x00\x00\x00\"\x00\x00p\xff\xff\xff\xff\xff\n\x00\x00\x00\x00\x00\x00\x00\x00vdso.cc\x00vdso_time.cc\x00_ZN4vdso13ClockRealtimeEP8timespec\x00_ZN4vdso14ClockMonotonicEP8timespec\x00_DYNAMIC\x00VDSO_PRELINK\x00_params\x00LINUX_2.6\x00__vdso_clock_gettime\x00_GLOBAL_OFFSET_TABLE_\x00__vdso_gettimeofday\x00clock_gettime\x00__vdso_time\x00gettimeofday\x00__vdso_getcpu\x00time\x00getcpu\x00\x00.text\x00.comment\x00.bss\x00.dynstr\x00.eh_frame_hdr\x00.gnu.version\x00.dynsym\x00.hash\x00.note\x00.eh_frame\x00.gnu.version_d\x00.dynamic\x00.shstrtab\x00.strtab\x00.symtab\x00.data\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 p\xff\xff\xff\xff\xff \x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x008\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00`p\xff\xff\xff\xff\xff`\x00\x00\x00\x00\x00\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00Pp\xff\xff\xff\xff\xffP\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00+\x00\x00\x00\xff\xff\xffo\x00\x00\x00\x00\x00\x00\x00\xd6p\xff\xff\xff\xff\xff\xd6\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00V\x00\x00\x00\xfd\xff\xffo\x00\x00\x00\x00\x00\x00\x00\xecp\xff\xff\xff\xff\xff\xec\x00\x00\x00\x00\x00\x008\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00F\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00$p\xff\xff\xff\xff\xff$\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00dp\xff\xff\xff\xff\xffd\x00\x00\x00\x00\x00\x00<\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00L\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa0p\xff\xff\xff\xff\xff\xa0\x00\x00\x00\x00\x00\x00\xd8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00xp\xff\xff\xff\xff\xffx\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x88\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@p\xff\xff\xff\xff\xff@\x00\x00\x00\x00\x00\x00V\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x96\x00\x00\x00\x00\x00\x006\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xd0\x00\x00\x00\x00\x00\x00\xb0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00x\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x85\x00\x00\x00\x00\x00\x00\x8e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD deleted file mode 100644 index 29c14ec56..000000000 --- a/pkg/sentry/memmap/BUILD +++ /dev/null @@ -1,57 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "mappable_range", - out = "mappable_range.go", - package = "memmap", - prefix = "Mappable", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "mapping_set_impl", - out = "mapping_set_impl.go", - package = "memmap", - prefix = "Mapping", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "MappableRange", - "Value": "MappingsOfRange", - "Functions": "mappingSetFunctions", - }, -) - -go_library( - name = "memmap", - srcs = [ - "mappable_range.go", - "mapping_set.go", - "mapping_set_impl.go", - "memmap.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/memmap", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/log", - "//pkg/refs", - "//pkg/sentry/context", - "//pkg/sentry/platform", - "//pkg/sentry/usermem", - "//pkg/syserror", - ], -) - -go_test( - name = "memmap_test", - size = "small", - srcs = ["mapping_set_test.go"], - embed = [":memmap"], - deps = ["//pkg/sentry/usermem"], -) diff --git a/pkg/sentry/memmap/mappable_range.go b/pkg/sentry/memmap/mappable_range.go new file mode 100755 index 000000000..6b6c2c685 --- /dev/null +++ b/pkg/sentry/memmap/mappable_range.go @@ -0,0 +1,62 @@ +package memmap + +// A Range represents a contiguous range of T. +// +// +stateify savable +type MappableRange struct { + // Start is the inclusive start of the range. + Start uint64 + + // End is the exclusive end of the range. + End uint64 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r MappableRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r MappableRange) Length() uint64 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r MappableRange) Contains(x uint64) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r MappableRange) Overlaps(r2 MappableRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r MappableRange) IsSupersetOf(r2 MappableRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r MappableRange) Intersect(r2 MappableRange) MappableRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r MappableRange) CanSplitAt(x uint64) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/memmap/mapping_set_impl.go b/pkg/sentry/memmap/mapping_set_impl.go new file mode 100755 index 000000000..eb3071e89 --- /dev/null +++ b/pkg/sentry/memmap/mapping_set_impl.go @@ -0,0 +1,1270 @@ +package memmap + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + MappingminDegree = 3 + + MappingmaxDegree = 2 * MappingminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type MappingSet struct { + root Mappingnode `state:".(*MappingSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *MappingSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *MappingSet) IsEmptyRange(r MappableRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *MappingSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *MappingSet) SpanRange(r MappableRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *MappingSet) FirstSegment() MappingIterator { + if s.root.nrSegments == 0 { + return MappingIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *MappingSet) LastSegment() MappingIterator { + if s.root.nrSegments == 0 { + return MappingIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *MappingSet) FirstGap() MappingGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return MappingGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *MappingSet) LastGap() MappingGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return MappingGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *MappingSet) Find(key uint64) (MappingIterator, MappingGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return MappingIterator{n, i}, MappingGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return MappingIterator{}, MappingGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *MappingSet) FindSegment(key uint64) MappingIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *MappingSet) LowerBoundSegment(min uint64) MappingIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *MappingSet) UpperBoundSegment(max uint64) MappingIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *MappingSet) FindGap(key uint64) MappingGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *MappingSet) LowerBoundGap(min uint64) MappingGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *MappingSet) UpperBoundGap(max uint64) MappingGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *MappingSet) Add(r MappableRange, val MappingsOfRange) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *MappingSet) AddWithoutMerging(r MappableRange, val MappingsOfRange) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *MappingSet) Insert(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (mappingSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (mappingSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (mappingSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *MappingSet) InsertWithoutMerging(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *MappingSet) InsertWithoutMergingUnchecked(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return MappingIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *MappingSet) Remove(seg MappingIterator) MappingGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + mappingSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(MappingGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *MappingSet) RemoveAll() { + s.root = Mappingnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *MappingSet) RemoveRange(r MappableRange) MappingGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *MappingSet) Merge(first, second MappingIterator) MappingIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *MappingSet) MergeUnchecked(first, second MappingIterator) MappingIterator { + if first.End() == second.Start() { + if mval, ok := (mappingSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return MappingIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *MappingSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *MappingSet) MergeRange(r MappableRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *MappingSet) MergeAdjacent(r MappableRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *MappingSet) Split(seg MappingIterator, split uint64) (MappingIterator, MappingIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *MappingSet) SplitUnchecked(seg MappingIterator, split uint64) (MappingIterator, MappingIterator) { + val1, val2 := (mappingSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), MappableRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *MappingSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *MappingSet) Isolate(seg MappingIterator, r MappableRange) MappingIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *MappingSet) ApplyContiguous(r MappableRange, fn func(seg MappingIterator)) MappingGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return MappingGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return MappingGapIterator{} + } + } +} + +// +stateify savable +type Mappingnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *Mappingnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [MappingmaxDegree - 1]MappableRange + values [MappingmaxDegree - 1]MappingsOfRange + children [MappingmaxDegree]*Mappingnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Mappingnode) firstSegment() MappingIterator { + for n.hasChildren { + n = n.children[0] + } + return MappingIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *Mappingnode) lastSegment() MappingIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return MappingIterator{n, n.nrSegments - 1} +} + +func (n *Mappingnode) prevSibling() *Mappingnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *Mappingnode) nextSibling() *Mappingnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *Mappingnode) rebalanceBeforeInsert(gap MappingGapIterator) MappingGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < MappingmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &Mappingnode{ + nrSegments: MappingminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &Mappingnode{ + nrSegments: MappingminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:MappingminDegree-1], n.keys[:MappingminDegree-1]) + copy(left.values[:MappingminDegree-1], n.values[:MappingminDegree-1]) + copy(right.keys[:MappingminDegree-1], n.keys[MappingminDegree:]) + copy(right.values[:MappingminDegree-1], n.values[MappingminDegree:]) + n.keys[0], n.values[0] = n.keys[MappingminDegree-1], n.values[MappingminDegree-1] + MappingzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:MappingminDegree], n.children[:MappingminDegree]) + copy(right.children[:MappingminDegree], n.children[MappingminDegree:]) + MappingzeroNodeSlice(n.children[2:]) + for i := 0; i < MappingminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < MappingminDegree { + return MappingGapIterator{left, gap.index} + } + return MappingGapIterator{right, gap.index - MappingminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[MappingminDegree-1], n.values[MappingminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &Mappingnode{ + nrSegments: MappingminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:MappingminDegree-1], n.keys[MappingminDegree:]) + copy(sibling.values[:MappingminDegree-1], n.values[MappingminDegree:]) + MappingzeroValueSlice(n.values[MappingminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:MappingminDegree], n.children[MappingminDegree:]) + MappingzeroNodeSlice(n.children[MappingminDegree:]) + for i := 0; i < MappingminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = MappingminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < MappingminDegree { + return gap + } + return MappingGapIterator{sibling, gap.index - MappingminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *Mappingnode) rebalanceAfterRemove(gap MappingGapIterator) MappingGapIterator { + for { + if n.nrSegments >= MappingminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= MappingminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + mappingSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return MappingGapIterator{n, 0} + } + if gap.node == n { + return MappingGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= MappingminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + mappingSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return MappingGapIterator{n, n.nrSegments} + } + return MappingGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return MappingGapIterator{p, gap.index} + } + if gap.node == right { + return MappingGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *Mappingnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = MappingGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + mappingSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type MappingIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *Mappingnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg MappingIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg MappingIterator) Range() MappableRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg MappingIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg MappingIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg MappingIterator) SetRangeUnchecked(r MappableRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg MappingIterator) SetRange(r MappableRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg MappingIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg MappingIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg MappingIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg MappingIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg MappingIterator) Value() MappingsOfRange { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg MappingIterator) ValuePtr() *MappingsOfRange { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg MappingIterator) SetValue(val MappingsOfRange) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg MappingIterator) PrevSegment() MappingIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return MappingIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return MappingIterator{} + } + return MappingsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg MappingIterator) NextSegment() MappingIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return MappingIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return MappingIterator{} + } + return MappingsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg MappingIterator) PrevGap() MappingGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return MappingGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg MappingIterator) NextGap() MappingGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return MappingGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg MappingIterator) PrevNonEmpty() (MappingIterator, MappingGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return MappingIterator{}, gap + } + return gap.PrevSegment(), MappingGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg MappingIterator) NextNonEmpty() (MappingIterator, MappingGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return MappingIterator{}, gap + } + return gap.NextSegment(), MappingGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type MappingGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *Mappingnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap MappingGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap MappingGapIterator) Range() MappableRange { + return MappableRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap MappingGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return mappingSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap MappingGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return mappingSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap MappingGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap MappingGapIterator) PrevSegment() MappingIterator { + return MappingsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap MappingGapIterator) NextSegment() MappingIterator { + return MappingsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap MappingGapIterator) PrevGap() MappingGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return MappingGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap MappingGapIterator) NextGap() MappingGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return MappingGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func MappingsegmentBeforePosition(n *Mappingnode, i int) MappingIterator { + for i == 0 { + if n.parent == nil { + return MappingIterator{} + } + n, i = n.parent, n.parentIndex + } + return MappingIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func MappingsegmentAfterPosition(n *Mappingnode, i int) MappingIterator { + for i == n.nrSegments { + if n.parent == nil { + return MappingIterator{} + } + n, i = n.parent, n.parentIndex + } + return MappingIterator{n, i} +} + +func MappingzeroValueSlice(slice []MappingsOfRange) { + + for i := range slice { + mappingSetFunctions{}.ClearValue(&slice[i]) + } +} + +func MappingzeroNodeSlice(slice []*Mappingnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *MappingSet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *Mappingnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *Mappingnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type MappingSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []MappingsOfRange +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *MappingSet) ExportSortedSlices() *MappingSegmentDataSlices { + var sds MappingSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *MappingSet) ImportSortedSlices(sds *MappingSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := MappableRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *MappingSet) saveRoot() *MappingSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *MappingSet) loadRoot(sds *MappingSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/memmap/mapping_set_test.go b/pkg/sentry/memmap/mapping_set_test.go deleted file mode 100644 index f9b11a59c..000000000 --- a/pkg/sentry/memmap/mapping_set_test.go +++ /dev/null @@ -1,260 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memmap - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -type testMappingSpace struct { - // Ideally we'd store the full ranges that were invalidated, rather - // than individual calls to Invalidate, as they are an implementation - // detail, but this is the simplest way for now. - inv []usermem.AddrRange -} - -func (n *testMappingSpace) reset() { - n.inv = []usermem.AddrRange{} -} - -func (n *testMappingSpace) Invalidate(ar usermem.AddrRange, opts InvalidateOpts) { - n.inv = append(n.inv, ar) -} - -func TestAddRemoveMapping(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - mapped := set.AddMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true) - if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings (usermem.AddrRanges => memmap.MappableRange): - // [0x10000, 0x12000) => [0x1000, 0x3000) - t.Log(&set) - - mapped = set.AddMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) - if len(mapped) != 0 { - t.Errorf("AddMapping: got %+v, wanted []", mapped) - } - - // Mappings: - // [0x10000, 0x11000) => [0x1000, 0x2000) - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - t.Log(&set) - - mapped = set.AddMapping(ms, usermem.AddrRange{0x30000, 0x31000}, 0x4000, true) - if got, want := mapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x11000) => [0x1000, 0x2000) - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - // [0x30000, 0x31000) => [0x4000, 0x5000) - t.Log(&set) - - mapped = set.AddMapping(ms, usermem.AddrRange{0x12000, 0x15000}, 0x3000, true) - if got, want := mapped, []MappableRange{{0x3000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x11000) => [0x1000, 0x2000) - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - // [0x12000, 0x13000) => [0x3000, 0x4000) - // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000) - // [0x14000, 0x15000) => [0x5000, 0x6000) - t.Log(&set) - - unmapped := set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0x1000, true) - if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000) - // [0x12000, 0x13000) => [0x3000, 0x4000) - // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000) - // [0x14000, 0x15000) => [0x5000, 0x6000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - // Mappings: - // [0x11000, 0x13000) => [0x2000, 0x4000) - // [0x13000, 0x14000) and [0x30000, 0x31000) => [0x4000, 0x5000) - // [0x14000, 0x15000) => [0x5000, 0x6000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x11000, 0x15000}, 0x2000, true) - if got, want := unmapped, []MappableRange{{0x2000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x30000, 0x31000) => [0x4000, 0x5000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x30000, 0x31000}, 0x4000, true) - if got, want := unmapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidateWholeMapping(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - set.AddMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0, true) - // Mappings: - // [0x10000, 0x11000) => [0, 0x1000) - t.Log(&set) - set.Invalidate(MappableRange{0, 0x1000}, InvalidateOpts{}) - if got, want := ms.inv, []usermem.AddrRange{{0x10000, 0x11000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidatePartialMapping(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - set.AddMapping(ms, usermem.AddrRange{0x10000, 0x13000}, 0, true) - // Mappings: - // [0x10000, 0x13000) => [0, 0x3000) - t.Log(&set) - set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{}) - if got, want := ms.inv, []usermem.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidateMultipleMappings(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - set.AddMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0, true) - set.AddMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) - // Mappings: - // [0x10000, 0x11000) => [0, 0x1000) - // [0x12000, 0x13000) => [0x2000, 0x3000) - t.Log(&set) - set.Invalidate(MappableRange{0, 0x3000}, InvalidateOpts{}) - if got, want := ms.inv, []usermem.AddrRange{{0x10000, 0x11000}, {0x20000, 0x21000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: got %+v, wanted %+v", got, want) - } -} - -func TestInvalidateOverlappingMappings(t *testing.T) { - set := MappingSet{} - ms1 := &testMappingSpace{} - ms2 := &testMappingSpace{} - - set.AddMapping(ms1, usermem.AddrRange{0x10000, 0x12000}, 0, true) - set.AddMapping(ms2, usermem.AddrRange{0x20000, 0x22000}, 0x1000, true) - // Mappings: - // ms1:[0x10000, 0x12000) => [0, 0x2000) - // ms2:[0x11000, 0x13000) => [0x1000, 0x3000) - t.Log(&set) - set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{}) - if got, want := ms1.inv, []usermem.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want) - } - if got, want := ms2.inv, []usermem.AddrRange{{0x20000, 0x21000}}; !reflect.DeepEqual(got, want) { - t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want) - } -} - -func TestMixedWritableMappings(t *testing.T) { - set := MappingSet{} - ms := &testMappingSpace{} - - mapped := set.AddMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true) - if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x12000) writable => [0x1000, 0x3000) - t.Log(&set) - - mapped = set.AddMapping(ms, usermem.AddrRange{0x20000, 0x22000}, 0x2000, false) - if got, want := mapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) { - t.Errorf("AddMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x11000) writable => [0x1000, 0x2000) - // [0x11000, 0x12000) writable and [0x20000, 0x21000) readonly => [0x2000, 0x3000) - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - // Unmap should fail because we specified the readonly map address range, but - // asked to unmap a writable segment. - unmapped := set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - // Readonly mapping removed, but writable mapping still exists in the range, - // so no mappable range fully unmapped. - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, false) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - // Mappings: - // [0x10000, 0x12000) writable => [0x1000, 0x3000) - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x11000, 0x12000}, 0x2000, true) - if got, want := unmapped, []MappableRange{{0x2000, 0x3000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x10000, 0x12000) writable => [0x1000, 0x3000) - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - // Unmap should fail since writable bit doesn't match. - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, false) - if len(unmapped) != 0 { - t.Errorf("RemoveMapping: got %+v, wanted []", unmapped) - } - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true) - if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } - - // Mappings: - // [0x21000, 0x22000) readonly => [0x3000, 0x4000) - t.Log(&set) - - unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x21000, 0x22000}, 0x3000, false) - if got, want := unmapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) { - t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want) - } -} diff --git a/pkg/sentry/memmap/memmap_state_autogen.go b/pkg/sentry/memmap/memmap_state_autogen.go new file mode 100755 index 000000000..2d0814ab1 --- /dev/null +++ b/pkg/sentry/memmap/memmap_state_autogen.go @@ -0,0 +1,93 @@ +// automatically generated by stateify. + +package memmap + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *MappableRange) beforeSave() {} +func (x *MappableRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *MappableRange) afterLoad() {} +func (x *MappableRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *MappingOfRange) beforeSave() {} +func (x *MappingOfRange) save(m state.Map) { + x.beforeSave() + m.Save("MappingSpace", &x.MappingSpace) + m.Save("AddrRange", &x.AddrRange) + m.Save("Writable", &x.Writable) +} + +func (x *MappingOfRange) afterLoad() {} +func (x *MappingOfRange) load(m state.Map) { + m.Load("MappingSpace", &x.MappingSpace) + m.Load("AddrRange", &x.AddrRange) + m.Load("Writable", &x.Writable) +} + +func (x *MappingSet) beforeSave() {} +func (x *MappingSet) save(m state.Map) { + x.beforeSave() + var root *MappingSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *MappingSet) afterLoad() {} +func (x *MappingSet) load(m state.Map) { + m.LoadValue("root", new(*MappingSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*MappingSegmentDataSlices)) }) +} + +func (x *Mappingnode) beforeSave() {} +func (x *Mappingnode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *Mappingnode) afterLoad() {} +func (x *Mappingnode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *MappingSegmentDataSlices) beforeSave() {} +func (x *MappingSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *MappingSegmentDataSlices) afterLoad() {} +func (x *MappingSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func init() { + state.Register("memmap.MappableRange", (*MappableRange)(nil), state.Fns{Save: (*MappableRange).save, Load: (*MappableRange).load}) + state.Register("memmap.MappingOfRange", (*MappingOfRange)(nil), state.Fns{Save: (*MappingOfRange).save, Load: (*MappingOfRange).load}) + state.Register("memmap.MappingSet", (*MappingSet)(nil), state.Fns{Save: (*MappingSet).save, Load: (*MappingSet).load}) + state.Register("memmap.Mappingnode", (*Mappingnode)(nil), state.Fns{Save: (*Mappingnode).save, Load: (*Mappingnode).load}) + state.Register("memmap.MappingSegmentDataSlices", (*MappingSegmentDataSlices)(nil), state.Fns{Save: (*MappingSegmentDataSlices).save, Load: (*MappingSegmentDataSlices).load}) +} diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD deleted file mode 100644 index 072745a08..000000000 --- a/pkg/sentry/mm/BUILD +++ /dev/null @@ -1,142 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "file_refcount_set", - out = "file_refcount_set.go", - imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", - }, - package = "mm", - prefix = "fileRefcount", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "platform.FileRange", - "Value": "int32", - "Functions": "fileRefcountSetFunctions", - }, -) - -go_template_instance( - name = "vma_set", - out = "vma_set.go", - consts = { - "minDegree": "8", - }, - imports = { - "usermem": "gvisor.dev/gvisor/pkg/sentry/usermem", - }, - package = "mm", - prefix = "vma", - template = "//pkg/segment:generic_set", - types = { - "Key": "usermem.Addr", - "Range": "usermem.AddrRange", - "Value": "vma", - "Functions": "vmaSetFunctions", - }, -) - -go_template_instance( - name = "pma_set", - out = "pma_set.go", - consts = { - "minDegree": "8", - }, - imports = { - "usermem": "gvisor.dev/gvisor/pkg/sentry/usermem", - }, - package = "mm", - prefix = "pma", - template = "//pkg/segment:generic_set", - types = { - "Key": "usermem.Addr", - "Range": "usermem.AddrRange", - "Value": "pma", - "Functions": "pmaSetFunctions", - }, -) - -go_template_instance( - name = "io_list", - out = "io_list.go", - package = "mm", - prefix = "io", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*ioResult", - "Linker": "*ioResult", - }, -) - -go_library( - name = "mm", - srcs = [ - "address_space.go", - "aio_context.go", - "aio_context_state.go", - "debug.go", - "file_refcount_set.go", - "io.go", - "io_list.go", - "lifecycle.go", - "metadata.go", - "mm.go", - "pma.go", - "pma_set.go", - "procfs.go", - "save_restore.go", - "shm.go", - "special_mappable.go", - "syscalls.go", - "vma.go", - "vma_set.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/mm", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/atomicbitops", - "//pkg/log", - "//pkg/refs", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/proc/seqfile", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/futex", - "//pkg/sentry/kernel/shm", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/safemem", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/tcpip/buffer", - "//third_party/gvsync", - ], -) - -go_test( - name = "mm_test", - size = "small", - srcs = ["mm_test.go"], - embed = [":mm"], - deps = [ - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/context/contexttest", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", - "//pkg/sentry/usermem", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/mm/README.md b/pkg/sentry/mm/README.md deleted file mode 100644 index e1322e373..000000000 --- a/pkg/sentry/mm/README.md +++ /dev/null @@ -1,280 +0,0 @@ -This package provides an emulation of Linux semantics for application virtual -memory mappings. - -For completeness, this document also describes aspects of the memory management -subsystem defined outside this package. - -# Background - -We begin by describing semantics for virtual memory in Linux. - -A virtual address space is defined as a collection of mappings from virtual -addresses to physical memory. However, userspace applications do not configure -mappings to physical memory directly. Instead, applications configure memory -mappings from virtual addresses to offsets into a file using the `mmap` system -call.[^mmap-anon] For example, a call to: - - mmap( - /* addr = */ 0x400000, - /* length = */ 0x1000, - PROT_READ | PROT_WRITE, - MAP_SHARED, - /* fd = */ 3, - /* offset = */ 0); - -creates a mapping of length 0x1000 bytes, starting at virtual address (VA) -0x400000, to offset 0 in the file represented by file descriptor (FD) 3. Within -the Linux kernel, virtual memory mappings are represented by *virtual memory -areas* (VMAs). Supposing that FD 3 represents file /tmp/foo, the state of the -virtual memory subsystem after the `mmap` call may be depicted as: - - VMA: VA:0x400000 -> /tmp/foo:0x0 - -Establishing a virtual memory area does not necessarily establish a mapping to a -physical address, because Linux has not necessarily provisioned physical memory -to store the file's contents. Thus, if the application attempts to read the -contents of VA 0x400000, it may incur a *page fault*, a CPU exception that -forces the kernel to create such a mapping to service the read. - -For a file, doing so consists of several logical phases: - -1. The kernel allocates physical memory to store the contents of the required - part of the file, and copies file contents to the allocated memory. - Supposing that the kernel chooses the physical memory at physical address - (PA) 0x2fb000, the resulting state of the system is: - - VMA: VA:0x400000 -> /tmp/foo:0x0 - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - - (In Linux the state of the mapping from file offset to physical memory is - stored in `struct address_space`, but to avoid confusion with other notions - of address space we will refer to this system as filemap, named after Linux - kernel source file `mm/filemap.c`.) - -2. The kernel stores the effective mapping from virtual to physical address in - a *page table entry* (PTE) in the application's *page tables*, which are - used by the CPU's virtual memory hardware to perform address translation. - The resulting state of the system is: - - VMA: VA:0x400000 -> /tmp/foo:0x0 - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - PTE: VA:0x400000 -----------------> PA:0x2fb000 - - The PTE is required for the application to actually use the contents of the - mapped file as virtual memory. However, the PTE is derived from the VMA and - filemap state, both of which are independently mutable, such that mutations - to either will affect the PTE. For example: - - - The application may remove the VMA using the `munmap` system call. This - breaks the mapping from VA:0x400000 to /tmp/foo:0x0, and consequently - the mapping from VA:0x400000 to PA:0x2fb000. However, it does not - necessarily break the mapping from /tmp/foo:0x0 to PA:0x2fb000, so a - future mapping of the same file offset may reuse this physical memory. - - - The application may invalidate the file's contents by passing a length - of 0 to the `ftruncate` system call. This breaks the mapping from - /tmp/foo:0x0 to PA:0x2fb000, and consequently the mapping from - VA:0x400000 to PA:0x2fb000. However, it does not break the mapping from - VA:0x400000 to /tmp/foo:0x0, so future changes to the file's contents - may again be made visible at VA:0x400000 after another page fault - results in the allocation of a new physical address. - - Note that, in order to correctly break the mapping from VA:0x400000 to - PA:0x2fb000 in the latter case, filemap must also store a *reverse mapping* - from /tmp/foo:0x0 to VA:0x400000 so that it can locate and remove the PTE. - -[^mmap-anon]: Memory mappings to non-files are discussed in later sections. - -## Private Mappings - -The preceding example considered VMAs created using the `MAP_SHARED` flag, which -means that PTEs derived from the mapping should always use physical memory that -represents the current state of the mapped file.[^mmap-dev-zero] Applications -can alternatively pass the `MAP_PRIVATE` flag to create a *private mapping*. -Private mappings are *copy-on-write*. - -Suppose that the application instead created a private mapping in the previous -example. In Linux, the state of the system after a read page fault would be: - - VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - PTE: VA:0x400000 -----------------> PA:0x2fb000 (read-only) - -Now suppose the application attempts to write to VA:0x400000. For a shared -mapping, the write would be propagated to PA:0x2fb000, and the kernel would be -responsible for ensuring that the write is later propagated to the mapped file. -For a private mapping, the write incurs another page fault since the PTE is -marked read-only. In response, the kernel allocates physical memory to store the -mapping's *private copy* of the file's contents, copies file contents to the -allocated memory, and changes the PTE to map to the private copy. Supposing that -the kernel chooses the physical memory at physical address (PA) 0x5ea000, the -resulting state of the system is: - - VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Filemap: /tmp/foo:0x0 -> PA:0x2fb000 - PTE: VA:0x400000 -----------------> PA:0x5ea000 - -Note that the filemap mapping from /tmp/foo:0x0 to PA:0x2fb000 may still exist, -but is now irrelevant to this mapping. - -[^mmap-dev-zero]: Modulo files with special mmap semantics such as `/dev/zero`. - -## Anonymous Mappings - -Instead of passing a file to the `mmap` system call, applications can instead -request an *anonymous* mapping by passing the `MAP_ANONYMOUS` flag. -Semantically, an anonymous mapping is essentially a mapping to an ephemeral file -initially filled with zero bytes. Practically speaking, this is how shared -anonymous mappings are implemented, but private anonymous mappings do not result -in the creation of an ephemeral file; since there would be no way to modify the -contents of the underlying file through a private mapping, all private anonymous -mappings use a single shared page filled with zero bytes until copy-on-write -occurs. - -# Virtual Memory in the Sentry - -The sentry implements application virtual memory atop a host kernel, introducing -an additional level of indirection to the above. - -Consider the same scenario as in the previous section. Since the sentry handles -application system calls, the effect of an application `mmap` system call is to -create a VMA in the sentry (as opposed to the host kernel): - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - -When the application first incurs a page fault on this address, the host kernel -delivers information about the page fault to the sentry in a platform-dependent -manner, and the sentry handles the fault: - -1. The sentry allocates memory to store the contents of the required part of - the file, and copies file contents to the allocated memory. However, since - the sentry is implemented atop a host kernel, it does not configure mappings - to physical memory directly. Instead, mappable "memory" in the sentry is - represented by a host file descriptor and offset, since (as noted in - "Background") this is the memory mapping primitive provided by the host - kernel. In general, memory is allocated from a temporary host file using the - `pgalloc` package. Supposing that the sentry allocates offset 0x3000 from - host file "memory-file", the resulting state is: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - -2. The sentry stores the effective mapping from virtual address to host file in - a host VMA by invoking the `mmap` system call: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000 - -3. The sentry returns control to the application, which immediately incurs the - page fault again.[^mmap-populate] However, since a host VMA now exists for - the faulting virtual address, the host kernel now handles the page fault as - described in "Background": - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000 - Host filemap: host:memory-file:0x3000 -> PA:0x2fb000 - Host PTE: VA:0x400000 --------------------------------------------> PA:0x2fb000 - -Thus, from an implementation standpoint, host VMAs serve the same purpose in the -sentry that PTEs do in Linux. As in Linux, sentry VMA and filemap state is -independently mutable, and the desired state of host VMAs is derived from that -state. - -[^mmap-populate]: The sentry could force the host kernel to establish PTEs when - it creates the host VMA by passing the `MAP_POPULATE` flag to - the `mmap` system call, but usually does not. This is because, - to reduce the number of page faults that require handling by - the sentry and (correspondingly) the number of host `mmap` - system calls, the sentry usually creates host VMAs that are - much larger than the single faulting page. - -## Private Mappings - -The sentry implements private mappings consistently with Linux. Before -copy-on-write, the private mapping example given in the Background results in: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x3000 (read-only) - Host filemap: host:memory-file:0x3000 -> PA:0x2fb000 - Host PTE: VA:0x400000 --------------------------------------------> PA:0x2fb000 (read-only) - -When the application attempts to write to this address, the host kernel delivers -information about the resulting page fault to the sentry. Analogous to Linux, -the sentry allocates memory to store the mapping's private copy of the file's -contents, copies file contents to the allocated memory, and changes the host VMA -to map to the private copy. Supposing that the sentry chooses the offset 0x4000 -in host file `memory-file` to store the private copy, the state of the system -after copy-on-write is: - - Sentry VMA: VA:0x400000 -> /tmp/foo:0x0 (private) - Sentry filemap: /tmp/foo:0x0 -> host:memory-file:0x3000 - Host VMA: VA:0x400000 -----------------> host:memory-file:0x4000 - Host filemap: host:memory-file:0x4000 -> PA:0x5ea000 - Host PTE: VA:0x400000 --------------------------------------------> PA:0x5ea000 - -However, this highlights an important difference between Linux and the sentry. -In Linux, page tables are concrete (architecture-dependent) data structures -owned by the kernel. Conversely, the sentry has the ability to create and -destroy host VMAs using host system calls, but it does not have direct access to -their state. Thus, as written, if the application invokes the `munmap` system -call to remove the sentry VMA, it is non-trivial for the sentry to determine -that it should deallocate `host:memory-file:0x4000`. This implies that the -sentry must retain information about the host VMAs that it has created. - -## Anonymous Mappings - -The sentry implements anonymous mappings consistently with Linux, except that -there is no shared zero page. - -# Implementation Constructs - -In Linux: - -- A virtual address space is represented by `struct mm_struct`. - -- VMAs are represented by `struct vm_area_struct`, stored in `struct - mm_struct::mmap`. - -- Mappings from file offsets to physical memory are stored in `struct - address_space`. - -- Reverse mappings from file offsets to virtual mappings are stored in `struct - address_space::i_mmap`. - -- Physical memory pages are represented by a pointer to `struct page` or an - index called a *page frame number* (PFN), represented by `pfn_t`. - -- PTEs are represented by architecture-dependent type `pte_t`, stored in a - table hierarchy rooted at `struct mm_struct::pgd`. - -In the sentry: - -- A virtual address space is represented by type [`mm.MemoryManager`][mm]. - -- Sentry VMAs are represented by type [`mm.vma`][mm], stored in - `mm.MemoryManager.vmas`. - -- Mappings from sentry file offsets to host file offsets are abstracted - through interface method [`memmap.Mappable.Translate`][memmap]. - -- Reverse mappings from sentry file offsets to virtual mappings are abstracted - through interface methods - [`memmap.Mappable.AddMapping` and `memmap.Mappable.RemoveMapping`][memmap]. - -- Host files that may be mapped into host VMAs are represented by type - [`platform.File`][platform]. - -- Host VMAs are represented in the sentry by type [`mm.pma`][mm] ("platform - mapping area"), stored in `mm.MemoryManager.pmas`. - -- Creation and destruction of host VMAs is abstracted through interface - methods - [`platform.AddressSpace.MapFile` and `platform.AddressSpace.Unmap`][platform]. - -[memmap]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/memmap/memmap.go -[mm]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/mm/mm.go -[pgalloc]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/pgalloc/pgalloc.go -[platform]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/platform/platform.go diff --git a/pkg/sentry/mm/file_refcount_set.go b/pkg/sentry/mm/file_refcount_set.go new file mode 100755 index 000000000..84a7fd759 --- /dev/null +++ b/pkg/sentry/mm/file_refcount_set.go @@ -0,0 +1,1274 @@ +package mm + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/platform" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + fileRefcountminDegree = 3 + + fileRefcountmaxDegree = 2 * fileRefcountminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type fileRefcountSet struct { + root fileRefcountnode `state:".(*fileRefcountSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *fileRefcountSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *fileRefcountSet) IsEmptyRange(r __generics_imported0.FileRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *fileRefcountSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *fileRefcountSet) SpanRange(r __generics_imported0.FileRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *fileRefcountSet) FirstSegment() fileRefcountIterator { + if s.root.nrSegments == 0 { + return fileRefcountIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *fileRefcountSet) LastSegment() fileRefcountIterator { + if s.root.nrSegments == 0 { + return fileRefcountIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *fileRefcountSet) FirstGap() fileRefcountGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return fileRefcountGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *fileRefcountSet) LastGap() fileRefcountGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return fileRefcountGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *fileRefcountSet) Find(key uint64) (fileRefcountIterator, fileRefcountGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return fileRefcountIterator{n, i}, fileRefcountGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return fileRefcountIterator{}, fileRefcountGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *fileRefcountSet) FindSegment(key uint64) fileRefcountIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *fileRefcountSet) LowerBoundSegment(min uint64) fileRefcountIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *fileRefcountSet) UpperBoundSegment(max uint64) fileRefcountIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *fileRefcountSet) FindGap(key uint64) fileRefcountGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *fileRefcountSet) LowerBoundGap(min uint64) fileRefcountGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *fileRefcountSet) UpperBoundGap(max uint64) fileRefcountGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *fileRefcountSet) Add(r __generics_imported0.FileRange, val int32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *fileRefcountSet) AddWithoutMerging(r __generics_imported0.FileRange, val int32) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *fileRefcountSet) Insert(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (fileRefcountSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (fileRefcountSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (fileRefcountSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *fileRefcountSet) InsertWithoutMerging(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *fileRefcountSet) InsertWithoutMergingUnchecked(gap fileRefcountGapIterator, r __generics_imported0.FileRange, val int32) fileRefcountIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return fileRefcountIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *fileRefcountSet) Remove(seg fileRefcountIterator) fileRefcountGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + fileRefcountSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(fileRefcountGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *fileRefcountSet) RemoveAll() { + s.root = fileRefcountnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *fileRefcountSet) RemoveRange(r __generics_imported0.FileRange) fileRefcountGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *fileRefcountSet) Merge(first, second fileRefcountIterator) fileRefcountIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *fileRefcountSet) MergeUnchecked(first, second fileRefcountIterator) fileRefcountIterator { + if first.End() == second.Start() { + if mval, ok := (fileRefcountSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return fileRefcountIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *fileRefcountSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *fileRefcountSet) MergeRange(r __generics_imported0.FileRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *fileRefcountSet) MergeAdjacent(r __generics_imported0.FileRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *fileRefcountSet) Split(seg fileRefcountIterator, split uint64) (fileRefcountIterator, fileRefcountIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *fileRefcountSet) SplitUnchecked(seg fileRefcountIterator, split uint64) (fileRefcountIterator, fileRefcountIterator) { + val1, val2 := (fileRefcountSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *fileRefcountSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *fileRefcountSet) Isolate(seg fileRefcountIterator, r __generics_imported0.FileRange) fileRefcountIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *fileRefcountSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg fileRefcountIterator)) fileRefcountGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return fileRefcountGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return fileRefcountGapIterator{} + } + } +} + +// +stateify savable +type fileRefcountnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *fileRefcountnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [fileRefcountmaxDegree - 1]__generics_imported0.FileRange + values [fileRefcountmaxDegree - 1]int32 + children [fileRefcountmaxDegree]*fileRefcountnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *fileRefcountnode) firstSegment() fileRefcountIterator { + for n.hasChildren { + n = n.children[0] + } + return fileRefcountIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *fileRefcountnode) lastSegment() fileRefcountIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return fileRefcountIterator{n, n.nrSegments - 1} +} + +func (n *fileRefcountnode) prevSibling() *fileRefcountnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *fileRefcountnode) nextSibling() *fileRefcountnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *fileRefcountnode) rebalanceBeforeInsert(gap fileRefcountGapIterator) fileRefcountGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < fileRefcountmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &fileRefcountnode{ + nrSegments: fileRefcountminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &fileRefcountnode{ + nrSegments: fileRefcountminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:fileRefcountminDegree-1], n.keys[:fileRefcountminDegree-1]) + copy(left.values[:fileRefcountminDegree-1], n.values[:fileRefcountminDegree-1]) + copy(right.keys[:fileRefcountminDegree-1], n.keys[fileRefcountminDegree:]) + copy(right.values[:fileRefcountminDegree-1], n.values[fileRefcountminDegree:]) + n.keys[0], n.values[0] = n.keys[fileRefcountminDegree-1], n.values[fileRefcountminDegree-1] + fileRefcountzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:fileRefcountminDegree], n.children[:fileRefcountminDegree]) + copy(right.children[:fileRefcountminDegree], n.children[fileRefcountminDegree:]) + fileRefcountzeroNodeSlice(n.children[2:]) + for i := 0; i < fileRefcountminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < fileRefcountminDegree { + return fileRefcountGapIterator{left, gap.index} + } + return fileRefcountGapIterator{right, gap.index - fileRefcountminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[fileRefcountminDegree-1], n.values[fileRefcountminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &fileRefcountnode{ + nrSegments: fileRefcountminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:fileRefcountminDegree-1], n.keys[fileRefcountminDegree:]) + copy(sibling.values[:fileRefcountminDegree-1], n.values[fileRefcountminDegree:]) + fileRefcountzeroValueSlice(n.values[fileRefcountminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:fileRefcountminDegree], n.children[fileRefcountminDegree:]) + fileRefcountzeroNodeSlice(n.children[fileRefcountminDegree:]) + for i := 0; i < fileRefcountminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = fileRefcountminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < fileRefcountminDegree { + return gap + } + return fileRefcountGapIterator{sibling, gap.index - fileRefcountminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *fileRefcountnode) rebalanceAfterRemove(gap fileRefcountGapIterator) fileRefcountGapIterator { + for { + if n.nrSegments >= fileRefcountminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= fileRefcountminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + fileRefcountSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return fileRefcountGapIterator{n, 0} + } + if gap.node == n { + return fileRefcountGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= fileRefcountminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + fileRefcountSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return fileRefcountGapIterator{n, n.nrSegments} + } + return fileRefcountGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return fileRefcountGapIterator{p, gap.index} + } + if gap.node == right { + return fileRefcountGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *fileRefcountnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = fileRefcountGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + fileRefcountSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type fileRefcountIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *fileRefcountnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg fileRefcountIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg fileRefcountIterator) Range() __generics_imported0.FileRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg fileRefcountIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg fileRefcountIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg fileRefcountIterator) SetRangeUnchecked(r __generics_imported0.FileRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg fileRefcountIterator) SetRange(r __generics_imported0.FileRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg fileRefcountIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg fileRefcountIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg fileRefcountIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg fileRefcountIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg fileRefcountIterator) Value() int32 { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg fileRefcountIterator) ValuePtr() *int32 { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg fileRefcountIterator) SetValue(val int32) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg fileRefcountIterator) PrevSegment() fileRefcountIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return fileRefcountIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return fileRefcountIterator{} + } + return fileRefcountsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg fileRefcountIterator) NextSegment() fileRefcountIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return fileRefcountIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return fileRefcountIterator{} + } + return fileRefcountsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg fileRefcountIterator) PrevGap() fileRefcountGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return fileRefcountGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg fileRefcountIterator) NextGap() fileRefcountGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return fileRefcountGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg fileRefcountIterator) PrevNonEmpty() (fileRefcountIterator, fileRefcountGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return fileRefcountIterator{}, gap + } + return gap.PrevSegment(), fileRefcountGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg fileRefcountIterator) NextNonEmpty() (fileRefcountIterator, fileRefcountGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return fileRefcountIterator{}, gap + } + return gap.NextSegment(), fileRefcountGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type fileRefcountGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *fileRefcountnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap fileRefcountGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap fileRefcountGapIterator) Range() __generics_imported0.FileRange { + return __generics_imported0.FileRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap fileRefcountGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return fileRefcountSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap fileRefcountGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return fileRefcountSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap fileRefcountGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap fileRefcountGapIterator) PrevSegment() fileRefcountIterator { + return fileRefcountsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap fileRefcountGapIterator) NextSegment() fileRefcountIterator { + return fileRefcountsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap fileRefcountGapIterator) PrevGap() fileRefcountGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return fileRefcountGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap fileRefcountGapIterator) NextGap() fileRefcountGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return fileRefcountGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func fileRefcountsegmentBeforePosition(n *fileRefcountnode, i int) fileRefcountIterator { + for i == 0 { + if n.parent == nil { + return fileRefcountIterator{} + } + n, i = n.parent, n.parentIndex + } + return fileRefcountIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func fileRefcountsegmentAfterPosition(n *fileRefcountnode, i int) fileRefcountIterator { + for i == n.nrSegments { + if n.parent == nil { + return fileRefcountIterator{} + } + n, i = n.parent, n.parentIndex + } + return fileRefcountIterator{n, i} +} + +func fileRefcountzeroValueSlice(slice []int32) { + + for i := range slice { + fileRefcountSetFunctions{}.ClearValue(&slice[i]) + } +} + +func fileRefcountzeroNodeSlice(slice []*fileRefcountnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *fileRefcountSet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *fileRefcountnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *fileRefcountnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type fileRefcountSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []int32 +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *fileRefcountSet) ExportSortedSlices() *fileRefcountSegmentDataSlices { + var sds fileRefcountSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *fileRefcountSet) ImportSortedSlices(sds *fileRefcountSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *fileRefcountSet) saveRoot() *fileRefcountSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *fileRefcountSet) loadRoot(sds *fileRefcountSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/mm/io_list.go b/pkg/sentry/mm/io_list.go new file mode 100755 index 000000000..99c83c4b9 --- /dev/null +++ b/pkg/sentry/mm/io_list.go @@ -0,0 +1,173 @@ +package mm + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type ioElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (ioElementMapper) linkerFor(elem *ioResult) *ioResult { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type ioList struct { + head *ioResult + tail *ioResult +} + +// Reset resets list l to the empty state. +func (l *ioList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *ioList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *ioList) Front() *ioResult { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *ioList) Back() *ioResult { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *ioList) PushFront(e *ioResult) { + ioElementMapper{}.linkerFor(e).SetNext(l.head) + ioElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + ioElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *ioList) PushBack(e *ioResult) { + ioElementMapper{}.linkerFor(e).SetNext(nil) + ioElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + ioElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *ioList) PushBackList(m *ioList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + ioElementMapper{}.linkerFor(l.tail).SetNext(m.head) + ioElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *ioList) InsertAfter(b, e *ioResult) { + a := ioElementMapper{}.linkerFor(b).Next() + ioElementMapper{}.linkerFor(e).SetNext(a) + ioElementMapper{}.linkerFor(e).SetPrev(b) + ioElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + ioElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *ioList) InsertBefore(a, e *ioResult) { + b := ioElementMapper{}.linkerFor(a).Prev() + ioElementMapper{}.linkerFor(e).SetNext(a) + ioElementMapper{}.linkerFor(e).SetPrev(b) + ioElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + ioElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *ioList) Remove(e *ioResult) { + prev := ioElementMapper{}.linkerFor(e).Prev() + next := ioElementMapper{}.linkerFor(e).Next() + + if prev != nil { + ioElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + ioElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type ioEntry struct { + next *ioResult + prev *ioResult +} + +// Next returns the entry that follows e in the list. +func (e *ioEntry) Next() *ioResult { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *ioEntry) Prev() *ioResult { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *ioEntry) SetNext(elem *ioResult) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *ioEntry) SetPrev(elem *ioResult) { + e.prev = elem +} diff --git a/pkg/sentry/mm/mm_state_autogen.go b/pkg/sentry/mm/mm_state_autogen.go new file mode 100755 index 000000000..29f2f5634 --- /dev/null +++ b/pkg/sentry/mm/mm_state_autogen.go @@ -0,0 +1,386 @@ +// automatically generated by stateify. + +package mm + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *aioManager) beforeSave() {} +func (x *aioManager) save(m state.Map) { + x.beforeSave() + m.Save("contexts", &x.contexts) +} + +func (x *aioManager) afterLoad() {} +func (x *aioManager) load(m state.Map) { + m.Load("contexts", &x.contexts) +} + +func (x *ioResult) beforeSave() {} +func (x *ioResult) save(m state.Map) { + x.beforeSave() + m.Save("data", &x.data) + m.Save("ioEntry", &x.ioEntry) +} + +func (x *ioResult) afterLoad() {} +func (x *ioResult) load(m state.Map) { + m.Load("data", &x.data) + m.Load("ioEntry", &x.ioEntry) +} + +func (x *AIOContext) beforeSave() {} +func (x *AIOContext) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.dead) { m.Failf("dead is %v, expected zero", x.dead) } + m.Save("results", &x.results) + m.Save("maxOutstanding", &x.maxOutstanding) + m.Save("outstanding", &x.outstanding) +} + +func (x *AIOContext) load(m state.Map) { + m.Load("results", &x.results) + m.Load("maxOutstanding", &x.maxOutstanding) + m.Load("outstanding", &x.outstanding) + m.AfterLoad(x.afterLoad) +} + +func (x *aioMappable) beforeSave() {} +func (x *aioMappable) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("mfp", &x.mfp) + m.Save("fr", &x.fr) +} + +func (x *aioMappable) afterLoad() {} +func (x *aioMappable) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("mfp", &x.mfp) + m.Load("fr", &x.fr) +} + +func (x *fileRefcountSet) beforeSave() {} +func (x *fileRefcountSet) save(m state.Map) { + x.beforeSave() + var root *fileRefcountSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *fileRefcountSet) afterLoad() {} +func (x *fileRefcountSet) load(m state.Map) { + m.LoadValue("root", new(*fileRefcountSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*fileRefcountSegmentDataSlices)) }) +} + +func (x *fileRefcountnode) beforeSave() {} +func (x *fileRefcountnode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *fileRefcountnode) afterLoad() {} +func (x *fileRefcountnode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *fileRefcountSegmentDataSlices) beforeSave() {} +func (x *fileRefcountSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *fileRefcountSegmentDataSlices) afterLoad() {} +func (x *fileRefcountSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *ioList) beforeSave() {} +func (x *ioList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *ioList) afterLoad() {} +func (x *ioList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *ioEntry) beforeSave() {} +func (x *ioEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *ioEntry) afterLoad() {} +func (x *ioEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *MemoryManager) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.active) { m.Failf("active is %v, expected zero", x.active) } + if !state.IsZeroValue(x.captureInvalidations) { m.Failf("captureInvalidations is %v, expected zero", x.captureInvalidations) } + m.Save("p", &x.p) + m.Save("mfp", &x.mfp) + m.Save("layout", &x.layout) + m.Save("privateRefs", &x.privateRefs) + m.Save("users", &x.users) + m.Save("vmas", &x.vmas) + m.Save("brk", &x.brk) + m.Save("usageAS", &x.usageAS) + m.Save("lockedAS", &x.lockedAS) + m.Save("dataAS", &x.dataAS) + m.Save("defMLockMode", &x.defMLockMode) + m.Save("pmas", &x.pmas) + m.Save("curRSS", &x.curRSS) + m.Save("maxRSS", &x.maxRSS) + m.Save("argv", &x.argv) + m.Save("envv", &x.envv) + m.Save("auxv", &x.auxv) + m.Save("executable", &x.executable) + m.Save("dumpability", &x.dumpability) + m.Save("aioManager", &x.aioManager) +} + +func (x *MemoryManager) load(m state.Map) { + m.Load("p", &x.p) + m.Load("mfp", &x.mfp) + m.Load("layout", &x.layout) + m.Load("privateRefs", &x.privateRefs) + m.Load("users", &x.users) + m.Load("vmas", &x.vmas) + m.Load("brk", &x.brk) + m.Load("usageAS", &x.usageAS) + m.Load("lockedAS", &x.lockedAS) + m.Load("dataAS", &x.dataAS) + m.Load("defMLockMode", &x.defMLockMode) + m.Load("pmas", &x.pmas) + m.Load("curRSS", &x.curRSS) + m.Load("maxRSS", &x.maxRSS) + m.Load("argv", &x.argv) + m.Load("envv", &x.envv) + m.Load("auxv", &x.auxv) + m.Load("executable", &x.executable) + m.Load("dumpability", &x.dumpability) + m.Load("aioManager", &x.aioManager) + m.AfterLoad(x.afterLoad) +} + +func (x *vma) beforeSave() {} +func (x *vma) save(m state.Map) { + x.beforeSave() + var realPerms int = x.saveRealPerms() + m.SaveValue("realPerms", realPerms) + m.Save("mappable", &x.mappable) + m.Save("off", &x.off) + m.Save("mlockMode", &x.mlockMode) + m.Save("numaPolicy", &x.numaPolicy) + m.Save("numaNodemask", &x.numaNodemask) + m.Save("id", &x.id) + m.Save("hint", &x.hint) +} + +func (x *vma) afterLoad() {} +func (x *vma) load(m state.Map) { + m.Load("mappable", &x.mappable) + m.Load("off", &x.off) + m.Load("mlockMode", &x.mlockMode) + m.Load("numaPolicy", &x.numaPolicy) + m.Load("numaNodemask", &x.numaNodemask) + m.Load("id", &x.id) + m.Load("hint", &x.hint) + m.LoadValue("realPerms", new(int), func(y interface{}) { x.loadRealPerms(y.(int)) }) +} + +func (x *pma) beforeSave() {} +func (x *pma) save(m state.Map) { + x.beforeSave() + m.Save("off", &x.off) + m.Save("translatePerms", &x.translatePerms) + m.Save("effectivePerms", &x.effectivePerms) + m.Save("maxPerms", &x.maxPerms) + m.Save("needCOW", &x.needCOW) + m.Save("private", &x.private) +} + +func (x *pma) afterLoad() {} +func (x *pma) load(m state.Map) { + m.Load("off", &x.off) + m.Load("translatePerms", &x.translatePerms) + m.Load("effectivePerms", &x.effectivePerms) + m.Load("maxPerms", &x.maxPerms) + m.Load("needCOW", &x.needCOW) + m.Load("private", &x.private) +} + +func (x *privateRefs) beforeSave() {} +func (x *privateRefs) save(m state.Map) { + x.beforeSave() + m.Save("refs", &x.refs) +} + +func (x *privateRefs) afterLoad() {} +func (x *privateRefs) load(m state.Map) { + m.Load("refs", &x.refs) +} + +func (x *pmaSet) beforeSave() {} +func (x *pmaSet) save(m state.Map) { + x.beforeSave() + var root *pmaSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *pmaSet) afterLoad() {} +func (x *pmaSet) load(m state.Map) { + m.LoadValue("root", new(*pmaSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*pmaSegmentDataSlices)) }) +} + +func (x *pmanode) beforeSave() {} +func (x *pmanode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *pmanode) afterLoad() {} +func (x *pmanode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *pmaSegmentDataSlices) beforeSave() {} +func (x *pmaSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *pmaSegmentDataSlices) afterLoad() {} +func (x *pmaSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *SpecialMappable) beforeSave() {} +func (x *SpecialMappable) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("mfp", &x.mfp) + m.Save("fr", &x.fr) + m.Save("name", &x.name) +} + +func (x *SpecialMappable) afterLoad() {} +func (x *SpecialMappable) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("mfp", &x.mfp) + m.Load("fr", &x.fr) + m.Load("name", &x.name) +} + +func (x *vmaSet) beforeSave() {} +func (x *vmaSet) save(m state.Map) { + x.beforeSave() + var root *vmaSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *vmaSet) afterLoad() {} +func (x *vmaSet) load(m state.Map) { + m.LoadValue("root", new(*vmaSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*vmaSegmentDataSlices)) }) +} + +func (x *vmanode) beforeSave() {} +func (x *vmanode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *vmanode) afterLoad() {} +func (x *vmanode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *vmaSegmentDataSlices) beforeSave() {} +func (x *vmaSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *vmaSegmentDataSlices) afterLoad() {} +func (x *vmaSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func init() { + state.Register("mm.aioManager", (*aioManager)(nil), state.Fns{Save: (*aioManager).save, Load: (*aioManager).load}) + state.Register("mm.ioResult", (*ioResult)(nil), state.Fns{Save: (*ioResult).save, Load: (*ioResult).load}) + state.Register("mm.AIOContext", (*AIOContext)(nil), state.Fns{Save: (*AIOContext).save, Load: (*AIOContext).load}) + state.Register("mm.aioMappable", (*aioMappable)(nil), state.Fns{Save: (*aioMappable).save, Load: (*aioMappable).load}) + state.Register("mm.fileRefcountSet", (*fileRefcountSet)(nil), state.Fns{Save: (*fileRefcountSet).save, Load: (*fileRefcountSet).load}) + state.Register("mm.fileRefcountnode", (*fileRefcountnode)(nil), state.Fns{Save: (*fileRefcountnode).save, Load: (*fileRefcountnode).load}) + state.Register("mm.fileRefcountSegmentDataSlices", (*fileRefcountSegmentDataSlices)(nil), state.Fns{Save: (*fileRefcountSegmentDataSlices).save, Load: (*fileRefcountSegmentDataSlices).load}) + state.Register("mm.ioList", (*ioList)(nil), state.Fns{Save: (*ioList).save, Load: (*ioList).load}) + state.Register("mm.ioEntry", (*ioEntry)(nil), state.Fns{Save: (*ioEntry).save, Load: (*ioEntry).load}) + state.Register("mm.MemoryManager", (*MemoryManager)(nil), state.Fns{Save: (*MemoryManager).save, Load: (*MemoryManager).load}) + state.Register("mm.vma", (*vma)(nil), state.Fns{Save: (*vma).save, Load: (*vma).load}) + state.Register("mm.pma", (*pma)(nil), state.Fns{Save: (*pma).save, Load: (*pma).load}) + state.Register("mm.privateRefs", (*privateRefs)(nil), state.Fns{Save: (*privateRefs).save, Load: (*privateRefs).load}) + state.Register("mm.pmaSet", (*pmaSet)(nil), state.Fns{Save: (*pmaSet).save, Load: (*pmaSet).load}) + state.Register("mm.pmanode", (*pmanode)(nil), state.Fns{Save: (*pmanode).save, Load: (*pmanode).load}) + state.Register("mm.pmaSegmentDataSlices", (*pmaSegmentDataSlices)(nil), state.Fns{Save: (*pmaSegmentDataSlices).save, Load: (*pmaSegmentDataSlices).load}) + state.Register("mm.SpecialMappable", (*SpecialMappable)(nil), state.Fns{Save: (*SpecialMappable).save, Load: (*SpecialMappable).load}) + state.Register("mm.vmaSet", (*vmaSet)(nil), state.Fns{Save: (*vmaSet).save, Load: (*vmaSet).load}) + state.Register("mm.vmanode", (*vmanode)(nil), state.Fns{Save: (*vmanode).save, Load: (*vmanode).load}) + state.Register("mm.vmaSegmentDataSlices", (*vmaSegmentDataSlices)(nil), state.Fns{Save: (*vmaSegmentDataSlices).save, Load: (*vmaSegmentDataSlices).load}) +} diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go deleted file mode 100644 index 4d2bfaaed..000000000 --- a/pkg/sentry/mm/mm_test.go +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package mm - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/context/contexttest" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserror" -) - -func testMemoryManager(ctx context.Context) *MemoryManager { - p := platform.FromContext(ctx) - mfp := pgalloc.MemoryFileProviderFromContext(ctx) - mm := NewMemoryManager(p, mfp) - mm.layout = arch.MmapLayout{ - MinAddr: p.MinUserAddress(), - MaxAddr: p.MaxUserAddress(), - BottomUpBase: p.MinUserAddress(), - TopDownBase: p.MaxUserAddress(), - } - return mm -} - -func (mm *MemoryManager) realUsageAS() uint64 { - return uint64(mm.vmas.Span()) -} - -func TestUsageASUpdates(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: 2 * usermem.PageSize, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - realUsage := mm.realUsageAS() - if mm.usageAS != realUsage { - t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage) - } - - mm.MUnmap(ctx, addr, usermem.PageSize) - realUsage = mm.realUsageAS() - if mm.usageAS != realUsage { - t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage) - } -} - -func (mm *MemoryManager) realDataAS() uint64 { - var sz uint64 - for seg := mm.vmas.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - vma := seg.Value() - if vma.isPrivateDataLocked() { - sz += uint64(seg.Range().Length()) - } - } - return sz -} - -func TestDataASUpdates(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: 3 * usermem.PageSize, - Private: true, - Perms: usermem.Write, - MaxPerms: usermem.AnyAccess, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - if mm.dataAS == 0 { - t.Fatalf("dataAS is 0, wanted not 0") - } - realDataAS := mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } - - mm.MUnmap(ctx, addr, usermem.PageSize) - realDataAS = mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } - - mm.MProtect(addr+usermem.PageSize, usermem.PageSize, usermem.Read, false) - realDataAS = mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } - - mm.MRemap(ctx, addr+2*usermem.PageSize, usermem.PageSize, 2*usermem.PageSize, MRemapOpts{ - Move: MRemapMayMove, - }) - realDataAS = mm.realDataAS() - if mm.dataAS != realDataAS { - t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS) - } -} - -func TestBrkDataLimitUpdates(t *testing.T) { - limitSet := limits.NewLimitSet() - limitSet.Set(limits.Data, limits.Limit{}, true /* privileged */) // zero RLIMIT_DATA - - ctx := contexttest.WithLimitSet(contexttest.Context(t), limitSet) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - // Try to extend the brk by one page and expect doing so to fail. - oldBrk, _ := mm.Brk(ctx, 0) - if newBrk, _ := mm.Brk(ctx, oldBrk+usermem.PageSize); newBrk != oldBrk { - t.Errorf("brk() increased data segment above RLIMIT_DATA (old brk = %#x, new brk = %#x", oldBrk, newBrk) - } -} - -// TestIOAfterUnmap ensures that IO fails after unmap. -func TestIOAfterUnmap(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: usermem.PageSize, - Private: true, - Perms: usermem.Read, - MaxPerms: usermem.AnyAccess, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - - // IO works before munmap. - b := make([]byte, 1) - n, err := mm.CopyIn(ctx, addr, b, usermem.IOOpts{}) - if err != nil { - t.Errorf("CopyIn got err %v want nil", err) - } - if n != 1 { - t.Errorf("CopyIn got %d want 1", n) - } - - err = mm.MUnmap(ctx, addr, usermem.PageSize) - if err != nil { - t.Fatalf("MUnmap got err %v want nil", err) - } - - n, err = mm.CopyIn(ctx, addr, b, usermem.IOOpts{}) - if err != syserror.EFAULT { - t.Errorf("CopyIn got err %v want EFAULT", err) - } - if n != 0 { - t.Errorf("CopyIn got %d want 0", n) - } -} - -// TestIOAfterMProtect tests IO interaction with mprotect permissions. -func TestIOAfterMProtect(t *testing.T) { - ctx := contexttest.Context(t) - mm := testMemoryManager(ctx) - defer mm.DecUsers(ctx) - - addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: usermem.PageSize, - Private: true, - Perms: usermem.ReadWrite, - MaxPerms: usermem.AnyAccess, - }) - if err != nil { - t.Fatalf("MMap got err %v want nil", err) - } - - // Writing works before mprotect. - b := make([]byte, 1) - n, err := mm.CopyOut(ctx, addr, b, usermem.IOOpts{}) - if err != nil { - t.Errorf("CopyOut got err %v want nil", err) - } - if n != 1 { - t.Errorf("CopyOut got %d want 1", n) - } - - err = mm.MProtect(addr, usermem.PageSize, usermem.Read, false) - if err != nil { - t.Errorf("MProtect got err %v want nil", err) - } - - // Without IgnorePermissions, CopyOut should no longer succeed. - n, err = mm.CopyOut(ctx, addr, b, usermem.IOOpts{}) - if err != syserror.EFAULT { - t.Errorf("CopyOut got err %v want EFAULT", err) - } - if n != 0 { - t.Errorf("CopyOut got %d want 0", n) - } - - // With IgnorePermissions, CopyOut should succeed despite mprotect. - n, err = mm.CopyOut(ctx, addr, b, usermem.IOOpts{ - IgnorePermissions: true, - }) - if err != nil { - t.Errorf("CopyOut got err %v want nil", err) - } - if n != 1 { - t.Errorf("CopyOut got %d want 1", n) - } -} diff --git a/pkg/sentry/mm/pma_set.go b/pkg/sentry/mm/pma_set.go new file mode 100755 index 000000000..a048a299c --- /dev/null +++ b/pkg/sentry/mm/pma_set.go @@ -0,0 +1,1274 @@ +package mm + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + pmaminDegree = 8 + + pmamaxDegree = 2 * pmaminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type pmaSet struct { + root pmanode `state:".(*pmaSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *pmaSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *pmaSet) IsEmptyRange(r __generics_imported0.AddrRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *pmaSet) Span() __generics_imported0.Addr { + var sz __generics_imported0.Addr + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *pmaSet) SpanRange(r __generics_imported0.AddrRange) __generics_imported0.Addr { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz __generics_imported0.Addr + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *pmaSet) FirstSegment() pmaIterator { + if s.root.nrSegments == 0 { + return pmaIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *pmaSet) LastSegment() pmaIterator { + if s.root.nrSegments == 0 { + return pmaIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *pmaSet) FirstGap() pmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return pmaGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *pmaSet) LastGap() pmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return pmaGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *pmaSet) Find(key __generics_imported0.Addr) (pmaIterator, pmaGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return pmaIterator{n, i}, pmaGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return pmaIterator{}, pmaGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *pmaSet) FindSegment(key __generics_imported0.Addr) pmaIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *pmaSet) LowerBoundSegment(min __generics_imported0.Addr) pmaIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *pmaSet) UpperBoundSegment(max __generics_imported0.Addr) pmaIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *pmaSet) FindGap(key __generics_imported0.Addr) pmaGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *pmaSet) LowerBoundGap(min __generics_imported0.Addr) pmaGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *pmaSet) UpperBoundGap(max __generics_imported0.Addr) pmaGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *pmaSet) Add(r __generics_imported0.AddrRange, val pma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *pmaSet) AddWithoutMerging(r __generics_imported0.AddrRange, val pma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *pmaSet) Insert(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (pmaSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (pmaSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (pmaSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *pmaSet) InsertWithoutMerging(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *pmaSet) InsertWithoutMergingUnchecked(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return pmaIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *pmaSet) Remove(seg pmaIterator) pmaGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + pmaSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(pmaGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *pmaSet) RemoveAll() { + s.root = pmanode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *pmaSet) RemoveRange(r __generics_imported0.AddrRange) pmaGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *pmaSet) Merge(first, second pmaIterator) pmaIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *pmaSet) MergeUnchecked(first, second pmaIterator) pmaIterator { + if first.End() == second.Start() { + if mval, ok := (pmaSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return pmaIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *pmaSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *pmaSet) MergeRange(r __generics_imported0.AddrRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *pmaSet) MergeAdjacent(r __generics_imported0.AddrRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *pmaSet) Split(seg pmaIterator, split __generics_imported0.Addr) (pmaIterator, pmaIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *pmaSet) SplitUnchecked(seg pmaIterator, split __generics_imported0.Addr) (pmaIterator, pmaIterator) { + val1, val2 := (pmaSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.AddrRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *pmaSet) SplitAt(split __generics_imported0.Addr) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *pmaSet) Isolate(seg pmaIterator, r __generics_imported0.AddrRange) pmaIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *pmaSet) ApplyContiguous(r __generics_imported0.AddrRange, fn func(seg pmaIterator)) pmaGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return pmaGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return pmaGapIterator{} + } + } +} + +// +stateify savable +type pmanode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *pmanode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [pmamaxDegree - 1]__generics_imported0.AddrRange + values [pmamaxDegree - 1]pma + children [pmamaxDegree]*pmanode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *pmanode) firstSegment() pmaIterator { + for n.hasChildren { + n = n.children[0] + } + return pmaIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *pmanode) lastSegment() pmaIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return pmaIterator{n, n.nrSegments - 1} +} + +func (n *pmanode) prevSibling() *pmanode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *pmanode) nextSibling() *pmanode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *pmanode) rebalanceBeforeInsert(gap pmaGapIterator) pmaGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < pmamaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &pmanode{ + nrSegments: pmaminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &pmanode{ + nrSegments: pmaminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:pmaminDegree-1], n.keys[:pmaminDegree-1]) + copy(left.values[:pmaminDegree-1], n.values[:pmaminDegree-1]) + copy(right.keys[:pmaminDegree-1], n.keys[pmaminDegree:]) + copy(right.values[:pmaminDegree-1], n.values[pmaminDegree:]) + n.keys[0], n.values[0] = n.keys[pmaminDegree-1], n.values[pmaminDegree-1] + pmazeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:pmaminDegree], n.children[:pmaminDegree]) + copy(right.children[:pmaminDegree], n.children[pmaminDegree:]) + pmazeroNodeSlice(n.children[2:]) + for i := 0; i < pmaminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < pmaminDegree { + return pmaGapIterator{left, gap.index} + } + return pmaGapIterator{right, gap.index - pmaminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[pmaminDegree-1], n.values[pmaminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &pmanode{ + nrSegments: pmaminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:pmaminDegree-1], n.keys[pmaminDegree:]) + copy(sibling.values[:pmaminDegree-1], n.values[pmaminDegree:]) + pmazeroValueSlice(n.values[pmaminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:pmaminDegree], n.children[pmaminDegree:]) + pmazeroNodeSlice(n.children[pmaminDegree:]) + for i := 0; i < pmaminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = pmaminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < pmaminDegree { + return gap + } + return pmaGapIterator{sibling, gap.index - pmaminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *pmanode) rebalanceAfterRemove(gap pmaGapIterator) pmaGapIterator { + for { + if n.nrSegments >= pmaminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= pmaminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + pmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return pmaGapIterator{n, 0} + } + if gap.node == n { + return pmaGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= pmaminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + pmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return pmaGapIterator{n, n.nrSegments} + } + return pmaGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return pmaGapIterator{p, gap.index} + } + if gap.node == right { + return pmaGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *pmanode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = pmaGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + pmaSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type pmaIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *pmanode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg pmaIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg pmaIterator) Range() __generics_imported0.AddrRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg pmaIterator) Start() __generics_imported0.Addr { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg pmaIterator) End() __generics_imported0.Addr { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg pmaIterator) SetRangeUnchecked(r __generics_imported0.AddrRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg pmaIterator) SetRange(r __generics_imported0.AddrRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg pmaIterator) SetStartUnchecked(start __generics_imported0.Addr) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg pmaIterator) SetStart(start __generics_imported0.Addr) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg pmaIterator) SetEndUnchecked(end __generics_imported0.Addr) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg pmaIterator) SetEnd(end __generics_imported0.Addr) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg pmaIterator) Value() pma { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg pmaIterator) ValuePtr() *pma { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg pmaIterator) SetValue(val pma) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg pmaIterator) PrevSegment() pmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return pmaIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return pmaIterator{} + } + return pmasegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg pmaIterator) NextSegment() pmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return pmaIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return pmaIterator{} + } + return pmasegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg pmaIterator) PrevGap() pmaGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return pmaGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg pmaIterator) NextGap() pmaGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return pmaGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg pmaIterator) PrevNonEmpty() (pmaIterator, pmaGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return pmaIterator{}, gap + } + return gap.PrevSegment(), pmaGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg pmaIterator) NextNonEmpty() (pmaIterator, pmaGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return pmaIterator{}, gap + } + return gap.NextSegment(), pmaGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type pmaGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *pmanode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap pmaGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap pmaGapIterator) Range() __generics_imported0.AddrRange { + return __generics_imported0.AddrRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap pmaGapIterator) Start() __generics_imported0.Addr { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return pmaSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap pmaGapIterator) End() __generics_imported0.Addr { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return pmaSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap pmaGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap pmaGapIterator) PrevSegment() pmaIterator { + return pmasegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap pmaGapIterator) NextSegment() pmaIterator { + return pmasegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap pmaGapIterator) PrevGap() pmaGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return pmaGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap pmaGapIterator) NextGap() pmaGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return pmaGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func pmasegmentBeforePosition(n *pmanode, i int) pmaIterator { + for i == 0 { + if n.parent == nil { + return pmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return pmaIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func pmasegmentAfterPosition(n *pmanode, i int) pmaIterator { + for i == n.nrSegments { + if n.parent == nil { + return pmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return pmaIterator{n, i} +} + +func pmazeroValueSlice(slice []pma) { + + for i := range slice { + pmaSetFunctions{}.ClearValue(&slice[i]) + } +} + +func pmazeroNodeSlice(slice []*pmanode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *pmaSet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *pmanode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *pmanode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type pmaSegmentDataSlices struct { + Start []__generics_imported0.Addr + End []__generics_imported0.Addr + Values []pma +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *pmaSet) ExportSortedSlices() *pmaSegmentDataSlices { + var sds pmaSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *pmaSet) ImportSortedSlices(sds *pmaSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.AddrRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *pmaSet) saveRoot() *pmaSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *pmaSet) loadRoot(sds *pmaSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/mm/vma_set.go b/pkg/sentry/mm/vma_set.go new file mode 100755 index 000000000..0211e2fce --- /dev/null +++ b/pkg/sentry/mm/vma_set.go @@ -0,0 +1,1274 @@ +package mm + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + vmaminDegree = 8 + + vmamaxDegree = 2 * vmaminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type vmaSet struct { + root vmanode `state:".(*vmaSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *vmaSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *vmaSet) IsEmptyRange(r __generics_imported0.AddrRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *vmaSet) Span() __generics_imported0.Addr { + var sz __generics_imported0.Addr + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *vmaSet) SpanRange(r __generics_imported0.AddrRange) __generics_imported0.Addr { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz __generics_imported0.Addr + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *vmaSet) FirstSegment() vmaIterator { + if s.root.nrSegments == 0 { + return vmaIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *vmaSet) LastSegment() vmaIterator { + if s.root.nrSegments == 0 { + return vmaIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *vmaSet) FirstGap() vmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return vmaGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *vmaSet) LastGap() vmaGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return vmaGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *vmaSet) Find(key __generics_imported0.Addr) (vmaIterator, vmaGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return vmaIterator{n, i}, vmaGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return vmaIterator{}, vmaGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *vmaSet) FindSegment(key __generics_imported0.Addr) vmaIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *vmaSet) LowerBoundSegment(min __generics_imported0.Addr) vmaIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *vmaSet) UpperBoundSegment(max __generics_imported0.Addr) vmaIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *vmaSet) FindGap(key __generics_imported0.Addr) vmaGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *vmaSet) LowerBoundGap(min __generics_imported0.Addr) vmaGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *vmaSet) UpperBoundGap(max __generics_imported0.Addr) vmaGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *vmaSet) Add(r __generics_imported0.AddrRange, val vma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *vmaSet) AddWithoutMerging(r __generics_imported0.AddrRange, val vma) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *vmaSet) Insert(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (vmaSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (vmaSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (vmaSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *vmaSet) InsertWithoutMerging(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *vmaSet) InsertWithoutMergingUnchecked(gap vmaGapIterator, r __generics_imported0.AddrRange, val vma) vmaIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return vmaIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *vmaSet) Remove(seg vmaIterator) vmaGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + vmaSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(vmaGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *vmaSet) RemoveAll() { + s.root = vmanode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *vmaSet) RemoveRange(r __generics_imported0.AddrRange) vmaGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *vmaSet) Merge(first, second vmaIterator) vmaIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *vmaSet) MergeUnchecked(first, second vmaIterator) vmaIterator { + if first.End() == second.Start() { + if mval, ok := (vmaSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return vmaIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *vmaSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *vmaSet) MergeRange(r __generics_imported0.AddrRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *vmaSet) MergeAdjacent(r __generics_imported0.AddrRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *vmaSet) Split(seg vmaIterator, split __generics_imported0.Addr) (vmaIterator, vmaIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *vmaSet) SplitUnchecked(seg vmaIterator, split __generics_imported0.Addr) (vmaIterator, vmaIterator) { + val1, val2 := (vmaSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.AddrRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *vmaSet) SplitAt(split __generics_imported0.Addr) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *vmaSet) Isolate(seg vmaIterator, r __generics_imported0.AddrRange) vmaIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *vmaSet) ApplyContiguous(r __generics_imported0.AddrRange, fn func(seg vmaIterator)) vmaGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return vmaGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return vmaGapIterator{} + } + } +} + +// +stateify savable +type vmanode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *vmanode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [vmamaxDegree - 1]__generics_imported0.AddrRange + values [vmamaxDegree - 1]vma + children [vmamaxDegree]*vmanode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *vmanode) firstSegment() vmaIterator { + for n.hasChildren { + n = n.children[0] + } + return vmaIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *vmanode) lastSegment() vmaIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return vmaIterator{n, n.nrSegments - 1} +} + +func (n *vmanode) prevSibling() *vmanode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *vmanode) nextSibling() *vmanode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *vmanode) rebalanceBeforeInsert(gap vmaGapIterator) vmaGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < vmamaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &vmanode{ + nrSegments: vmaminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &vmanode{ + nrSegments: vmaminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:vmaminDegree-1], n.keys[:vmaminDegree-1]) + copy(left.values[:vmaminDegree-1], n.values[:vmaminDegree-1]) + copy(right.keys[:vmaminDegree-1], n.keys[vmaminDegree:]) + copy(right.values[:vmaminDegree-1], n.values[vmaminDegree:]) + n.keys[0], n.values[0] = n.keys[vmaminDegree-1], n.values[vmaminDegree-1] + vmazeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:vmaminDegree], n.children[:vmaminDegree]) + copy(right.children[:vmaminDegree], n.children[vmaminDegree:]) + vmazeroNodeSlice(n.children[2:]) + for i := 0; i < vmaminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < vmaminDegree { + return vmaGapIterator{left, gap.index} + } + return vmaGapIterator{right, gap.index - vmaminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[vmaminDegree-1], n.values[vmaminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &vmanode{ + nrSegments: vmaminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:vmaminDegree-1], n.keys[vmaminDegree:]) + copy(sibling.values[:vmaminDegree-1], n.values[vmaminDegree:]) + vmazeroValueSlice(n.values[vmaminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:vmaminDegree], n.children[vmaminDegree:]) + vmazeroNodeSlice(n.children[vmaminDegree:]) + for i := 0; i < vmaminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = vmaminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < vmaminDegree { + return gap + } + return vmaGapIterator{sibling, gap.index - vmaminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *vmanode) rebalanceAfterRemove(gap vmaGapIterator) vmaGapIterator { + for { + if n.nrSegments >= vmaminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= vmaminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + vmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return vmaGapIterator{n, 0} + } + if gap.node == n { + return vmaGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= vmaminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + vmaSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return vmaGapIterator{n, n.nrSegments} + } + return vmaGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return vmaGapIterator{p, gap.index} + } + if gap.node == right { + return vmaGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *vmanode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = vmaGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + vmaSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type vmaIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *vmanode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg vmaIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg vmaIterator) Range() __generics_imported0.AddrRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg vmaIterator) Start() __generics_imported0.Addr { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg vmaIterator) End() __generics_imported0.Addr { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg vmaIterator) SetRangeUnchecked(r __generics_imported0.AddrRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg vmaIterator) SetRange(r __generics_imported0.AddrRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg vmaIterator) SetStartUnchecked(start __generics_imported0.Addr) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg vmaIterator) SetStart(start __generics_imported0.Addr) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg vmaIterator) SetEndUnchecked(end __generics_imported0.Addr) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg vmaIterator) SetEnd(end __generics_imported0.Addr) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg vmaIterator) Value() vma { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg vmaIterator) ValuePtr() *vma { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg vmaIterator) SetValue(val vma) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg vmaIterator) PrevSegment() vmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return vmaIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return vmaIterator{} + } + return vmasegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg vmaIterator) NextSegment() vmaIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return vmaIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return vmaIterator{} + } + return vmasegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg vmaIterator) PrevGap() vmaGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return vmaGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg vmaIterator) NextGap() vmaGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return vmaGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg vmaIterator) PrevNonEmpty() (vmaIterator, vmaGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return vmaIterator{}, gap + } + return gap.PrevSegment(), vmaGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg vmaIterator) NextNonEmpty() (vmaIterator, vmaGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return vmaIterator{}, gap + } + return gap.NextSegment(), vmaGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type vmaGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *vmanode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap vmaGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap vmaGapIterator) Range() __generics_imported0.AddrRange { + return __generics_imported0.AddrRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap vmaGapIterator) Start() __generics_imported0.Addr { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return vmaSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap vmaGapIterator) End() __generics_imported0.Addr { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return vmaSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap vmaGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap vmaGapIterator) PrevSegment() vmaIterator { + return vmasegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap vmaGapIterator) NextSegment() vmaIterator { + return vmasegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap vmaGapIterator) PrevGap() vmaGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return vmaGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap vmaGapIterator) NextGap() vmaGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return vmaGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func vmasegmentBeforePosition(n *vmanode, i int) vmaIterator { + for i == 0 { + if n.parent == nil { + return vmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return vmaIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func vmasegmentAfterPosition(n *vmanode, i int) vmaIterator { + for i == n.nrSegments { + if n.parent == nil { + return vmaIterator{} + } + n, i = n.parent, n.parentIndex + } + return vmaIterator{n, i} +} + +func vmazeroValueSlice(slice []vma) { + + for i := range slice { + vmaSetFunctions{}.ClearValue(&slice[i]) + } +} + +func vmazeroNodeSlice(slice []*vmanode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *vmaSet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *vmanode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *vmanode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type vmaSegmentDataSlices struct { + Start []__generics_imported0.Addr + End []__generics_imported0.Addr + Values []vma +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *vmaSet) ExportSortedSlices() *vmaSegmentDataSlices { + var sds vmaSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *vmaSet) ImportSortedSlices(sds *vmaSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.AddrRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *vmaSet) saveRoot() *vmaSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *vmaSet) loadRoot(sds *vmaSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD deleted file mode 100644 index 858f895f2..000000000 --- a/pkg/sentry/pgalloc/BUILD +++ /dev/null @@ -1,85 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "evictable_range", - out = "evictable_range.go", - package = "pgalloc", - prefix = "Evictable", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_template_instance( - name = "evictable_range_set", - out = "evictable_range_set.go", - package = "pgalloc", - prefix = "evictableRange", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "EvictableRange", - "Value": "evictableRangeSetValue", - "Functions": "evictableRangeSetFunctions", - }, -) - -go_template_instance( - name = "usage_set", - out = "usage_set.go", - consts = { - "minDegree": "10", - }, - imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", - }, - package = "pgalloc", - prefix = "usage", - template = "//pkg/segment:generic_set", - types = { - "Key": "uint64", - "Range": "platform.FileRange", - "Value": "usageInfo", - "Functions": "usageSetFunctions", - }, -) - -go_library( - name = "pgalloc", - srcs = [ - "context.go", - "evictable_range.go", - "evictable_range_set.go", - "pgalloc.go", - "pgalloc_unsafe.go", - "save_restore.go", - "usage_set.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/pgalloc", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/log", - "//pkg/memutil", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/hostmm", - "//pkg/sentry/platform", - "//pkg/sentry/safemem", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/state", - "//pkg/syserror", - ], -) - -go_test( - name = "pgalloc_test", - size = "small", - srcs = ["pgalloc_test.go"], - embed = [":pgalloc"], - deps = ["//pkg/sentry/usermem"], -) diff --git a/pkg/sentry/pgalloc/evictable_range.go b/pkg/sentry/pgalloc/evictable_range.go new file mode 100755 index 000000000..10ce2ff44 --- /dev/null +++ b/pkg/sentry/pgalloc/evictable_range.go @@ -0,0 +1,62 @@ +package pgalloc + +// A Range represents a contiguous range of T. +// +// +stateify savable +type EvictableRange struct { + // Start is the inclusive start of the range. + Start uint64 + + // End is the exclusive end of the range. + End uint64 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r EvictableRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r EvictableRange) Length() uint64 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r EvictableRange) Contains(x uint64) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r EvictableRange) Overlaps(r2 EvictableRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r EvictableRange) IsSupersetOf(r2 EvictableRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r EvictableRange) Intersect(r2 EvictableRange) EvictableRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r EvictableRange) CanSplitAt(x uint64) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/pgalloc/evictable_range_set.go b/pkg/sentry/pgalloc/evictable_range_set.go new file mode 100755 index 000000000..a4dcb1663 --- /dev/null +++ b/pkg/sentry/pgalloc/evictable_range_set.go @@ -0,0 +1,1270 @@ +package pgalloc + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + evictableRangeminDegree = 3 + + evictableRangemaxDegree = 2 * evictableRangeminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type evictableRangeSet struct { + root evictableRangenode `state:".(*evictableRangeSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *evictableRangeSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *evictableRangeSet) IsEmptyRange(r EvictableRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *evictableRangeSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *evictableRangeSet) SpanRange(r EvictableRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *evictableRangeSet) FirstSegment() evictableRangeIterator { + if s.root.nrSegments == 0 { + return evictableRangeIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *evictableRangeSet) LastSegment() evictableRangeIterator { + if s.root.nrSegments == 0 { + return evictableRangeIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *evictableRangeSet) FirstGap() evictableRangeGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return evictableRangeGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *evictableRangeSet) LastGap() evictableRangeGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return evictableRangeGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *evictableRangeSet) Find(key uint64) (evictableRangeIterator, evictableRangeGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return evictableRangeIterator{n, i}, evictableRangeGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return evictableRangeIterator{}, evictableRangeGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *evictableRangeSet) FindSegment(key uint64) evictableRangeIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *evictableRangeSet) LowerBoundSegment(min uint64) evictableRangeIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *evictableRangeSet) UpperBoundSegment(max uint64) evictableRangeIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *evictableRangeSet) FindGap(key uint64) evictableRangeGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *evictableRangeSet) LowerBoundGap(min uint64) evictableRangeGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *evictableRangeSet) UpperBoundGap(max uint64) evictableRangeGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *evictableRangeSet) Add(r EvictableRange, val evictableRangeSetValue) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *evictableRangeSet) AddWithoutMerging(r EvictableRange, val evictableRangeSetValue) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *evictableRangeSet) Insert(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (evictableRangeSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (evictableRangeSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (evictableRangeSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *evictableRangeSet) InsertWithoutMerging(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *evictableRangeSet) InsertWithoutMergingUnchecked(gap evictableRangeGapIterator, r EvictableRange, val evictableRangeSetValue) evictableRangeIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return evictableRangeIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *evictableRangeSet) Remove(seg evictableRangeIterator) evictableRangeGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + evictableRangeSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(evictableRangeGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *evictableRangeSet) RemoveAll() { + s.root = evictableRangenode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *evictableRangeSet) RemoveRange(r EvictableRange) evictableRangeGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *evictableRangeSet) Merge(first, second evictableRangeIterator) evictableRangeIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *evictableRangeSet) MergeUnchecked(first, second evictableRangeIterator) evictableRangeIterator { + if first.End() == second.Start() { + if mval, ok := (evictableRangeSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return evictableRangeIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *evictableRangeSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *evictableRangeSet) MergeRange(r EvictableRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *evictableRangeSet) MergeAdjacent(r EvictableRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *evictableRangeSet) Split(seg evictableRangeIterator, split uint64) (evictableRangeIterator, evictableRangeIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *evictableRangeSet) SplitUnchecked(seg evictableRangeIterator, split uint64) (evictableRangeIterator, evictableRangeIterator) { + val1, val2 := (evictableRangeSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), EvictableRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *evictableRangeSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *evictableRangeSet) Isolate(seg evictableRangeIterator, r EvictableRange) evictableRangeIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *evictableRangeSet) ApplyContiguous(r EvictableRange, fn func(seg evictableRangeIterator)) evictableRangeGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return evictableRangeGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return evictableRangeGapIterator{} + } + } +} + +// +stateify savable +type evictableRangenode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *evictableRangenode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [evictableRangemaxDegree - 1]EvictableRange + values [evictableRangemaxDegree - 1]evictableRangeSetValue + children [evictableRangemaxDegree]*evictableRangenode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *evictableRangenode) firstSegment() evictableRangeIterator { + for n.hasChildren { + n = n.children[0] + } + return evictableRangeIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *evictableRangenode) lastSegment() evictableRangeIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return evictableRangeIterator{n, n.nrSegments - 1} +} + +func (n *evictableRangenode) prevSibling() *evictableRangenode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *evictableRangenode) nextSibling() *evictableRangenode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *evictableRangenode) rebalanceBeforeInsert(gap evictableRangeGapIterator) evictableRangeGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < evictableRangemaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &evictableRangenode{ + nrSegments: evictableRangeminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &evictableRangenode{ + nrSegments: evictableRangeminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:evictableRangeminDegree-1], n.keys[:evictableRangeminDegree-1]) + copy(left.values[:evictableRangeminDegree-1], n.values[:evictableRangeminDegree-1]) + copy(right.keys[:evictableRangeminDegree-1], n.keys[evictableRangeminDegree:]) + copy(right.values[:evictableRangeminDegree-1], n.values[evictableRangeminDegree:]) + n.keys[0], n.values[0] = n.keys[evictableRangeminDegree-1], n.values[evictableRangeminDegree-1] + evictableRangezeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:evictableRangeminDegree], n.children[:evictableRangeminDegree]) + copy(right.children[:evictableRangeminDegree], n.children[evictableRangeminDegree:]) + evictableRangezeroNodeSlice(n.children[2:]) + for i := 0; i < evictableRangeminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < evictableRangeminDegree { + return evictableRangeGapIterator{left, gap.index} + } + return evictableRangeGapIterator{right, gap.index - evictableRangeminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[evictableRangeminDegree-1], n.values[evictableRangeminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &evictableRangenode{ + nrSegments: evictableRangeminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:evictableRangeminDegree-1], n.keys[evictableRangeminDegree:]) + copy(sibling.values[:evictableRangeminDegree-1], n.values[evictableRangeminDegree:]) + evictableRangezeroValueSlice(n.values[evictableRangeminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:evictableRangeminDegree], n.children[evictableRangeminDegree:]) + evictableRangezeroNodeSlice(n.children[evictableRangeminDegree:]) + for i := 0; i < evictableRangeminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = evictableRangeminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < evictableRangeminDegree { + return gap + } + return evictableRangeGapIterator{sibling, gap.index - evictableRangeminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *evictableRangenode) rebalanceAfterRemove(gap evictableRangeGapIterator) evictableRangeGapIterator { + for { + if n.nrSegments >= evictableRangeminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= evictableRangeminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + evictableRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return evictableRangeGapIterator{n, 0} + } + if gap.node == n { + return evictableRangeGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= evictableRangeminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + evictableRangeSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return evictableRangeGapIterator{n, n.nrSegments} + } + return evictableRangeGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return evictableRangeGapIterator{p, gap.index} + } + if gap.node == right { + return evictableRangeGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *evictableRangenode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = evictableRangeGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + evictableRangeSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type evictableRangeIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *evictableRangenode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg evictableRangeIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg evictableRangeIterator) Range() EvictableRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg evictableRangeIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg evictableRangeIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg evictableRangeIterator) SetRangeUnchecked(r EvictableRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg evictableRangeIterator) SetRange(r EvictableRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg evictableRangeIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg evictableRangeIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg evictableRangeIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg evictableRangeIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg evictableRangeIterator) Value() evictableRangeSetValue { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg evictableRangeIterator) ValuePtr() *evictableRangeSetValue { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg evictableRangeIterator) SetValue(val evictableRangeSetValue) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg evictableRangeIterator) PrevSegment() evictableRangeIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return evictableRangeIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return evictableRangeIterator{} + } + return evictableRangesegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg evictableRangeIterator) NextSegment() evictableRangeIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return evictableRangeIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return evictableRangeIterator{} + } + return evictableRangesegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg evictableRangeIterator) PrevGap() evictableRangeGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return evictableRangeGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg evictableRangeIterator) NextGap() evictableRangeGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return evictableRangeGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg evictableRangeIterator) PrevNonEmpty() (evictableRangeIterator, evictableRangeGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return evictableRangeIterator{}, gap + } + return gap.PrevSegment(), evictableRangeGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg evictableRangeIterator) NextNonEmpty() (evictableRangeIterator, evictableRangeGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return evictableRangeIterator{}, gap + } + return gap.NextSegment(), evictableRangeGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type evictableRangeGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *evictableRangenode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap evictableRangeGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap evictableRangeGapIterator) Range() EvictableRange { + return EvictableRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap evictableRangeGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return evictableRangeSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap evictableRangeGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return evictableRangeSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap evictableRangeGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap evictableRangeGapIterator) PrevSegment() evictableRangeIterator { + return evictableRangesegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap evictableRangeGapIterator) NextSegment() evictableRangeIterator { + return evictableRangesegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap evictableRangeGapIterator) PrevGap() evictableRangeGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return evictableRangeGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap evictableRangeGapIterator) NextGap() evictableRangeGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return evictableRangeGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func evictableRangesegmentBeforePosition(n *evictableRangenode, i int) evictableRangeIterator { + for i == 0 { + if n.parent == nil { + return evictableRangeIterator{} + } + n, i = n.parent, n.parentIndex + } + return evictableRangeIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func evictableRangesegmentAfterPosition(n *evictableRangenode, i int) evictableRangeIterator { + for i == n.nrSegments { + if n.parent == nil { + return evictableRangeIterator{} + } + n, i = n.parent, n.parentIndex + } + return evictableRangeIterator{n, i} +} + +func evictableRangezeroValueSlice(slice []evictableRangeSetValue) { + + for i := range slice { + evictableRangeSetFunctions{}.ClearValue(&slice[i]) + } +} + +func evictableRangezeroNodeSlice(slice []*evictableRangenode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *evictableRangeSet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *evictableRangenode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *evictableRangenode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type evictableRangeSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []evictableRangeSetValue +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *evictableRangeSet) ExportSortedSlices() *evictableRangeSegmentDataSlices { + var sds evictableRangeSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *evictableRangeSet) ImportSortedSlices(sds *evictableRangeSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := EvictableRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *evictableRangeSet) saveRoot() *evictableRangeSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *evictableRangeSet) loadRoot(sds *evictableRangeSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/pgalloc/pgalloc_state_autogen.go b/pkg/sentry/pgalloc/pgalloc_state_autogen.go new file mode 100755 index 000000000..1cfacc241 --- /dev/null +++ b/pkg/sentry/pgalloc/pgalloc_state_autogen.go @@ -0,0 +1,146 @@ +// automatically generated by stateify. + +package pgalloc + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *EvictableRange) beforeSave() {} +func (x *EvictableRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *EvictableRange) afterLoad() {} +func (x *EvictableRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *evictableRangeSet) beforeSave() {} +func (x *evictableRangeSet) save(m state.Map) { + x.beforeSave() + var root *evictableRangeSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *evictableRangeSet) afterLoad() {} +func (x *evictableRangeSet) load(m state.Map) { + m.LoadValue("root", new(*evictableRangeSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*evictableRangeSegmentDataSlices)) }) +} + +func (x *evictableRangenode) beforeSave() {} +func (x *evictableRangenode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *evictableRangenode) afterLoad() {} +func (x *evictableRangenode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *evictableRangeSegmentDataSlices) beforeSave() {} +func (x *evictableRangeSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *evictableRangeSegmentDataSlices) afterLoad() {} +func (x *evictableRangeSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func (x *usageInfo) beforeSave() {} +func (x *usageInfo) save(m state.Map) { + x.beforeSave() + m.Save("kind", &x.kind) + m.Save("knownCommitted", &x.knownCommitted) + m.Save("refs", &x.refs) +} + +func (x *usageInfo) afterLoad() {} +func (x *usageInfo) load(m state.Map) { + m.Load("kind", &x.kind) + m.Load("knownCommitted", &x.knownCommitted) + m.Load("refs", &x.refs) +} + +func (x *usageSet) beforeSave() {} +func (x *usageSet) save(m state.Map) { + x.beforeSave() + var root *usageSegmentDataSlices = x.saveRoot() + m.SaveValue("root", root) +} + +func (x *usageSet) afterLoad() {} +func (x *usageSet) load(m state.Map) { + m.LoadValue("root", new(*usageSegmentDataSlices), func(y interface{}) { x.loadRoot(y.(*usageSegmentDataSlices)) }) +} + +func (x *usagenode) beforeSave() {} +func (x *usagenode) save(m state.Map) { + x.beforeSave() + m.Save("nrSegments", &x.nrSegments) + m.Save("parent", &x.parent) + m.Save("parentIndex", &x.parentIndex) + m.Save("hasChildren", &x.hasChildren) + m.Save("keys", &x.keys) + m.Save("values", &x.values) + m.Save("children", &x.children) +} + +func (x *usagenode) afterLoad() {} +func (x *usagenode) load(m state.Map) { + m.Load("nrSegments", &x.nrSegments) + m.Load("parent", &x.parent) + m.Load("parentIndex", &x.parentIndex) + m.Load("hasChildren", &x.hasChildren) + m.Load("keys", &x.keys) + m.Load("values", &x.values) + m.Load("children", &x.children) +} + +func (x *usageSegmentDataSlices) beforeSave() {} +func (x *usageSegmentDataSlices) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) + m.Save("Values", &x.Values) +} + +func (x *usageSegmentDataSlices) afterLoad() {} +func (x *usageSegmentDataSlices) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) + m.Load("Values", &x.Values) +} + +func init() { + state.Register("pgalloc.EvictableRange", (*EvictableRange)(nil), state.Fns{Save: (*EvictableRange).save, Load: (*EvictableRange).load}) + state.Register("pgalloc.evictableRangeSet", (*evictableRangeSet)(nil), state.Fns{Save: (*evictableRangeSet).save, Load: (*evictableRangeSet).load}) + state.Register("pgalloc.evictableRangenode", (*evictableRangenode)(nil), state.Fns{Save: (*evictableRangenode).save, Load: (*evictableRangenode).load}) + state.Register("pgalloc.evictableRangeSegmentDataSlices", (*evictableRangeSegmentDataSlices)(nil), state.Fns{Save: (*evictableRangeSegmentDataSlices).save, Load: (*evictableRangeSegmentDataSlices).load}) + state.Register("pgalloc.usageInfo", (*usageInfo)(nil), state.Fns{Save: (*usageInfo).save, Load: (*usageInfo).load}) + state.Register("pgalloc.usageSet", (*usageSet)(nil), state.Fns{Save: (*usageSet).save, Load: (*usageSet).load}) + state.Register("pgalloc.usagenode", (*usagenode)(nil), state.Fns{Save: (*usagenode).save, Load: (*usagenode).load}) + state.Register("pgalloc.usageSegmentDataSlices", (*usageSegmentDataSlices)(nil), state.Fns{Save: (*usageSegmentDataSlices).save, Load: (*usageSegmentDataSlices).load}) +} diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go deleted file mode 100644 index 428e6a859..000000000 --- a/pkg/sentry/pgalloc/pgalloc_test.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgalloc - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -const ( - page = usermem.PageSize - hugepage = usermem.HugePageSize -) - -func TestFindUnallocatedRange(t *testing.T) { - for _, test := range []struct { - desc string - usage *usageSegmentDataSlices - start uint64 - length uint64 - alignment uint64 - unallocated uint64 - minUnallocated uint64 - }{ - { - desc: "Initial allocation succeeds", - usage: &usageSegmentDataSlices{}, - start: 0, - length: page, - alignment: page, - unallocated: 0, - minUnallocated: 0, - }, - { - desc: "Allocation begins at start of file", - usage: &usageSegmentDataSlices{ - Start: []uint64{page}, - End: []uint64{2 * page}, - Values: []usageInfo{{refs: 1}}, - }, - start: 0, - length: page, - alignment: page, - unallocated: 0, - minUnallocated: 0, - }, - { - desc: "In-use frames are not allocatable", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, page}, - End: []uint64{page, 2 * page}, - Values: []usageInfo{{refs: 1}, {refs: 2}}, - }, - start: 0, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, - }, - { - desc: "Reclaimable frames are not allocatable", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, page, 2 * page}, - End: []uint64{page, 2 * page, 3 * page}, - Values: []usageInfo{{refs: 1}, {refs: 0}, {refs: 1}}, - }, - start: 0, - length: page, - alignment: page, - unallocated: 3 * page, - minUnallocated: 3 * page, - }, - { - desc: "Gaps between in-use frames are allocatable", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, 2 * page}, - End: []uint64{page, 3 * page}, - Values: []usageInfo{{refs: 1}, {refs: 1}}, - }, - start: 0, - length: page, - alignment: page, - unallocated: page, - minUnallocated: page, - }, - { - desc: "Inadequately-sized gaps are rejected", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, 2 * page}, - End: []uint64{page, 3 * page}, - Values: []usageInfo{{refs: 1}, {refs: 1}}, - }, - start: 0, - length: 2 * page, - alignment: page, - unallocated: 3 * page, - minUnallocated: page, - }, - { - desc: "Hugepage alignment is honored", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, hugepage + page}, - // Hugepage-sized gap here that shouldn't be allocated from - // since it's incorrectly aligned. - End: []uint64{page, hugepage + 2*page}, - Values: []usageInfo{{refs: 1}, {refs: 1}}, - }, - start: 0, - length: hugepage, - alignment: hugepage, - unallocated: 2 * hugepage, - minUnallocated: page, - }, - { - desc: "Pages before start ignored", - usage: &usageSegmentDataSlices{ - Start: []uint64{page, 3 * page}, - End: []uint64{2 * page, 4 * page}, - Values: []usageInfo{{refs: 1}, {refs: 2}}, - }, - start: page, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, - }, - { - desc: "start may be in the middle of segment", - usage: &usageSegmentDataSlices{ - Start: []uint64{0, 3 * page}, - End: []uint64{2 * page, 4 * page}, - Values: []usageInfo{{refs: 1}, {refs: 2}}, - }, - start: page, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, - }, - } { - t.Run(test.desc, func(t *testing.T) { - var usage usageSet - if err := usage.ImportSortedSlices(test.usage); err != nil { - t.Fatalf("Failed to initialize usage from %v: %v", test.usage, err) - } - unallocated, minUnallocated := findUnallocatedRange(&usage, test.start, test.length, test.alignment) - if unallocated != test.unallocated { - t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got unallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, unallocated, test.unallocated) - } - if minUnallocated != test.minUnallocated { - t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got minUnallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, minUnallocated, test.minUnallocated) - } - }) - } -} diff --git a/pkg/sentry/pgalloc/usage_set.go b/pkg/sentry/pgalloc/usage_set.go new file mode 100755 index 000000000..612807caf --- /dev/null +++ b/pkg/sentry/pgalloc/usage_set.go @@ -0,0 +1,1274 @@ +package pgalloc + +import ( + __generics_imported0 "gvisor.dev/gvisor/pkg/sentry/platform" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + usageminDegree = 10 + + usagemaxDegree = 2 * usageminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type usageSet struct { + root usagenode `state:".(*usageSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *usageSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *usageSet) IsEmptyRange(r __generics_imported0.FileRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *usageSet) Span() uint64 { + var sz uint64 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *usageSet) SpanRange(r __generics_imported0.FileRange) uint64 { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uint64 + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *usageSet) FirstSegment() usageIterator { + if s.root.nrSegments == 0 { + return usageIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *usageSet) LastSegment() usageIterator { + if s.root.nrSegments == 0 { + return usageIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *usageSet) FirstGap() usageGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return usageGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *usageSet) LastGap() usageGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return usageGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *usageSet) Find(key uint64) (usageIterator, usageGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return usageIterator{n, i}, usageGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return usageIterator{}, usageGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *usageSet) FindSegment(key uint64) usageIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *usageSet) LowerBoundSegment(min uint64) usageIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *usageSet) UpperBoundSegment(max uint64) usageIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *usageSet) FindGap(key uint64) usageGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *usageSet) LowerBoundGap(min uint64) usageGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *usageSet) UpperBoundGap(max uint64) usageGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *usageSet) Add(r __generics_imported0.FileRange, val usageInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *usageSet) AddWithoutMerging(r __generics_imported0.FileRange, val usageInfo) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *usageSet) Insert(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (usageSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (usageSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (usageSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *usageSet) InsertWithoutMerging(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *usageSet) InsertWithoutMergingUnchecked(gap usageGapIterator, r __generics_imported0.FileRange, val usageInfo) usageIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return usageIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *usageSet) Remove(seg usageIterator) usageGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + usageSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(usageGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *usageSet) RemoveAll() { + s.root = usagenode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *usageSet) RemoveRange(r __generics_imported0.FileRange) usageGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *usageSet) Merge(first, second usageIterator) usageIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *usageSet) MergeUnchecked(first, second usageIterator) usageIterator { + if first.End() == second.Start() { + if mval, ok := (usageSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return usageIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *usageSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *usageSet) MergeRange(r __generics_imported0.FileRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *usageSet) MergeAdjacent(r __generics_imported0.FileRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *usageSet) Split(seg usageIterator, split uint64) (usageIterator, usageIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *usageSet) SplitUnchecked(seg usageIterator, split uint64) (usageIterator, usageIterator) { + val1, val2 := (usageSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), __generics_imported0.FileRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *usageSet) SplitAt(split uint64) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *usageSet) Isolate(seg usageIterator, r __generics_imported0.FileRange) usageIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *usageSet) ApplyContiguous(r __generics_imported0.FileRange, fn func(seg usageIterator)) usageGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return usageGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return usageGapIterator{} + } + } +} + +// +stateify savable +type usagenode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *usagenode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [usagemaxDegree - 1]__generics_imported0.FileRange + values [usagemaxDegree - 1]usageInfo + children [usagemaxDegree]*usagenode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *usagenode) firstSegment() usageIterator { + for n.hasChildren { + n = n.children[0] + } + return usageIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *usagenode) lastSegment() usageIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return usageIterator{n, n.nrSegments - 1} +} + +func (n *usagenode) prevSibling() *usagenode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *usagenode) nextSibling() *usagenode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *usagenode) rebalanceBeforeInsert(gap usageGapIterator) usageGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < usagemaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &usagenode{ + nrSegments: usageminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &usagenode{ + nrSegments: usageminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:usageminDegree-1], n.keys[:usageminDegree-1]) + copy(left.values[:usageminDegree-1], n.values[:usageminDegree-1]) + copy(right.keys[:usageminDegree-1], n.keys[usageminDegree:]) + copy(right.values[:usageminDegree-1], n.values[usageminDegree:]) + n.keys[0], n.values[0] = n.keys[usageminDegree-1], n.values[usageminDegree-1] + usagezeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:usageminDegree], n.children[:usageminDegree]) + copy(right.children[:usageminDegree], n.children[usageminDegree:]) + usagezeroNodeSlice(n.children[2:]) + for i := 0; i < usageminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < usageminDegree { + return usageGapIterator{left, gap.index} + } + return usageGapIterator{right, gap.index - usageminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[usageminDegree-1], n.values[usageminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &usagenode{ + nrSegments: usageminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:usageminDegree-1], n.keys[usageminDegree:]) + copy(sibling.values[:usageminDegree-1], n.values[usageminDegree:]) + usagezeroValueSlice(n.values[usageminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:usageminDegree], n.children[usageminDegree:]) + usagezeroNodeSlice(n.children[usageminDegree:]) + for i := 0; i < usageminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = usageminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < usageminDegree { + return gap + } + return usageGapIterator{sibling, gap.index - usageminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *usagenode) rebalanceAfterRemove(gap usageGapIterator) usageGapIterator { + for { + if n.nrSegments >= usageminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= usageminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + usageSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return usageGapIterator{n, 0} + } + if gap.node == n { + return usageGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= usageminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + usageSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return usageGapIterator{n, n.nrSegments} + } + return usageGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return usageGapIterator{p, gap.index} + } + if gap.node == right { + return usageGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *usagenode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = usageGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + usageSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type usageIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *usagenode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg usageIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg usageIterator) Range() __generics_imported0.FileRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg usageIterator) Start() uint64 { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg usageIterator) End() uint64 { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg usageIterator) SetRangeUnchecked(r __generics_imported0.FileRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg usageIterator) SetRange(r __generics_imported0.FileRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg usageIterator) SetStartUnchecked(start uint64) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg usageIterator) SetStart(start uint64) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg usageIterator) SetEndUnchecked(end uint64) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg usageIterator) SetEnd(end uint64) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg usageIterator) Value() usageInfo { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg usageIterator) ValuePtr() *usageInfo { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg usageIterator) SetValue(val usageInfo) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg usageIterator) PrevSegment() usageIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return usageIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return usageIterator{} + } + return usagesegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg usageIterator) NextSegment() usageIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return usageIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return usageIterator{} + } + return usagesegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg usageIterator) PrevGap() usageGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return usageGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg usageIterator) NextGap() usageGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return usageGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg usageIterator) PrevNonEmpty() (usageIterator, usageGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return usageIterator{}, gap + } + return gap.PrevSegment(), usageGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg usageIterator) NextNonEmpty() (usageIterator, usageGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return usageIterator{}, gap + } + return gap.NextSegment(), usageGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type usageGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *usagenode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap usageGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap usageGapIterator) Range() __generics_imported0.FileRange { + return __generics_imported0.FileRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap usageGapIterator) Start() uint64 { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return usageSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap usageGapIterator) End() uint64 { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return usageSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap usageGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap usageGapIterator) PrevSegment() usageIterator { + return usagesegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap usageGapIterator) NextSegment() usageIterator { + return usagesegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap usageGapIterator) PrevGap() usageGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return usageGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap usageGapIterator) NextGap() usageGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return usageGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func usagesegmentBeforePosition(n *usagenode, i int) usageIterator { + for i == 0 { + if n.parent == nil { + return usageIterator{} + } + n, i = n.parent, n.parentIndex + } + return usageIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func usagesegmentAfterPosition(n *usagenode, i int) usageIterator { + for i == n.nrSegments { + if n.parent == nil { + return usageIterator{} + } + n, i = n.parent, n.parentIndex + } + return usageIterator{n, i} +} + +func usagezeroValueSlice(slice []usageInfo) { + + for i := range slice { + usageSetFunctions{}.ClearValue(&slice[i]) + } +} + +func usagezeroNodeSlice(slice []*usagenode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *usageSet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *usagenode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *usagenode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type usageSegmentDataSlices struct { + Start []uint64 + End []uint64 + Values []usageInfo +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *usageSet) ExportSortedSlices() *usageSegmentDataSlices { + var sds usageSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *usageSet) ImportSortedSlices(sds *usageSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := __generics_imported0.FileRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *usageSet) saveRoot() *usageSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *usageSet) loadRoot(sds *usageSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD deleted file mode 100644 index 0b9962b2b..000000000 --- a/pkg/sentry/platform/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") - -go_template_instance( - name = "file_range", - out = "file_range.go", - package = "platform", - prefix = "File", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - -go_library( - name = "platform", - srcs = [ - "context.go", - "file_range.go", - "mmap_min_addr.go", - "platform.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/atomicbitops", - "//pkg/log", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/safemem", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/platform/file_range.go b/pkg/sentry/platform/file_range.go new file mode 100755 index 000000000..685d360e3 --- /dev/null +++ b/pkg/sentry/platform/file_range.go @@ -0,0 +1,62 @@ +package platform + +// A Range represents a contiguous range of T. +// +// +stateify savable +type FileRange struct { + // Start is the inclusive start of the range. + Start uint64 + + // End is the exclusive end of the range. + End uint64 +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r FileRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r FileRange) Length() uint64 { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r FileRange) Contains(x uint64) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r FileRange) Overlaps(r2 FileRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r FileRange) IsSupersetOf(r2 FileRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r FileRange) Intersect(r2 FileRange) FileRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r FileRange) CanSplitAt(x uint64) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD deleted file mode 100644 index eeb634644..000000000 --- a/pkg/sentry/platform/interrupt/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "interrupt", - srcs = [ - "interrupt.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/interrupt", - visibility = ["//pkg/sentry:internal"], -) - -go_test( - name = "interrupt_test", - size = "small", - srcs = ["interrupt_test.go"], - embed = [":interrupt"], -) diff --git a/pkg/sentry/platform/interrupt/interrupt_state_autogen.go b/pkg/sentry/platform/interrupt/interrupt_state_autogen.go new file mode 100755 index 000000000..15e8bacdf --- /dev/null +++ b/pkg/sentry/platform/interrupt/interrupt_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package interrupt + diff --git a/pkg/sentry/platform/interrupt/interrupt_test.go b/pkg/sentry/platform/interrupt/interrupt_test.go deleted file mode 100644 index 0ecdf6e7a..000000000 --- a/pkg/sentry/platform/interrupt/interrupt_test.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package interrupt - -import ( - "testing" -) - -type countingReceiver struct { - interrupts int -} - -// NotifyInterrupt implements Receiver.NotifyInterrupt. -func (r *countingReceiver) NotifyInterrupt() { - r.interrupts++ -} - -func TestSingleInterruptBeforeEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - f.NotifyInterrupt() - // The interrupt should cause the first Enable to fail. - if f.Enable(&r) { - f.Disable() - t.Fatalf("Enable: got true, wanted false") - } - // The failing Enable "acknowledges" the interrupt, allowing future Enables - // to succeed. - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - f.Disable() -} - -func TestMultipleInterruptsBeforeEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - f.NotifyInterrupt() - f.NotifyInterrupt() - // The interrupts should cause the first Enable to fail. - if f.Enable(&r) { - f.Disable() - t.Fatalf("Enable: got true, wanted false") - } - // Interrupts are deduplicated while the Forwarder is disabled, so the - // failing Enable "acknowledges" all interrupts, allowing future Enables to - // succeed. - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - f.Disable() -} - -func TestSingleInterruptAfterEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - defer f.Disable() - f.NotifyInterrupt() - if r.interrupts != 1 { - t.Errorf("interrupts: got %d, wanted 1", r.interrupts) - } -} - -func TestMultipleInterruptsAfterEnable(t *testing.T) { - var ( - f Forwarder - r countingReceiver - ) - if !f.Enable(&r) { - t.Fatalf("Enable: got false, wanted true") - } - defer f.Disable() - f.NotifyInterrupt() - f.NotifyInterrupt() - if r.interrupts != 2 { - t.Errorf("interrupts: got %d, wanted 2", r.interrupts) - } -} diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD deleted file mode 100644 index 9ccf77fdf..000000000 --- a/pkg/sentry/platform/kvm/BUILD +++ /dev/null @@ -1,66 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "kvm", - srcs = [ - "address_space.go", - "allocator.go", - "bluepill.go", - "bluepill_amd64.go", - "bluepill_amd64.s", - "bluepill_amd64_unsafe.go", - "bluepill_fault.go", - "bluepill_unsafe.go", - "context.go", - "kvm.go", - "kvm_amd64.go", - "kvm_amd64_unsafe.go", - "kvm_const.go", - "machine.go", - "machine_amd64.go", - "machine_amd64_unsafe.go", - "machine_unsafe.go", - "physical_map.go", - "virtual_map.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/atomicbitops", - "//pkg/cpuid", - "//pkg/log", - "//pkg/procid", - "//pkg/sentry/arch", - "//pkg/sentry/platform", - "//pkg/sentry/platform/interrupt", - "//pkg/sentry/platform/ring0", - "//pkg/sentry/platform/ring0/pagetables", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/time", - "//pkg/sentry/usermem", - ], -) - -go_test( - name = "kvm_test", - srcs = [ - "kvm_test.go", - "virtual_map_test.go", - ], - embed = [":kvm"], - tags = [ - "nogotsan", - "requires-kvm", - ], - deps = [ - "//pkg/sentry/arch", - "//pkg/sentry/platform", - "//pkg/sentry/platform/kvm/testutil", - "//pkg/sentry/platform/ring0", - "//pkg/sentry/platform/ring0/pagetables", - "//pkg/sentry/usermem", - ], -) diff --git a/pkg/sentry/platform/kvm/kvm_state_autogen.go b/pkg/sentry/platform/kvm/kvm_state_autogen.go new file mode 100755 index 000000000..5ab0e0735 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package kvm + diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go deleted file mode 100644 index 30df725d4..000000000 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ /dev/null @@ -1,533 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kvm - -import ( - "math/rand" - "reflect" - "sync/atomic" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -var dummyFPState = (*byte)(arch.NewFloatingPointData()) - -type testHarness interface { - Errorf(format string, args ...interface{}) - Fatalf(format string, args ...interface{}) -} - -func kvmTest(t testHarness, setup func(*KVM), fn func(*vCPU) bool) { - // Create the machine. - deviceFile, err := OpenDevice() - if err != nil { - t.Fatalf("error opening device file: %v", err) - } - k, err := New(deviceFile) - if err != nil { - t.Fatalf("error creating KVM instance: %v", err) - } - defer k.machine.Destroy() - - // Call additional setup. - if setup != nil { - setup(k) - } - - var c *vCPU // For recovery. - defer func() { - redpill() - if c != nil { - k.machine.Put(c) - } - }() - for { - c = k.machine.Get() - if !fn(c) { - break - } - - // We put the vCPU here and clear the value so that the - // deferred recovery will not re-put it above. - k.machine.Put(c) - c = nil - } -} - -func bluepillTest(t testHarness, fn func(*vCPU)) { - kvmTest(t, nil, func(c *vCPU) bool { - bluepill(c) - fn(c) - return false - }) -} - -func TestKernelSyscall(t *testing.T) { - bluepillTest(t, func(c *vCPU) { - redpill() // Leave guest mode. - if got := atomic.LoadUint32(&c.state); got != vCPUUser { - t.Errorf("vCPU not in ready state: got %v", got) - } - }) -} - -func hostFault() { - defer func() { - recover() - }() - var foo *int - *foo = 0 -} - -func TestKernelFault(t *testing.T) { - hostFault() // Ensure recovery works. - bluepillTest(t, func(c *vCPU) { - hostFault() - if got := atomic.LoadUint32(&c.state); got != vCPUUser { - t.Errorf("vCPU not in ready state: got %v", got) - } - }) -} - -func TestKernelFloatingPoint(t *testing.T) { - bluepillTest(t, func(c *vCPU) { - if !testutil.FloatingPointWorks() { - t.Errorf("floating point does not work, and it should!") - } - }) -} - -func applicationTest(t testHarness, useHostMappings bool, target func(), fn func(*vCPU, *syscall.PtraceRegs, *pagetables.PageTables) bool) { - // Initialize registers & page tables. - var ( - regs syscall.PtraceRegs - pt *pagetables.PageTables - ) - testutil.SetTestTarget(®s, target) - - kvmTest(t, func(k *KVM) { - // Create new page tables. - as, _, err := k.NewAddressSpace(nil /* invalidator */) - if err != nil { - t.Fatalf("can't create new address space: %v", err) - } - pt = as.(*addressSpace).pageTables - - if useHostMappings { - // Apply the physical mappings to these page tables. - // (This is normally dangerous, since they point to - // physical pages that may not exist. This shouldn't be - // done for regular user code, but is fine for test - // purposes.) - applyPhysicalRegions(func(pr physicalRegion) bool { - pt.Map(usermem.Addr(pr.virtual), pr.length, pagetables.MapOpts{ - AccessType: usermem.AnyAccess, - User: true, - }, pr.physical) - return true // Keep iterating. - }) - } - }, func(c *vCPU) bool { - // Invoke the function with the extra data. - return fn(c, ®s, pt) - }) -} - -func TestApplicationSyscall(t *testing.T) { - applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != nil { - t.Errorf("application syscall with full restore failed: %v", err) - } - return false - }) - applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != nil { - t.Errorf("application syscall with partial restore failed: %v", err) - } - return false - }) -} - -func TestApplicationFault(t *testing.T) { - applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTouchTarget(regs, nil) // Cause fault. - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != platform.ErrContextSignal || si.Signo != int32(syscall.SIGSEGV) { - t.Errorf("application fault with full restore got (%v, %v), expected (%v, SIGSEGV)", err, si, platform.ErrContextSignal) - } - return false - }) - applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTouchTarget(regs, nil) // Cause fault. - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if err != platform.ErrContextSignal || si.Signo != int32(syscall.SIGSEGV) { - t.Errorf("application fault with partial restore got (%v, %v), expected (%v, SIGSEGV)", err, si, platform.ErrContextSignal) - } - return false - }) -} - -func TestRegistersSyscall(t *testing.T) { - applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTestRegs(regs) // Fill values for all registers. - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application register check with partial restore got unexpected error: %v", err) - } - if err := testutil.CheckTestRegs(regs, false); err != nil { - t.Errorf("application register check with partial restore failed: %v", err) - } - break // Done. - } - return false - }) -} - -func TestRegistersFault(t *testing.T) { - applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTestRegs(regs) // Fill values for all registers. - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != platform.ErrContextSignal || si.Signo != int32(syscall.SIGSEGV) { - t.Errorf("application register check with full restore got unexpected error: %v", err) - } - if err := testutil.CheckTestRegs(regs, true); err != nil { - t.Errorf("application register check with full restore failed: %v", err) - } - break // Done. - } - return false - }) -} - -func TestSegments(t *testing.T) { - applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTestSegments(regs) - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application segment check with full restore got unexpected error: %v", err) - } - if err := testutil.CheckTestSegments(regs); err != nil { - t.Errorf("application segment check with full restore failed: %v", err) - } - break // Done. - } - return false - }) -} - -func TestBounce(t *testing.T) { - applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - go func() { - time.Sleep(time.Millisecond) - c.BounceToKernel() - }() - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err != platform.ErrContextInterrupt { - t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) - } - return false - }) - applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - go func() { - time.Sleep(time.Millisecond) - c.BounceToKernel() - }() - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err != platform.ErrContextInterrupt { - t.Errorf("application full restore: got %v, wanted %v", err, platform.ErrContextInterrupt) - } - return false - }) -} - -func TestBounceStress(t *testing.T) { - applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - randomSleep := func() { - // O(hundreds of microseconds) is appropriate to ensure - // different overlaps and different schedules. - if n := rand.Intn(1000); n > 100 { - time.Sleep(time.Duration(n) * time.Microsecond) - } - } - for i := 0; i < 1000; i++ { - // Start an asynchronously executing goroutine that - // calls Bounce at pseudo-random point in time. - // This should wind up calling Bounce when the - // kernel is in various stages of the switch. - go func() { - randomSleep() - c.BounceToKernel() - }() - randomSleep() - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err != platform.ErrContextInterrupt { - t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextInterrupt) - } - c.unlock() - randomSleep() - c.lock() - } - return false - }) -} - -func TestInvalidate(t *testing.T) { - var data uintptr // Used below. - applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - testutil.SetTouchTarget(regs, &data) // Read legitimate value. - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application partial restore: got %v, wanted nil", err) - } - break // Done. - } - // Unmap the page containing data & invalidate. - pt.Unmap(usermem.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(usermem.PageSize-1)), usermem.PageSize) - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - Flush: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != platform.ErrContextSignal { - t.Errorf("application partial restore: got %v, wanted %v", err, platform.ErrContextSignal) - } - break // Success. - } - return false - }) -} - -// IsFault returns true iff the given signal represents a fault. -func IsFault(err error, si *arch.SignalInfo) bool { - return err == platform.ErrContextSignal && si.Signo == int32(syscall.SIGSEGV) -} - -func TestEmptyAddressSpace(t *testing.T) { - applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if !IsFault(err, &si) { - t.Errorf("first fault with partial restore failed got %v", err) - t.Logf("registers: %#v", ®s) - } - return false - }) - applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - return true // Retry. - } else if !IsFault(err, &si) { - t.Errorf("first fault with full restore failed got %v", err) - t.Logf("registers: %#v", ®s) - } - return false - }) -} - -func TestWrongVCPU(t *testing.T) { - kvmTest(t, nil, func(c1 *vCPU) bool { - kvmTest(t, nil, func(c2 *vCPU) bool { - // Basic test, one then the other. - bluepill(c1) - bluepill(c2) - if c2.switches == 0 { - // Don't allow the test to proceed if this fails. - t.Fatalf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2) - } - - // Alternate vCPUs; we expect to need to trigger the - // wrong vCPU path on each switch. - for i := 0; i < 100; i++ { - bluepill(c1) - bluepill(c2) - } - if count := c1.switches; count < 90 { - t.Errorf("wrong vCPU#1 switches: vCPU1=%+v,vCPU2=%+v", c1, c2) - } - if count := c2.switches; count < 90 { - t.Errorf("wrong vCPU#2 switches: vCPU1=%+v,vCPU2=%+v", c1, c2) - } - return false - }) - return false - }) - kvmTest(t, nil, func(c1 *vCPU) bool { - kvmTest(t, nil, func(c2 *vCPU) bool { - bluepill(c1) - bluepill(c2) - return false - }) - return false - }) -} - -func BenchmarkApplicationSyscall(b *testing.B) { - var ( - i int // Iteration includes machine.Get() / machine.Put(). - a int // Count for ErrContextInterrupt. - ) - applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - a++ - return true // Ignore. - } else if err != nil { - b.Fatalf("benchmark failed: %v", err) - } - i++ - return i < b.N - }) - if a != 0 { - b.Logf("ErrContextInterrupt occurred %d times (in %d iterations).", a, a+i) - } -} - -func BenchmarkKernelSyscall(b *testing.B) { - // Note that the target passed here is irrelevant, we never execute SwitchToUser. - applicationTest(b, true, testutil.Getpid, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - // iteration does not include machine.Get() / machine.Put(). - for i := 0; i < b.N; i++ { - testutil.Getpid() - } - return false - }) -} - -func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) { - // see BenchmarkApplicationSyscall. - var ( - i int - a int - ) - applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - }, &si); err == platform.ErrContextInterrupt { - a++ - return true // Ignore. - } else if err != nil { - b.Fatalf("benchmark failed: %v", err) - } - // This will intentionally cause the world switch. By executing - // a host syscall here, we force the transition between guest - // and host mode. - testutil.Getpid() - i++ - return i < b.N - }) - if a != 0 { - b.Logf("ErrContextInterrupt occurred %d times (in %d iterations).", a, a+i) - } -} diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD deleted file mode 100644 index 77a449a8b..000000000 --- a/pkg/sentry/platform/kvm/testutil/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testutil", - testonly = 1, - srcs = [ - "testutil.go", - "testutil_amd64.go", - "testutil_amd64.s", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil", - visibility = ["//pkg/sentry/platform/kvm:__pkg__"], -) diff --git a/pkg/sentry/platform/kvm/testutil/testutil.go b/pkg/sentry/platform/kvm/testutil/testutil.go deleted file mode 100644 index 6cf2359a3..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil provides common assembly stubs for testing. -package testutil - -import ( - "fmt" - "strings" -) - -// Getpid executes a trivial system call. -func Getpid() - -// Touch touches the value in the first register. -func Touch() - -// SyscallLoop executes a syscall and loops. -func SyscallLoop() - -// SpinLoop spins on the CPU. -func SpinLoop() - -// HaltLoop immediately halts and loops. -func HaltLoop() - -// TwiddleRegsFault twiddles registers then faults. -func TwiddleRegsFault() - -// TwiddleRegsSyscall twiddles registers then executes a syscall. -func TwiddleRegsSyscall() - -// TwiddleSegments reads segments into known registers. -func TwiddleSegments() - -// FloatingPointWorks is a floating point test. -// -// It returns true or false. -func FloatingPointWorks() bool - -// RegisterMismatchError is used for checking registers. -type RegisterMismatchError []string - -// Error returns a human-readable error. -func (r RegisterMismatchError) Error() string { - return strings.Join([]string(r), ";") -} - -// addRegisterMisatch allows simple chaining of register mismatches. -func addRegisterMismatch(err error, reg string, got, expected interface{}) error { - errStr := fmt.Sprintf("%s got %08x, expected %08x", reg, got, expected) - switch r := err.(type) { - case nil: - // Return a new register mismatch. - return RegisterMismatchError{errStr} - case RegisterMismatchError: - // Append the error. - r = append(r, errStr) - return r - default: - // Leave as is. - return err - } -} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go deleted file mode 100644 index 203d71528..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - -package testutil - -import ( - "reflect" - "syscall" -) - -// SetTestTarget sets the rip appropriately. -func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { - regs.Rip = uint64(reflect.ValueOf(fn).Pointer()) -} - -// SetTouchTarget sets rax appropriately. -func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) { - if target != nil { - regs.Rax = uint64(reflect.ValueOf(target).Pointer()) - } else { - regs.Rax = 0 - } -} - -// RewindSyscall rewinds a syscall RIP. -func RewindSyscall(regs *syscall.PtraceRegs) { - regs.Rip -= 2 -} - -// SetTestRegs initializes registers to known values. -func SetTestRegs(regs *syscall.PtraceRegs) { - regs.R15 = 0x15 - regs.R14 = 0x14 - regs.R13 = 0x13 - regs.R12 = 0x12 - regs.Rbp = 0xb9 - regs.Rbx = 0xb4 - regs.R11 = 0x11 - regs.R10 = 0x10 - regs.R9 = 0x09 - regs.R8 = 0x08 - regs.Rax = 0x44 - regs.Rcx = 0xc4 - regs.Rdx = 0xd4 - regs.Rsi = 0x51 - regs.Rdi = 0xd1 - regs.Rsp = 0x59 -} - -// CheckTestRegs checks that registers were twiddled per TwiddleRegs. -func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) { - if need := ^uint64(0x15); regs.R15 != need { - err = addRegisterMismatch(err, "R15", regs.R15, need) - } - if need := ^uint64(0x14); regs.R14 != need { - err = addRegisterMismatch(err, "R14", regs.R14, need) - } - if need := ^uint64(0x13); regs.R13 != need { - err = addRegisterMismatch(err, "R13", regs.R13, need) - } - if need := ^uint64(0x12); regs.R12 != need { - err = addRegisterMismatch(err, "R12", regs.R12, need) - } - if need := ^uint64(0xb9); regs.Rbp != need { - err = addRegisterMismatch(err, "Rbp", regs.Rbp, need) - } - if need := ^uint64(0xb4); regs.Rbx != need { - err = addRegisterMismatch(err, "Rbx", regs.Rbx, need) - } - if need := ^uint64(0x10); regs.R10 != need { - err = addRegisterMismatch(err, "R10", regs.R10, need) - } - if need := ^uint64(0x09); regs.R9 != need { - err = addRegisterMismatch(err, "R9", regs.R9, need) - } - if need := ^uint64(0x08); regs.R8 != need { - err = addRegisterMismatch(err, "R8", regs.R8, need) - } - if need := ^uint64(0x44); regs.Rax != need { - err = addRegisterMismatch(err, "Rax", regs.Rax, need) - } - if need := ^uint64(0xd4); regs.Rdx != need { - err = addRegisterMismatch(err, "Rdx", regs.Rdx, need) - } - if need := ^uint64(0x51); regs.Rsi != need { - err = addRegisterMismatch(err, "Rsi", regs.Rsi, need) - } - if need := ^uint64(0xd1); regs.Rdi != need { - err = addRegisterMismatch(err, "Rdi", regs.Rdi, need) - } - if need := ^uint64(0x59); regs.Rsp != need { - err = addRegisterMismatch(err, "Rsp", regs.Rsp, need) - } - // Rcx & R11 are ignored if !full is set. - if need := ^uint64(0x11); full && regs.R11 != need { - err = addRegisterMismatch(err, "R11", regs.R11, need) - } - if need := ^uint64(0xc4); full && regs.Rcx != need { - err = addRegisterMismatch(err, "Rcx", regs.Rcx, need) - } - return -} - -var fsData uint64 = 0x55 -var gsData uint64 = 0x85 - -// SetTestSegments initializes segments to known values. -func SetTestSegments(regs *syscall.PtraceRegs) { - regs.Fs_base = uint64(reflect.ValueOf(&fsData).Pointer()) - regs.Gs_base = uint64(reflect.ValueOf(&gsData).Pointer()) -} - -// CheckTestSegments checks that registers were twiddled per TwiddleSegments. -func CheckTestSegments(regs *syscall.PtraceRegs) (err error) { - if regs.Rax != fsData { - err = addRegisterMismatch(err, "Rax", regs.Rax, fsData) - } - if regs.Rbx != gsData { - err = addRegisterMismatch(err, "Rbx", regs.Rcx, gsData) - } - return -} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.s b/pkg/sentry/platform/kvm/testutil/testutil_amd64.s deleted file mode 100644 index 491ec0c2a..000000000 --- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.s +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - -// test_util_amd64.s provides AMD64 test functions. - -#include "funcdata.h" -#include "textflag.h" - -TEXT ·Getpid(SB),NOSPLIT,$0 - NO_LOCAL_POINTERS - MOVQ $39, AX // getpid - SYSCALL - RET - -TEXT ·Touch(SB),NOSPLIT,$0 -start: - MOVQ 0(AX), BX // deref AX - MOVQ $39, AX // getpid - SYSCALL - JMP start - -TEXT ·HaltLoop(SB),NOSPLIT,$0 -start: - HLT - JMP start - -TEXT ·SyscallLoop(SB),NOSPLIT,$0 -start: - SYSCALL - JMP start - -TEXT ·SpinLoop(SB),NOSPLIT,$0 -start: - JMP start - -TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8 - NO_LOCAL_POINTERS - MOVQ $1, AX - MOVQ AX, X0 - MOVQ $39, AX // getpid - SYSCALL - MOVQ X0, AX - CMPQ AX, $1 - SETEQ ret+0(FP) - RET - -#define TWIDDLE_REGS() \ - NOTQ R15; \ - NOTQ R14; \ - NOTQ R13; \ - NOTQ R12; \ - NOTQ BP; \ - NOTQ BX; \ - NOTQ R11; \ - NOTQ R10; \ - NOTQ R9; \ - NOTQ R8; \ - NOTQ AX; \ - NOTQ CX; \ - NOTQ DX; \ - NOTQ SI; \ - NOTQ DI; \ - NOTQ SP; - -TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 - TWIDDLE_REGS() - SYSCALL - RET // never reached - -TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 - TWIDDLE_REGS() - JMP AX // must fault - RET // never reached - -#define READ_FS() BYTE $0x64; BYTE $0x48; BYTE $0x8b; BYTE $0x00; -#define READ_GS() BYTE $0x65; BYTE $0x48; BYTE $0x8b; BYTE $0x00; - -TEXT ·TwiddleSegments(SB),NOSPLIT,$0 - MOVQ $0x0, AX - READ_GS() - MOVQ AX, BX - MOVQ $0x0, AX - READ_FS() - SYSCALL - RET // never reached diff --git a/pkg/sentry/platform/kvm/virtual_map_test.go b/pkg/sentry/platform/kvm/virtual_map_test.go deleted file mode 100644 index 6a2f145be..000000000 --- a/pkg/sentry/platform/kvm/virtual_map_test.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kvm - -import ( - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -type checker struct { - ok bool - accessType usermem.AccessType -} - -func (c *checker) Containing(addr uintptr) func(virtualRegion) { - c.ok = false // Reset for below calls. - return func(vr virtualRegion) { - if vr.virtual <= addr && addr < vr.virtual+vr.length { - c.ok = true - c.accessType = vr.accessType - } - } -} - -func TestParseMaps(t *testing.T) { - c := new(checker) - - // Simple test. - if err := applyVirtualRegions(c.Containing(0)); err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // MMap a new page. - addr, _, errno := syscall.RawSyscall6( - syscall.SYS_MMAP, 0, usermem.PageSize, - syscall.PROT_READ|syscall.PROT_WRITE, - syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE, 0, 0) - if errno != 0 { - t.Fatalf("unexpected map error: %v", errno) - } - - // Re-parse maps. - if err := applyVirtualRegions(c.Containing(addr)); err != nil { - syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0) - t.Fatalf("unexpected error: %v", err) - } - - // Assert that it now does contain the region. - if !c.ok { - syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0) - t.Fatalf("updated map does not contain 0x%08x, expected true", addr) - } - - // Map the region as PROT_NONE. - newAddr, _, errno := syscall.RawSyscall6( - syscall.SYS_MMAP, addr, usermem.PageSize, - syscall.PROT_NONE, - syscall.MAP_ANONYMOUS|syscall.MAP_FIXED|syscall.MAP_PRIVATE, 0, 0) - if errno != 0 { - t.Fatalf("unexpected map error: %v", errno) - } - if newAddr != addr { - t.Fatalf("unable to remap address: got 0x%08x, wanted 0x%08x", newAddr, addr) - } - - // Re-parse maps. - if err := applyVirtualRegions(c.Containing(addr)); err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !c.ok { - t.Fatalf("final map does not contain 0x%08x, expected true", addr) - } - if c.accessType.Read || c.accessType.Write || c.accessType.Execute { - t.Fatalf("final map has incorrect permissions for 0x%08x", addr) - } - - // Unmap the region. - syscall.RawSyscall(syscall.SYS_MUNMAP, addr, usermem.PageSize, 0) -} diff --git a/pkg/sentry/platform/platform_state_autogen.go b/pkg/sentry/platform/platform_state_autogen.go new file mode 100755 index 000000000..28dcd1764 --- /dev/null +++ b/pkg/sentry/platform/platform_state_autogen.go @@ -0,0 +1,24 @@ +// automatically generated by stateify. + +package platform + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *FileRange) beforeSave() {} +func (x *FileRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *FileRange) afterLoad() {} +func (x *FileRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func init() { + state.Register("platform.FileRange", (*FileRange)(nil), state.Fns{Save: (*FileRange).save, Load: (*FileRange).load}) +} diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD deleted file mode 100644 index 6a1343f47..000000000 --- a/pkg/sentry/platform/ptrace/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "ptrace", - srcs = [ - "ptrace.go", - "ptrace_unsafe.go", - "stub_amd64.s", - "stub_unsafe.go", - "subprocess.go", - "subprocess_amd64.go", - "subprocess_linux.go", - "subprocess_linux_amd64_unsafe.go", - "subprocess_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ptrace", - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/procid", - "//pkg/seccomp", - "//pkg/sentry/arch", - "//pkg/sentry/platform", - "//pkg/sentry/platform/interrupt", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/usermem", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/platform/ptrace/ptrace_state_autogen.go b/pkg/sentry/platform/ptrace/ptrace_state_autogen.go new file mode 100755 index 000000000..ac83a71e7 --- /dev/null +++ b/pkg/sentry/platform/ptrace/ptrace_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package ptrace + diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD deleted file mode 100644 index 8ed6c7652..000000000 --- a/pkg/sentry/platform/ring0/BUILD +++ /dev/null @@ -1,53 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -go_template( - name = "defs", - srcs = [ - "defs.go", - "defs_amd64.go", - "offsets_amd64.go", - "x86.go", - ], - visibility = [":__subpackages__"], -) - -go_template_instance( - name = "defs_impl", - out = "defs_impl.go", - package = "ring0", - template = ":defs", -) - -genrule( - name = "entry_impl_amd64", - srcs = ["entry_amd64.s"], - outs = ["entry_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", - tools = ["//pkg/sentry/platform/ring0/gen_offsets"], -) - -go_library( - name = "ring0", - srcs = [ - "defs_impl.go", - "entry_amd64.go", - "entry_impl_amd64.s", - "kernel.go", - "kernel_amd64.go", - "kernel_unsafe.go", - "lib_amd64.go", - "lib_amd64.s", - "ring0.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/cpuid", - "//pkg/sentry/platform/ring0/pagetables", - "//pkg/sentry/usermem", - ], -) diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go deleted file mode 100644 index 076063f85..000000000 --- a/pkg/sentry/platform/ring0/defs.go +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ring0 - -import ( - "syscall" - - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -var ( - // UserspaceSize is the total size of userspace. - UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) - - // MaximumUserAddress is the largest possible user address. - MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) - - // KernelStartAddress is the starting kernel address. - KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) -) - -// Kernel is a global kernel object. -// -// This contains global state, shared by multiple CPUs. -type Kernel struct { - KernelArchState -} - -// Hooks are hooks for kernel functions. -type Hooks interface { - // KernelSyscall is called for kernel system calls. - // - // Return from this call will restore registers and return to the kernel: the - // registers must be modified directly. - // - // If this function is not provided, a kernel exception results in halt. - // - // This must be go:nosplit, as this will be on the interrupt stack. - // Closures are permitted, as the pointer to the closure frame is not - // passed on the stack. - KernelSyscall() - - // KernelException handles an exception during kernel execution. - // - // Return from this call will restore registers and return to the kernel: the - // registers must be modified directly. - // - // If this function is not provided, a kernel exception results in halt. - // - // This must be go:nosplit, as this will be on the interrupt stack. - // Closures are permitted, as the pointer to the closure frame is not - // passed on the stack. - KernelException(Vector) -} - -// CPU is the per-CPU struct. -type CPU struct { - // self is a self reference. - // - // This is always guaranteed to be at offset zero. - self *CPU - - // kernel is reference to the kernel that this CPU was initialized - // with. This reference is kept for garbage collection purposes: CPU - // registers may refer to objects within the Kernel object that cannot - // be safely freed. - kernel *Kernel - - // CPUArchState is architecture-specific state. - CPUArchState - - // registers is a set of registers; these may be used on kernel system - // calls and exceptions via the Registers function. - registers syscall.PtraceRegs - - // hooks are kernel hooks. - hooks Hooks -} - -// Registers returns a modifiable-copy of the kernel registers. -// -// This is explicitly safe to call during KernelException and KernelSyscall. -// -//go:nosplit -func (c *CPU) Registers() *syscall.PtraceRegs { - return &c.registers -} - -// SwitchOpts are passed to the Switch function. -type SwitchOpts struct { - // Registers are the user register state. - Registers *syscall.PtraceRegs - - // FloatingPointState is a byte pointer where floating point state is - // saved and restored. - FloatingPointState *byte - - // PageTables are the application page tables. - PageTables *pagetables.PageTables - - // Flush indicates that a TLB flush should be forced on switch. - Flush bool - - // FullRestore indicates that an iret-based restore should be used. - FullRestore bool - - // SwitchArchOpts are architecture-specific options. - SwitchArchOpts -} diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go deleted file mode 100644 index 7206322b1..000000000 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - -package ring0 - -import ( - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" -) - -// Segment indices and Selectors. -const ( - // Index into GDT array. - _ = iota // Null descriptor first. - _ // Reserved (Linux is kernel 32). - segKcode // Kernel code (64-bit). - segKdata // Kernel data. - segUcode32 // User code (32-bit). - segUdata // User data. - segUcode64 // User code (64-bit). - segTss // Task segment descriptor. - segTssHi // Upper bits for TSS. - segLast // Last segment (terminal, not included). -) - -// Selectors. -const ( - Kcode Selector = segKcode << 3 - Kdata Selector = segKdata << 3 - Ucode32 Selector = (segUcode32 << 3) | 3 - Udata Selector = (segUdata << 3) | 3 - Ucode64 Selector = (segUcode64 << 3) | 3 - Tss Selector = segTss << 3 -) - -// Standard segments. -var ( - UserCodeSegment32 SegmentDescriptor - UserDataSegment SegmentDescriptor - UserCodeSegment64 SegmentDescriptor - KernelCodeSegment SegmentDescriptor - KernelDataSegment SegmentDescriptor -) - -// KernelOpts has initialization options for the kernel. -type KernelOpts struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables -} - -// KernelArchState contains architecture-specific state. -type KernelArchState struct { - KernelOpts - - // globalIDT is our set of interrupt gates. - globalIDT idt64 -} - -// CPUArchState contains CPU-specific arch state. -type CPUArchState struct { - // stack is the stack used for interrupts on this CPU. - stack [256]byte - - // errorCode is the error code from the last exception. - errorCode uintptr - - // errorType indicates the type of error code here, it is always set - // along with the errorCode value above. - // - // It will either by 1, which indicates a user error, or 0 indicating a - // kernel error. If the error code below returns false (kernel error), - // then it cannot provide relevant information about the last - // exception. - errorType uintptr - - // gdt is the CPU's descriptor table. - gdt descriptorTable - - // tss is the CPU's task state. - tss TaskState64 -} - -// ErrorCode returns the last error code. -// -// The returned boolean indicates whether the error code corresponds to the -// last user error or not. If it does not, then fault information must be -// ignored. This is generally the result of a kernel fault while servicing a -// user fault. -// -//go:nosplit -func (c *CPU) ErrorCode() (value uintptr, user bool) { - return c.errorCode, c.errorType != 0 -} - -// ClearErrorCode resets the error code. -// -//go:nosplit -func (c *CPU) ClearErrorCode() { - c.errorCode = 0 // No code. - c.errorType = 1 // User mode. -} - -// SwitchArchOpts are embedded in SwitchOpts. -type SwitchArchOpts struct { - // UserPCID indicates that the application PCID to be used on switch, - // assuming that PCIDs are supported. - // - // Per pagetables_x86.go, a zero PCID implies a flush. - UserPCID uint16 - - // KernelPCID indicates that the kernel PCID to be used on return, - // assuming that PCIDs are supported. - // - // Per pagetables_x86.go, a zero PCID implies a flush. - KernelPCID uint16 -} - -func init() { - KernelCodeSegment.setCode64(0, 0, 0) - KernelDataSegment.setData(0, 0xffffffff, 0) - UserCodeSegment32.setCode64(0, 0, 3) - UserDataSegment.setData(0, 0xffffffff, 3) - UserCodeSegment64.setCode64(0, 0, 3) -} diff --git a/pkg/sentry/platform/ring0/defs_impl.go b/pkg/sentry/platform/ring0/defs_impl.go new file mode 100755 index 000000000..a30a9dd4a --- /dev/null +++ b/pkg/sentry/platform/ring0/defs_impl.go @@ -0,0 +1,538 @@ +package ring0 + +import ( + "gvisor.dev/gvisor/pkg/cpuid" + "io" + "reflect" + "syscall" + + "fmt" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + +// Kernel is a global kernel object. +// +// This contains global state, shared by multiple CPUs. +type Kernel struct { + KernelArchState +} + +// Hooks are hooks for kernel functions. +type Hooks interface { + // KernelSyscall is called for kernel system calls. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelSyscall() + + // KernelException handles an exception during kernel execution. + // + // Return from this call will restore registers and return to the kernel: the + // registers must be modified directly. + // + // If this function is not provided, a kernel exception results in halt. + // + // This must be go:nosplit, as this will be on the interrupt stack. + // Closures are permitted, as the pointer to the closure frame is not + // passed on the stack. + KernelException(Vector) +} + +// CPU is the per-CPU struct. +type CPU struct { + // self is a self reference. + // + // This is always guaranteed to be at offset zero. + self *CPU + + // kernel is reference to the kernel that this CPU was initialized + // with. This reference is kept for garbage collection purposes: CPU + // registers may refer to objects within the Kernel object that cannot + // be safely freed. + kernel *Kernel + + // CPUArchState is architecture-specific state. + CPUArchState + + // registers is a set of registers; these may be used on kernel system + // calls and exceptions via the Registers function. + registers syscall.PtraceRegs + + // hooks are kernel hooks. + hooks Hooks +} + +// Registers returns a modifiable-copy of the kernel registers. +// +// This is explicitly safe to call during KernelException and KernelSyscall. +// +//go:nosplit +func (c *CPU) Registers() *syscall.PtraceRegs { + return &c.registers +} + +// SwitchOpts are passed to the Switch function. +type SwitchOpts struct { + // Registers are the user register state. + Registers *syscall.PtraceRegs + + // FloatingPointState is a byte pointer where floating point state is + // saved and restored. + FloatingPointState *byte + + // PageTables are the application page tables. + PageTables *pagetables.PageTables + + // Flush indicates that a TLB flush should be forced on switch. + Flush bool + + // FullRestore indicates that an iret-based restore should be used. + FullRestore bool + + // SwitchArchOpts are architecture-specific options. + SwitchArchOpts +} + +// Segment indices and Selectors. +const ( + // Index into GDT array. + _ = iota // Null descriptor first. + _ // Reserved (Linux is kernel 32). + segKcode // Kernel code (64-bit). + segKdata // Kernel data. + segUcode32 // User code (32-bit). + segUdata // User data. + segUcode64 // User code (64-bit). + segTss // Task segment descriptor. + segTssHi // Upper bits for TSS. + segLast // Last segment (terminal, not included). +) + +// Selectors. +const ( + Kcode Selector = segKcode << 3 + Kdata Selector = segKdata << 3 + Ucode32 Selector = (segUcode32 << 3) | 3 + Udata Selector = (segUdata << 3) | 3 + Ucode64 Selector = (segUcode64 << 3) | 3 + Tss Selector = segTss << 3 +) + +// Standard segments. +var ( + UserCodeSegment32 SegmentDescriptor + UserDataSegment SegmentDescriptor + UserCodeSegment64 SegmentDescriptor + KernelCodeSegment SegmentDescriptor + KernelDataSegment SegmentDescriptor +) + +// KernelOpts has initialization options for the kernel. +type KernelOpts struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables +} + +// KernelArchState contains architecture-specific state. +type KernelArchState struct { + KernelOpts + + // globalIDT is our set of interrupt gates. + globalIDT idt64 +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { + // stack is the stack used for interrupts on this CPU. + stack [256]byte + + // errorCode is the error code from the last exception. + errorCode uintptr + + // errorType indicates the type of error code here, it is always set + // along with the errorCode value above. + // + // It will either by 1, which indicates a user error, or 0 indicating a + // kernel error. If the error code below returns false (kernel error), + // then it cannot provide relevant information about the last + // exception. + errorType uintptr + + // gdt is the CPU's descriptor table. + gdt descriptorTable + + // tss is the CPU's task state. + tss TaskState64 +} + +// ErrorCode returns the last error code. +// +// The returned boolean indicates whether the error code corresponds to the +// last user error or not. If it does not, then fault information must be +// ignored. This is generally the result of a kernel fault while servicing a +// user fault. +// +//go:nosplit +func (c *CPU) ErrorCode() (value uintptr, user bool) { + return c.errorCode, c.errorType != 0 +} + +// ClearErrorCode resets the error code. +// +//go:nosplit +func (c *CPU) ClearErrorCode() { + c.errorCode = 0 + c.errorType = 1 +} + +// SwitchArchOpts are embedded in SwitchOpts. +type SwitchArchOpts struct { + // UserPCID indicates that the application PCID to be used on switch, + // assuming that PCIDs are supported. + // + // Per pagetables_x86.go, a zero PCID implies a flush. + UserPCID uint16 + + // KernelPCID indicates that the kernel PCID to be used on return, + // assuming that PCIDs are supported. + // + // Per pagetables_x86.go, a zero PCID implies a flush. + KernelPCID uint16 +} + +func init() { + KernelCodeSegment.setCode64(0, 0, 0) + KernelDataSegment.setData(0, 0xffffffff, 0) + UserCodeSegment32.setCode64(0, 0, 3) + UserDataSegment.setData(0, 0xffffffff, 3) + UserCodeSegment64.setCode64(0, 0, 3) +} + +// Emit prints architecture-specific offsets. +func Emit(w io.Writer) { + fmt.Fprintf(w, "// Automatically generated, do not edit.\n") + + c := &CPU{} + fmt.Fprintf(w, "\n// CPU offsets.\n") + fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) + fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) + + fmt.Fprintf(w, "\n// Bits.\n") + fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF) + fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) + + fmt.Fprintf(w, "\n// Vectors.\n") + fmt.Fprintf(w, "#define DivideByZero 0x%02x\n", DivideByZero) + fmt.Fprintf(w, "#define Debug 0x%02x\n", Debug) + fmt.Fprintf(w, "#define NMI 0x%02x\n", NMI) + fmt.Fprintf(w, "#define Breakpoint 0x%02x\n", Breakpoint) + fmt.Fprintf(w, "#define Overflow 0x%02x\n", Overflow) + fmt.Fprintf(w, "#define BoundRangeExceeded 0x%02x\n", BoundRangeExceeded) + fmt.Fprintf(w, "#define InvalidOpcode 0x%02x\n", InvalidOpcode) + fmt.Fprintf(w, "#define DeviceNotAvailable 0x%02x\n", DeviceNotAvailable) + fmt.Fprintf(w, "#define DoubleFault 0x%02x\n", DoubleFault) + fmt.Fprintf(w, "#define CoprocessorSegmentOverrun 0x%02x\n", CoprocessorSegmentOverrun) + fmt.Fprintf(w, "#define InvalidTSS 0x%02x\n", InvalidTSS) + fmt.Fprintf(w, "#define SegmentNotPresent 0x%02x\n", SegmentNotPresent) + fmt.Fprintf(w, "#define StackSegmentFault 0x%02x\n", StackSegmentFault) + fmt.Fprintf(w, "#define GeneralProtectionFault 0x%02x\n", GeneralProtectionFault) + fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) + fmt.Fprintf(w, "#define X87FloatingPointException 0x%02x\n", X87FloatingPointException) + fmt.Fprintf(w, "#define AlignmentCheck 0x%02x\n", AlignmentCheck) + fmt.Fprintf(w, "#define MachineCheck 0x%02x\n", MachineCheck) + fmt.Fprintf(w, "#define SIMDFloatingPointException 0x%02x\n", SIMDFloatingPointException) + fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException) + fmt.Fprintf(w, "#define SecurityException 0x%02x\n", SecurityException) + fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80) + fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) + + p := &syscall.PtraceRegs{} + fmt.Fprintf(w, "\n// Ptrace registers.\n") + fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.R13).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.R12).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RBP 0x%02x\n", reflect.ValueOf(&p.Rbp).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RBX 0x%02x\n", reflect.ValueOf(&p.Rbx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.R11).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.R10).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.R9).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.R8).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RAX 0x%02x\n", reflect.ValueOf(&p.Rax).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RCX 0x%02x\n", reflect.ValueOf(&p.Rcx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RDX 0x%02x\n", reflect.ValueOf(&p.Rdx).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RSI 0x%02x\n", reflect.ValueOf(&p.Rsi).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RDI 0x%02x\n", reflect.ValueOf(&p.Rdi).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_ORIGRAX 0x%02x\n", reflect.ValueOf(&p.Orig_rax).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RIP 0x%02x\n", reflect.ValueOf(&p.Rip).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_CS 0x%02x\n", reflect.ValueOf(&p.Cs).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_FLAGS 0x%02x\n", reflect.ValueOf(&p.Eflags).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_RSP 0x%02x\n", reflect.ValueOf(&p.Rsp).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_SS 0x%02x\n", reflect.ValueOf(&p.Ss).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_FS 0x%02x\n", reflect.ValueOf(&p.Fs_base).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_GS 0x%02x\n", reflect.ValueOf(&p.Gs_base).Pointer()-reflect.ValueOf(p).Pointer()) +} + +// Useful bits. +const ( + _CR0_PE = 1 << 0 + _CR0_ET = 1 << 4 + _CR0_AM = 1 << 18 + _CR0_PG = 1 << 31 + + _CR4_PSE = 1 << 4 + _CR4_PAE = 1 << 5 + _CR4_PGE = 1 << 7 + _CR4_OSFXSR = 1 << 9 + _CR4_OSXMMEXCPT = 1 << 10 + _CR4_FSGSBASE = 1 << 16 + _CR4_PCIDE = 1 << 17 + _CR4_OSXSAVE = 1 << 18 + _CR4_SMEP = 1 << 20 + + _RFLAGS_AC = 1 << 18 + _RFLAGS_NT = 1 << 14 + _RFLAGS_IOPL = 3 << 12 + _RFLAGS_DF = 1 << 10 + _RFLAGS_IF = 1 << 9 + _RFLAGS_STEP = 1 << 8 + _RFLAGS_RESERVED = 1 << 1 + + _EFER_SCE = 0x001 + _EFER_LME = 0x100 + _EFER_LMA = 0x400 + _EFER_NX = 0x800 + + _MSR_STAR = 0xc0000081 + _MSR_LSTAR = 0xc0000082 + _MSR_CSTAR = 0xc0000083 + _MSR_SYSCALL_MASK = 0xc0000084 + _MSR_PLATFORM_INFO = 0xce + _MSR_MISC_FEATURES = 0x140 + + _PLATFORM_INFO_CPUID_FAULT = 1 << 31 + + _MISC_FEATURE_CPUID_TRAP = 0x1 +) + +const ( + // KernelFlagsSet should always be set in the kernel. + KernelFlagsSet = _RFLAGS_RESERVED + + // UserFlagsSet are always set in userspace. + UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF + + // KernelFlagsClear should always be clear in the kernel. + KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT + + // UserFlagsClear are always cleared in userspace. + UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL +) + +// Vector is an exception vector. +type Vector uintptr + +// Exception vectors. +const ( + DivideByZero Vector = iota + Debug + NMI + Breakpoint + Overflow + BoundRangeExceeded + InvalidOpcode + DeviceNotAvailable + DoubleFault + CoprocessorSegmentOverrun + InvalidTSS + SegmentNotPresent + StackSegmentFault + GeneralProtectionFault + PageFault + _ + X87FloatingPointException + AlignmentCheck + MachineCheck + SIMDFloatingPointException + VirtualizationException + SecurityException = 0x1e + SyscallInt80 = 0x80 + _NR_INTERRUPTS = SyscallInt80 + 1 +) + +// System call vectors. +const ( + Syscall Vector = _NR_INTERRUPTS +) + +// VirtualAddressBits returns the number bits available for virtual addresses. +// +// Note that sign-extension semantics apply to the highest order bit. +// +// FIXME(b/69382326): This should use the cpuid passed to Init. +func VirtualAddressBits() uint32 { + ax, _, _, _ := cpuid.HostID(0x80000008, 0) + return (ax >> 8) & 0xff +} + +// PhysicalAddressBits returns the number of bits available for physical addresses. +// +// FIXME(b/69382326): This should use the cpuid passed to Init. +func PhysicalAddressBits() uint32 { + ax, _, _, _ := cpuid.HostID(0x80000008, 0) + return ax & 0xff +} + +// Selector is a segment Selector. +type Selector uint16 + +// SegmentDescriptor is a segment descriptor. +type SegmentDescriptor struct { + bits [2]uint32 +} + +// descriptorTable is a collection of descriptors. +type descriptorTable [32]SegmentDescriptor + +// SegmentDescriptorFlags are typed flags within a descriptor. +type SegmentDescriptorFlags uint32 + +// SegmentDescriptorFlag declarations. +const ( + SegmentDescriptorAccess SegmentDescriptorFlags = 1 << 8 // Access bit (always set). + SegmentDescriptorWrite = 1 << 9 // Write permission. + SegmentDescriptorExpandDown = 1 << 10 // Grows down, not used. + SegmentDescriptorExecute = 1 << 11 // Execute permission. + SegmentDescriptorSystem = 1 << 12 // Zero => system, 1 => user code/data. + SegmentDescriptorPresent = 1 << 15 // Present. + SegmentDescriptorAVL = 1 << 20 // Available. + SegmentDescriptorLong = 1 << 21 // Long mode. + SegmentDescriptorDB = 1 << 22 // 16 or 32-bit. + SegmentDescriptorG = 1 << 23 // Granularity: page or byte. +) + +// Base returns the descriptor's base linear address. +func (d *SegmentDescriptor) Base() uint32 { + return d.bits[1]&0xFF000000 | (d.bits[1]&0x000000FF)<<16 | d.bits[0]>>16 +} + +// Limit returns the descriptor size. +func (d *SegmentDescriptor) Limit() uint32 { + l := d.bits[0]&0xFFFF | d.bits[1]&0xF0000 + if d.bits[1]&uint32(SegmentDescriptorG) != 0 { + l <<= 12 + l |= 0xFFF + } + return l +} + +// Flags returns descriptor flags. +func (d *SegmentDescriptor) Flags() SegmentDescriptorFlags { + return SegmentDescriptorFlags(d.bits[1] & 0x00F09F00) +} + +// DPL returns the descriptor privilege level. +func (d *SegmentDescriptor) DPL() int { + return int((d.bits[1] >> 13) & 3) +} + +func (d *SegmentDescriptor) setNull() { + d.bits[0] = 0 + d.bits[1] = 0 +} + +func (d *SegmentDescriptor) set(base, limit uint32, dpl int, flags SegmentDescriptorFlags) { + flags |= SegmentDescriptorPresent + if limit>>12 != 0 { + limit >>= 12 + flags |= SegmentDescriptorG + } + d.bits[0] = base<<16 | limit&0xFFFF + d.bits[1] = base&0xFF000000 | (base>>16)&0xFF | limit&0x000F0000 | uint32(flags) | uint32(dpl)<<13 +} + +func (d *SegmentDescriptor) setCode32(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorDB| + SegmentDescriptorExecute| + SegmentDescriptorSystem) +} + +func (d *SegmentDescriptor) setCode64(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorG| + SegmentDescriptorLong| + SegmentDescriptorExecute| + SegmentDescriptorSystem) +} + +func (d *SegmentDescriptor) setData(base, limit uint32, dpl int) { + d.set(base, limit, dpl, + SegmentDescriptorWrite| + SegmentDescriptorSystem) +} + +// setHi is only used for the TSS segment, which is magically 64-bits. +func (d *SegmentDescriptor) setHi(base uint32) { + d.bits[0] = base + d.bits[1] = 0 +} + +// Gate64 is a 64-bit task, trap, or interrupt gate. +type Gate64 struct { + bits [4]uint32 +} + +// idt64 is a 64-bit interrupt descriptor table. +type idt64 [_NR_INTERRUPTS]Gate64 + +func (g *Gate64) setInterrupt(cs Selector, rip uint64, dpl int, ist int) { + g.bits[0] = uint32(cs)<<16 | uint32(rip)&0xFFFF + g.bits[1] = uint32(rip)&0xFFFF0000 | SegmentDescriptorPresent | uint32(dpl)<<13 | 14<<8 | uint32(ist)&0x7 + g.bits[2] = uint32(rip >> 32) +} + +func (g *Gate64) setTrap(cs Selector, rip uint64, dpl int, ist int) { + g.setInterrupt(cs, rip, dpl, ist) + g.bits[1] |= 1 << 8 +} + +// TaskState64 is a 64-bit task state structure. +type TaskState64 struct { + _ uint32 + rsp0Lo, rsp0Hi uint32 + rsp1Lo, rsp1Hi uint32 + rsp2Lo, rsp2Hi uint32 + _ [2]uint32 + ist1Lo, ist1Hi uint32 + ist2Lo, ist2Hi uint32 + ist3Lo, ist3Hi uint32 + ist4Lo, ist4Hi uint32 + ist5Lo, ist5Hi uint32 + ist6Lo, ist6Hi uint32 + ist7Lo, ist7Hi uint32 + _ [2]uint32 + _ uint16 + ioPerm uint16 +} diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_impl_amd64.s index 8cb8c4996..d082d06a9 100644..100755 --- a/pkg/sentry/platform/ring0/entry_amd64.s +++ b/pkg/sentry/platform/ring0/entry_impl_amd64.s @@ -1,3 +1,67 @@ +// build +amd64 + +// Automatically generated, do not edit. + +// CPU offsets. +#define CPU_SELF 0x00 +#define CPU_REGISTERS 0x288 +#define CPU_STACK_TOP 0x110 +#define CPU_ERROR_CODE 0x110 +#define CPU_ERROR_TYPE 0x118 + +// Bits. +#define _RFLAGS_IF 0x200 +#define _KERNEL_FLAGS 0x02 + +// Vectors. +#define DivideByZero 0x00 +#define Debug 0x01 +#define NMI 0x02 +#define Breakpoint 0x03 +#define Overflow 0x04 +#define BoundRangeExceeded 0x05 +#define InvalidOpcode 0x06 +#define DeviceNotAvailable 0x07 +#define DoubleFault 0x08 +#define CoprocessorSegmentOverrun 0x09 +#define InvalidTSS 0x0a +#define SegmentNotPresent 0x0b +#define StackSegmentFault 0x0c +#define GeneralProtectionFault 0x0d +#define PageFault 0x0e +#define X87FloatingPointException 0x10 +#define AlignmentCheck 0x11 +#define MachineCheck 0x12 +#define SIMDFloatingPointException 0x13 +#define VirtualizationException 0x14 +#define SecurityException 0x1e +#define SyscallInt80 0x80 +#define Syscall 0x81 + +// Ptrace registers. +#define PTRACE_R15 0x00 +#define PTRACE_R14 0x08 +#define PTRACE_R13 0x10 +#define PTRACE_R12 0x18 +#define PTRACE_RBP 0x20 +#define PTRACE_RBX 0x28 +#define PTRACE_R11 0x30 +#define PTRACE_R10 0x38 +#define PTRACE_R9 0x40 +#define PTRACE_R8 0x48 +#define PTRACE_RAX 0x50 +#define PTRACE_RCX 0x58 +#define PTRACE_RDX 0x60 +#define PTRACE_RSI 0x68 +#define PTRACE_RDI 0x70 +#define PTRACE_ORIGRAX 0x78 +#define PTRACE_RIP 0x80 +#define PTRACE_CS 0x88 +#define PTRACE_FLAGS 0x90 +#define PTRACE_RSP 0x98 +#define PTRACE_SS 0xa0 +#define PTRACE_FS 0xa8 +#define PTRACE_GS 0xb0 // Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD deleted file mode 100644 index d7029d5a9..000000000 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") - -go_template_instance( - name = "defs_impl", - out = "defs_impl.go", - package = "main", - template = "//pkg/sentry/platform/ring0:defs", -) - -go_binary( - name = "gen_offsets", - srcs = [ - "defs_impl.go", - "main.go", - ], - visibility = ["//pkg/sentry/platform/ring0:__pkg__"], - deps = [ - "//pkg/cpuid", - "//pkg/sentry/platform/ring0/pagetables", - "//pkg/sentry/usermem", - ], -) diff --git a/pkg/sentry/platform/ring0/gen_offsets/main.go b/pkg/sentry/platform/ring0/gen_offsets/main.go deleted file mode 100644 index a4927da2f..000000000 --- a/pkg/sentry/platform/ring0/gen_offsets/main.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Binary gen_offsets is a helper for generating offset headers. -package main - -import ( - "os" -) - -func main() { - Emit(os.Stdout) -} diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go deleted file mode 100644 index 85cc3fdad..000000000 --- a/pkg/sentry/platform/ring0/offsets_amd64.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - -package ring0 - -import ( - "fmt" - "io" - "reflect" - "syscall" -) - -// Emit prints architecture-specific offsets. -func Emit(w io.Writer) { - fmt.Fprintf(w, "// Automatically generated, do not edit.\n") - - c := &CPU{} - fmt.Fprintf(w, "\n// CPU offsets.\n") - fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) - fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) - - fmt.Fprintf(w, "\n// Bits.\n") - fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF) - fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) - - fmt.Fprintf(w, "\n// Vectors.\n") - fmt.Fprintf(w, "#define DivideByZero 0x%02x\n", DivideByZero) - fmt.Fprintf(w, "#define Debug 0x%02x\n", Debug) - fmt.Fprintf(w, "#define NMI 0x%02x\n", NMI) - fmt.Fprintf(w, "#define Breakpoint 0x%02x\n", Breakpoint) - fmt.Fprintf(w, "#define Overflow 0x%02x\n", Overflow) - fmt.Fprintf(w, "#define BoundRangeExceeded 0x%02x\n", BoundRangeExceeded) - fmt.Fprintf(w, "#define InvalidOpcode 0x%02x\n", InvalidOpcode) - fmt.Fprintf(w, "#define DeviceNotAvailable 0x%02x\n", DeviceNotAvailable) - fmt.Fprintf(w, "#define DoubleFault 0x%02x\n", DoubleFault) - fmt.Fprintf(w, "#define CoprocessorSegmentOverrun 0x%02x\n", CoprocessorSegmentOverrun) - fmt.Fprintf(w, "#define InvalidTSS 0x%02x\n", InvalidTSS) - fmt.Fprintf(w, "#define SegmentNotPresent 0x%02x\n", SegmentNotPresent) - fmt.Fprintf(w, "#define StackSegmentFault 0x%02x\n", StackSegmentFault) - fmt.Fprintf(w, "#define GeneralProtectionFault 0x%02x\n", GeneralProtectionFault) - fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) - fmt.Fprintf(w, "#define X87FloatingPointException 0x%02x\n", X87FloatingPointException) - fmt.Fprintf(w, "#define AlignmentCheck 0x%02x\n", AlignmentCheck) - fmt.Fprintf(w, "#define MachineCheck 0x%02x\n", MachineCheck) - fmt.Fprintf(w, "#define SIMDFloatingPointException 0x%02x\n", SIMDFloatingPointException) - fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException) - fmt.Fprintf(w, "#define SecurityException 0x%02x\n", SecurityException) - fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80) - fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) - - p := &syscall.PtraceRegs{} - fmt.Fprintf(w, "\n// Ptrace registers.\n") - fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.R13).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.R12).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RBP 0x%02x\n", reflect.ValueOf(&p.Rbp).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RBX 0x%02x\n", reflect.ValueOf(&p.Rbx).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.R11).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.R10).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.R9).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.R8).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RAX 0x%02x\n", reflect.ValueOf(&p.Rax).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RCX 0x%02x\n", reflect.ValueOf(&p.Rcx).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RDX 0x%02x\n", reflect.ValueOf(&p.Rdx).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RSI 0x%02x\n", reflect.ValueOf(&p.Rsi).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RDI 0x%02x\n", reflect.ValueOf(&p.Rdi).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_ORIGRAX 0x%02x\n", reflect.ValueOf(&p.Orig_rax).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RIP 0x%02x\n", reflect.ValueOf(&p.Rip).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_CS 0x%02x\n", reflect.ValueOf(&p.Cs).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_FLAGS 0x%02x\n", reflect.ValueOf(&p.Eflags).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_RSP 0x%02x\n", reflect.ValueOf(&p.Rsp).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_SS 0x%02x\n", reflect.ValueOf(&p.Ss).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_FS 0x%02x\n", reflect.ValueOf(&p.Fs_base).Pointer()-reflect.ValueOf(p).Pointer()) - fmt.Fprintf(w, "#define PTRACE_GS 0x%02x\n", reflect.ValueOf(&p.Gs_base).Pointer()-reflect.ValueOf(p).Pointer()) -} diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD deleted file mode 100644 index 3b95af617..000000000 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ /dev/null @@ -1,105 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - -go_template( - name = "generic_walker", - srcs = [ - "walker_amd64.go", - ], - opt_types = [ - "Visitor", - ], - visibility = [":__pkg__"], -) - -go_template_instance( - name = "walker_map", - out = "walker_map.go", - package = "pagetables", - prefix = "map", - template = ":generic_walker", - types = { - "Visitor": "mapVisitor", - }, -) - -go_template_instance( - name = "walker_unmap", - out = "walker_unmap.go", - package = "pagetables", - prefix = "unmap", - template = ":generic_walker", - types = { - "Visitor": "unmapVisitor", - }, -) - -go_template_instance( - name = "walker_lookup", - out = "walker_lookup.go", - package = "pagetables", - prefix = "lookup", - template = ":generic_walker", - types = { - "Visitor": "lookupVisitor", - }, -) - -go_template_instance( - name = "walker_empty", - out = "walker_empty.go", - package = "pagetables", - prefix = "empty", - template = ":generic_walker", - types = { - "Visitor": "emptyVisitor", - }, -) - -go_template_instance( - name = "walker_check", - out = "walker_check.go", - package = "pagetables", - prefix = "check", - template = ":generic_walker", - types = { - "Visitor": "checkVisitor", - }, -) - -go_library( - name = "pagetables", - srcs = [ - "allocator.go", - "allocator_unsafe.go", - "pagetables.go", - "pagetables_amd64.go", - "pagetables_x86.go", - "pcids_x86.go", - "walker_empty.go", - "walker_lookup.go", - "walker_map.go", - "walker_unmap.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables", - visibility = [ - "//pkg/sentry/platform/kvm:__subpackages__", - "//pkg/sentry/platform/ring0:__subpackages__", - ], - deps = ["//pkg/sentry/usermem"], -) - -go_test( - name = "pagetables_test", - size = "small", - srcs = [ - "pagetables_amd64_test.go", - "pagetables_test.go", - "walker_check.go", - ], - embed = [":pagetables"], - deps = ["//pkg/sentry/usermem"], -) diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go deleted file mode 100644 index 35e917526..000000000 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - -package pagetables - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -func Test2MAnd4K(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a small page and a huge page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: usermem.Read}, pmdSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x00007f0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read}}, - }) -} - -func Test1GAnd4K(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a small page and a super page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: usermem.Read}, pudSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x00007f0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read}}, - }) -} - -func TestSplit1GPage(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a super page and knock out the middle. - pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: usermem.Read}, pudSize*42) - pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pudSize-(2*pteSize)) - - checkMappings(t, pt, []mapping{ - {0x00007f0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read}}, - {0x00007f0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read}}, - }) -} - -func TestSplit2MPage(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map a huge page and knock out the middle. - pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: usermem.Read}, pmdSize*42) - pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pmdSize-(2*pteSize)) - - checkMappings(t, pt, []mapping{ - {0x00007f0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read}}, - {0x00007f0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read}}, - }) -} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_state_autogen.go b/pkg/sentry/platform/ring0/pagetables/pagetables_state_autogen.go new file mode 100755 index 000000000..ac1ccf3d3 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package pagetables + diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_test.go deleted file mode 100644 index 6e95ad2b9..000000000 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pagetables - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/sentry/usermem" -) - -type mapping struct { - start uintptr - length uintptr - addr uintptr - opts MapOpts -} - -type checkVisitor struct { - expected []mapping // Input. - current int // Temporary. - found []mapping // Output. - failed string // Output. -} - -func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) { - v.found = append(v.found, mapping{ - start: start, - length: align + 1, - addr: pte.Address(), - opts: pte.Opts(), - }) - if v.failed != "" { - // Don't keep looking for errors. - return - } - - if v.current >= len(v.expected) { - v.failed = "more mappings than expected" - } else if v.expected[v.current].start != start { - v.failed = "start didn't match expected" - } else if v.expected[v.current].length != (align + 1) { - v.failed = "end didn't match expected" - } else if v.expected[v.current].addr != pte.Address() { - v.failed = "address didn't match expected" - } else if v.expected[v.current].opts != pte.Opts() { - v.failed = "opts didn't match" - } - v.current++ -} - -func (*checkVisitor) requiresAlloc() bool { return false } - -func (*checkVisitor) requiresSplit() bool { return false } - -func checkMappings(t *testing.T, pt *PageTables, m []mapping) { - // Iterate over all the mappings. - w := checkWalker{ - pageTables: pt, - visitor: checkVisitor{ - expected: m, - }, - } - w.iterateRange(0, ^uintptr(0)) - - // Were we expected additional mappings? - if w.visitor.failed == "" && w.visitor.current != len(w.visitor.expected) { - w.visitor.failed = "insufficient mappings found" - } - - // Emit a meaningful error message on failure. - if w.visitor.failed != "" { - t.Errorf("%s; got %#v, wanted %#v", w.visitor.failed, w.visitor.found, w.visitor.expected) - } -} - -func TestUnmap(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map and unmap one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Unmap(0x400000, pteSize) - - checkMappings(t, pt, nil) -} - -func TestReadOnly(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.Read}, pteSize*42) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.Read}}, - }) -} - -func TestReadWrite(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map one entry. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - }) -} - -func TestSerialEntries(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map two sequential entries. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x401000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x401000, pteSize, pteSize * 47, MapOpts{AccessType: usermem.ReadWrite}}, - }) -} - -func TestSpanningEntries(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Span a pgd with two pages. - pt.Map(0x00007efffffff000, 2*pteSize, MapOpts{AccessType: usermem.Read}, pteSize*42) - - checkMappings(t, pt, []mapping{ - {0x00007efffffff000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.Read}}, - {0x00007f0000000000, pteSize, pteSize * 43, MapOpts{AccessType: usermem.Read}}, - }) -} - -func TestSparseEntries(t *testing.T) { - pt := New(NewRuntimeAllocator()) - - // Map two entries in different pgds. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42) - pt.Map(0x00007f0000000000, pteSize, MapOpts{AccessType: usermem.Read}, pteSize*47) - - checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}}, - {0x00007f0000000000, pteSize, pteSize * 47, MapOpts{AccessType: usermem.Read}}, - }) -} diff --git a/pkg/sentry/platform/ring0/pagetables/walker_amd64.go b/pkg/sentry/platform/ring0/pagetables/walker_empty.go index 8f9dacd93..417784e17 100644..100755 --- a/pkg/sentry/platform/ring0/pagetables/walker_amd64.go +++ b/pkg/sentry/platform/ring0/pagetables/walker_empty.go @@ -1,42 +1,12 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build amd64 - package pagetables -// Visitor is a generic type. -type Visitor interface { - // visit is called on each PTE. - visit(start uintptr, pte *PTE, align uintptr) - - // requiresAlloc indicates that new entries should be allocated within - // the walked range. - requiresAlloc() bool - - // requiresSplit indicates that entries in the given range should be - // split if they are huge or jumbo pages. - requiresSplit() bool -} - // Walker walks page tables. -type Walker struct { +type emptyWalker struct { // pageTables are the tables to walk. pageTables *PageTables // Visitor is the set of arguments. - visitor Visitor + visitor emptyVisitor } // iterateRange iterates over all appropriate levels of page tables for the given range. @@ -63,7 +33,7 @@ type Walker struct { // non-canonical ranges. If they do, a panic will result. // //go:nosplit -func (w *Walker) iterateRange(start, end uintptr) { +func (w *emptyWalker) iterateRange(start, end uintptr) { if start%pteSize != 0 { panic("unaligned start") } @@ -104,7 +74,7 @@ func (w *Walker) iterateRange(start, end uintptr) { // next returns the next address quantized by the given size. // //go:nosplit -func next(start uintptr, size uintptr) uintptr { +func emptynext(start uintptr, size uintptr) uintptr { start &= ^(size - 1) start += size return start @@ -113,7 +83,7 @@ func next(start uintptr, size uintptr) uintptr { // iterateRangeCanonical walks a canonical range. // //go:nosplit -func (w *Walker) iterateRangeCanonical(start, end uintptr) { +func (w *emptyWalker) iterateRangeCanonical(start, end uintptr) { for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { var ( pgdEntry = &w.pageTables.root[pgdIndex] @@ -121,19 +91,17 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { ) if !pgdEntry.Valid() { if !w.visitor.requiresAlloc() { - // Skip over this entry. - start = next(start, pgdSize) + + start = emptynext(start, pgdSize) continue } - // Allocate a new pgd. pudEntries = w.pageTables.Allocator.NewPTEs() pgdEntry.setPageTable(w.pageTables, pudEntries) } else { pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) } - // Map the next level. clearPUDEntries := uint16(0) for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { @@ -143,33 +111,28 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { ) if !pudEntry.Valid() { if !w.visitor.requiresAlloc() { - // Skip over this entry. + clearPUDEntries++ - start = next(start, pudSize) + start = emptynext(start, pudSize) continue } - // This level has 1-GB super pages. Is this - // entire region at least as large as a single - // PUD entry? If so, we can skip allocating a - // new page for the pmd. if start&(pudSize-1) == 0 && end-start >= pudSize { pudEntry.SetSuper() w.visitor.visit(uintptr(start), pudEntry, pudSize-1) if pudEntry.Valid() { - start = next(start, pudSize) + start = emptynext(start, pudSize) continue } } - // Allocate a new pud. pmdEntries = w.pageTables.Allocator.NewPTEs() pudEntry.setPageTable(w.pageTables, pmdEntries) } else if pudEntry.IsSuper() { - // Does this page need to be split? - if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < next(start, pudSize)) { - // Install the relevant entries. + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < emptynext(start, pudSize)) { + pmdEntries = w.pageTables.Allocator.NewPTEs() for index := uint16(0); index < entriesPerPage; index++ { pmdEntries[index].SetSuper() @@ -179,23 +142,20 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } pudEntry.setPageTable(w.pageTables, pmdEntries) } else { - // A super page to be checked directly. + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) - // Might have been cleared. if !pudEntry.Valid() { clearPUDEntries++ } - // Note that the super page was changed. - start = next(start, pudSize) + start = emptynext(start, pudSize) continue } } else { pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) } - // Map the next level, since this is valid. clearPMDEntries := uint16(0) for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { @@ -205,32 +165,28 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { ) if !pmdEntry.Valid() { if !w.visitor.requiresAlloc() { - // Skip over this entry. + clearPMDEntries++ - start = next(start, pmdSize) + start = emptynext(start, pmdSize) continue } - // This level has 2-MB huge pages. If this - // region is contined in a single PMD entry? - // As above, we can skip allocating a new page. if start&(pmdSize-1) == 0 && end-start >= pmdSize { pmdEntry.SetSuper() w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) if pmdEntry.Valid() { - start = next(start, pmdSize) + start = emptynext(start, pmdSize) continue } } - // Allocate a new pmd. pteEntries = w.pageTables.Allocator.NewPTEs() pmdEntry.setPageTable(w.pageTables, pteEntries) } else if pmdEntry.IsSuper() { - // Does this page need to be split? - if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < next(start, pmdSize)) { - // Install the relevant entries. + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < emptynext(start, pmdSize)) { + pteEntries = w.pageTables.Allocator.NewPTEs() for index := uint16(0); index < entriesPerPage; index++ { pteEntries[index].Set( @@ -239,23 +195,20 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } pmdEntry.setPageTable(w.pageTables, pteEntries) } else { - // A huge page to be checked directly. + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) - // Might have been cleared. if !pmdEntry.Valid() { clearPMDEntries++ } - // Note that the huge page was changed. - start = next(start, pmdSize) + start = emptynext(start, pmdSize) continue } } else { pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) } - // Map the next level, since this is valid. clearPTEEntries := uint16(0) for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { @@ -268,7 +221,6 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { continue } - // At this point, we are guaranteed that start%pteSize == 0. w.visitor.visit(uintptr(start), pteEntry, pteSize-1) if !pteEntry.Valid() { if w.visitor.requiresAlloc() { @@ -277,12 +229,10 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { clearPTEEntries++ } - // Note that the pte was changed. start += pteSize continue } - // Check if we no longer need this page. if clearPTEEntries == entriesPerPage { pmdEntry.Clear() w.pageTables.Allocator.FreePTEs(pteEntries) @@ -290,7 +240,6 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } } - // Check if we no longer need this page. if clearPMDEntries == entriesPerPage { pudEntry.Clear() w.pageTables.Allocator.FreePTEs(pmdEntries) @@ -298,7 +247,6 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } } - // Check if we no longer need this page. if clearPUDEntries == entriesPerPage { pgdEntry.Clear() w.pageTables.Allocator.FreePTEs(pudEntries) diff --git a/pkg/sentry/platform/ring0/pagetables/walker_lookup.go b/pkg/sentry/platform/ring0/pagetables/walker_lookup.go new file mode 100755 index 000000000..906c9c50f --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/walker_lookup.go @@ -0,0 +1,255 @@ +package pagetables + +// Walker walks page tables. +type lookupWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor lookupVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// +// Precondition: start must be less than end. +// +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *lookupWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func lookupnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *lookupWalker) iterateRangeCanonical(start, end uintptr) { + for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &w.pageTables.root[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = lookupnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = lookupnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if pudEntry.Valid() { + start = lookupnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < lookupnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSuper() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = lookupnext(start, pudSize) + continue + } + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = lookupnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if pmdEntry.Valid() { + start = lookupnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < lookupnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = lookupnext(start, pmdSize) + continue + } + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } +} diff --git a/pkg/sentry/platform/ring0/pagetables/walker_map.go b/pkg/sentry/platform/ring0/pagetables/walker_map.go new file mode 100755 index 000000000..61ee3c825 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/walker_map.go @@ -0,0 +1,255 @@ +package pagetables + +// Walker walks page tables. +type mapWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor mapVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// +// Precondition: start must be less than end. +// +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *mapWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func mapnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *mapWalker) iterateRangeCanonical(start, end uintptr) { + for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &w.pageTables.root[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = mapnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = mapnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if pudEntry.Valid() { + start = mapnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < mapnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSuper() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = mapnext(start, pudSize) + continue + } + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = mapnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if pmdEntry.Valid() { + start = mapnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < mapnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = mapnext(start, pmdSize) + continue + } + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } +} diff --git a/pkg/sentry/platform/ring0/pagetables/walker_unmap.go b/pkg/sentry/platform/ring0/pagetables/walker_unmap.go new file mode 100755 index 000000000..be2aa0ce4 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/walker_unmap.go @@ -0,0 +1,255 @@ +package pagetables + +// Walker walks page tables. +type unmapWalker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor unmapVisitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// +// Precondition: start must be less than end. +// +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *unmapWalker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func unmapnext(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *unmapWalker) iterateRangeCanonical(start, end uintptr) { + for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &w.pageTables.root[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + start = unmapnext(start, pgdSize) + continue + } + + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPUDEntries++ + start = unmapnext(start, pudSize) + continue + } + + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSuper() + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if pudEntry.Valid() { + start = unmapnext(start, pudSize) + continue + } + } + + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < unmapnext(start, pudSize)) { + + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSuper() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + + if !pudEntry.Valid() { + clearPUDEntries++ + } + + start = unmapnext(start, pudSize) + continue + } + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + + clearPMDEntries++ + start = unmapnext(start, pmdSize) + continue + } + + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSuper() + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if pmdEntry.Valid() { + start = unmapnext(start, pmdSize) + continue + } + } + + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSuper() { + + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < unmapnext(start, pmdSize)) { + + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + start = unmapnext(start, pmdSize) + continue + } + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + start += pteSize + continue + } + + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } +} diff --git a/pkg/sentry/platform/ring0/ring0_state_autogen.go b/pkg/sentry/platform/ring0/ring0_state_autogen.go new file mode 100755 index 000000000..462f9a446 --- /dev/null +++ b/pkg/sentry/platform/ring0/ring0_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package ring0 + diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go deleted file mode 100644 index 5f80d64e8..000000000 --- a/pkg/sentry/platform/ring0/x86.go +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build i386 amd64 - -package ring0 - -import ( - "gvisor.dev/gvisor/pkg/cpuid" -) - -// Useful bits. -const ( - _CR0_PE = 1 << 0 - _CR0_ET = 1 << 4 - _CR0_AM = 1 << 18 - _CR0_PG = 1 << 31 - - _CR4_PSE = 1 << 4 - _CR4_PAE = 1 << 5 - _CR4_PGE = 1 << 7 - _CR4_OSFXSR = 1 << 9 - _CR4_OSXMMEXCPT = 1 << 10 - _CR4_FSGSBASE = 1 << 16 - _CR4_PCIDE = 1 << 17 - _CR4_OSXSAVE = 1 << 18 - _CR4_SMEP = 1 << 20 - - _RFLAGS_AC = 1 << 18 - _RFLAGS_NT = 1 << 14 - _RFLAGS_IOPL = 3 << 12 - _RFLAGS_DF = 1 << 10 - _RFLAGS_IF = 1 << 9 - _RFLAGS_STEP = 1 << 8 - _RFLAGS_RESERVED = 1 << 1 - - _EFER_SCE = 0x001 - _EFER_LME = 0x100 - _EFER_LMA = 0x400 - _EFER_NX = 0x800 - - _MSR_STAR = 0xc0000081 - _MSR_LSTAR = 0xc0000082 - _MSR_CSTAR = 0xc0000083 - _MSR_SYSCALL_MASK = 0xc0000084 - _MSR_PLATFORM_INFO = 0xce - _MSR_MISC_FEATURES = 0x140 - - _PLATFORM_INFO_CPUID_FAULT = 1 << 31 - - _MISC_FEATURE_CPUID_TRAP = 0x1 -) - -const ( - // KernelFlagsSet should always be set in the kernel. - KernelFlagsSet = _RFLAGS_RESERVED - - // UserFlagsSet are always set in userspace. - UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF - - // KernelFlagsClear should always be clear in the kernel. - KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT - - // UserFlagsClear are always cleared in userspace. - UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL -) - -// Vector is an exception vector. -type Vector uintptr - -// Exception vectors. -const ( - DivideByZero Vector = iota - Debug - NMI - Breakpoint - Overflow - BoundRangeExceeded - InvalidOpcode - DeviceNotAvailable - DoubleFault - CoprocessorSegmentOverrun - InvalidTSS - SegmentNotPresent - StackSegmentFault - GeneralProtectionFault - PageFault - _ - X87FloatingPointException - AlignmentCheck - MachineCheck - SIMDFloatingPointException - VirtualizationException - SecurityException = 0x1e - SyscallInt80 = 0x80 - _NR_INTERRUPTS = SyscallInt80 + 1 -) - -// System call vectors. -const ( - Syscall Vector = _NR_INTERRUPTS -) - -// VirtualAddressBits returns the number bits available for virtual addresses. -// -// Note that sign-extension semantics apply to the highest order bit. -// -// FIXME(b/69382326): This should use the cpuid passed to Init. -func VirtualAddressBits() uint32 { - ax, _, _, _ := cpuid.HostID(0x80000008, 0) - return (ax >> 8) & 0xff -} - -// PhysicalAddressBits returns the number of bits available for physical addresses. -// -// FIXME(b/69382326): This should use the cpuid passed to Init. -func PhysicalAddressBits() uint32 { - ax, _, _, _ := cpuid.HostID(0x80000008, 0) - return ax & 0xff -} - -// Selector is a segment Selector. -type Selector uint16 - -// SegmentDescriptor is a segment descriptor. -type SegmentDescriptor struct { - bits [2]uint32 -} - -// descriptorTable is a collection of descriptors. -type descriptorTable [32]SegmentDescriptor - -// SegmentDescriptorFlags are typed flags within a descriptor. -type SegmentDescriptorFlags uint32 - -// SegmentDescriptorFlag declarations. -const ( - SegmentDescriptorAccess SegmentDescriptorFlags = 1 << 8 // Access bit (always set). - SegmentDescriptorWrite = 1 << 9 // Write permission. - SegmentDescriptorExpandDown = 1 << 10 // Grows down, not used. - SegmentDescriptorExecute = 1 << 11 // Execute permission. - SegmentDescriptorSystem = 1 << 12 // Zero => system, 1 => user code/data. - SegmentDescriptorPresent = 1 << 15 // Present. - SegmentDescriptorAVL = 1 << 20 // Available. - SegmentDescriptorLong = 1 << 21 // Long mode. - SegmentDescriptorDB = 1 << 22 // 16 or 32-bit. - SegmentDescriptorG = 1 << 23 // Granularity: page or byte. -) - -// Base returns the descriptor's base linear address. -func (d *SegmentDescriptor) Base() uint32 { - return d.bits[1]&0xFF000000 | (d.bits[1]&0x000000FF)<<16 | d.bits[0]>>16 -} - -// Limit returns the descriptor size. -func (d *SegmentDescriptor) Limit() uint32 { - l := d.bits[0]&0xFFFF | d.bits[1]&0xF0000 - if d.bits[1]&uint32(SegmentDescriptorG) != 0 { - l <<= 12 - l |= 0xFFF - } - return l -} - -// Flags returns descriptor flags. -func (d *SegmentDescriptor) Flags() SegmentDescriptorFlags { - return SegmentDescriptorFlags(d.bits[1] & 0x00F09F00) -} - -// DPL returns the descriptor privilege level. -func (d *SegmentDescriptor) DPL() int { - return int((d.bits[1] >> 13) & 3) -} - -func (d *SegmentDescriptor) setNull() { - d.bits[0] = 0 - d.bits[1] = 0 -} - -func (d *SegmentDescriptor) set(base, limit uint32, dpl int, flags SegmentDescriptorFlags) { - flags |= SegmentDescriptorPresent - if limit>>12 != 0 { - limit >>= 12 - flags |= SegmentDescriptorG - } - d.bits[0] = base<<16 | limit&0xFFFF - d.bits[1] = base&0xFF000000 | (base>>16)&0xFF | limit&0x000F0000 | uint32(flags) | uint32(dpl)<<13 -} - -func (d *SegmentDescriptor) setCode32(base, limit uint32, dpl int) { - d.set(base, limit, dpl, - SegmentDescriptorDB| - SegmentDescriptorExecute| - SegmentDescriptorSystem) -} - -func (d *SegmentDescriptor) setCode64(base, limit uint32, dpl int) { - d.set(base, limit, dpl, - SegmentDescriptorG| - SegmentDescriptorLong| - SegmentDescriptorExecute| - SegmentDescriptorSystem) -} - -func (d *SegmentDescriptor) setData(base, limit uint32, dpl int) { - d.set(base, limit, dpl, - SegmentDescriptorWrite| - SegmentDescriptorSystem) -} - -// setHi is only used for the TSS segment, which is magically 64-bits. -func (d *SegmentDescriptor) setHi(base uint32) { - d.bits[0] = base - d.bits[1] = 0 -} - -// Gate64 is a 64-bit task, trap, or interrupt gate. -type Gate64 struct { - bits [4]uint32 -} - -// idt64 is a 64-bit interrupt descriptor table. -type idt64 [_NR_INTERRUPTS]Gate64 - -func (g *Gate64) setInterrupt(cs Selector, rip uint64, dpl int, ist int) { - g.bits[0] = uint32(cs)<<16 | uint32(rip)&0xFFFF - g.bits[1] = uint32(rip)&0xFFFF0000 | SegmentDescriptorPresent | uint32(dpl)<<13 | 14<<8 | uint32(ist)&0x7 - g.bits[2] = uint32(rip >> 32) -} - -func (g *Gate64) setTrap(cs Selector, rip uint64, dpl int, ist int) { - g.setInterrupt(cs, rip, dpl, ist) - g.bits[1] |= 1 << 8 -} - -// TaskState64 is a 64-bit task state structure. -type TaskState64 struct { - _ uint32 - rsp0Lo, rsp0Hi uint32 - rsp1Lo, rsp1Hi uint32 - rsp2Lo, rsp2Hi uint32 - _ [2]uint32 - ist1Lo, ist1Hi uint32 - ist2Lo, ist2Hi uint32 - ist3Lo, ist3Hi uint32 - ist4Lo, ist4Hi uint32 - ist5Lo, ist5Hi uint32 - ist6Lo, ist6Hi uint32 - ist7Lo, ist7Hi uint32 - _ [2]uint32 - _ uint16 - ioPerm uint16 -} diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD deleted file mode 100644 index 924d8a6d6..000000000 --- a/pkg/sentry/platform/safecopy/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "safecopy", - srcs = [ - "atomic_amd64.s", - "atomic_arm64.s", - "memclr_amd64.s", - "memclr_arm64.s", - "memcpy_amd64.s", - "memcpy_arm64.s", - "safecopy.go", - "safecopy_unsafe.go", - "sighandler_amd64.s", - "sighandler_arm64.s", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy", - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/syserror"], -) - -go_test( - name = "safecopy_test", - srcs = [ - "safecopy_test.go", - ], - embed = [":safecopy"], -) diff --git a/pkg/sentry/platform/safecopy/LICENSE b/pkg/sentry/platform/safecopy/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/sentry/platform/safecopy/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/sentry/platform/safecopy/safecopy_state_autogen.go b/pkg/sentry/platform/safecopy/safecopy_state_autogen.go new file mode 100755 index 000000000..58fd8fbd0 --- /dev/null +++ b/pkg/sentry/platform/safecopy/safecopy_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package safecopy + diff --git a/pkg/sentry/platform/safecopy/safecopy_test.go b/pkg/sentry/platform/safecopy/safecopy_test.go deleted file mode 100644 index 5818f7f9b..000000000 --- a/pkg/sentry/platform/safecopy/safecopy_test.go +++ /dev/null @@ -1,617 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safecopy - -import ( - "bytes" - "fmt" - "io/ioutil" - "math/rand" - "os" - "runtime/debug" - "syscall" - "testing" - "unsafe" -) - -// Size of a page in bytes. Cloned from usermem.PageSize to avoid a circular -// dependency. -const pageSize = 4096 - -func initRandom(b []byte) { - for i := range b { - b[i] = byte(rand.Intn(256)) - } -} - -func randBuf(size int) []byte { - b := make([]byte, size) - initRandom(b) - return b -} - -func TestCopyInSuccess(t *testing.T) { - // Test that CopyIn does not return an error when all pages are accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := CopyIn(b, unsafe.Pointer(&a[0])) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestCopyOutSuccess(t *testing.T) { - // Test that CopyOut does not return an error when all pages are - // accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := CopyOut(unsafe.Pointer(&b[0]), a) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestCopySuccess(t *testing.T) { - // Test that Copy does not return an error when all pages are accessible. - const bufLen = 8192 - a := randBuf(bufLen) - b := make([]byte, bufLen) - - n, err := Copy(unsafe.Pointer(&b[0]), unsafe.Pointer(&a[0]), bufLen) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestZeroOutSuccess(t *testing.T) { - // Test that ZeroOut does not return an error when all pages are - // accessible. - const bufLen = 8192 - a := make([]byte, bufLen) - b := randBuf(bufLen) - - n, err := ZeroOut(unsafe.Pointer(&b[0]), bufLen) - if n != bufLen { - t.Errorf("Unexpected copy length, got %v, want %v", n, bufLen) - } - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if !bytes.Equal(a, b) { - t.Errorf("Buffers are not equal when they should be: %v %v", a, b) - } -} - -func TestSwapUint32Success(t *testing.T) { - // Test that SwapUint32 does not return an error when the page is - // accessible. - before := uint32(rand.Int31()) - after := uint32(rand.Int31()) - val := before - - old, err := SwapUint32(unsafe.Pointer(&val), after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if val != after { - t.Errorf("Unexpected new value: got %v, want %v", val, after) - } -} - -func TestSwapUint32AlignmentError(t *testing.T) { - // Test that SwapUint32 returns an AlignmentError when passed an unaligned - // address. - data := new(struct{ val uint64 }) - addr := uintptr(unsafe.Pointer(&data.val)) + 1 - want := AlignmentError{Addr: addr, Alignment: 4} - if _, err := SwapUint32(unsafe.Pointer(addr), 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -func TestSwapUint64Success(t *testing.T) { - // Test that SwapUint64 does not return an error when the page is - // accessible. - before := uint64(rand.Int63()) - after := uint64(rand.Int63()) - // "The first word in ... an allocated struct or slice can be relied upon - // to be 64-bit aligned." - sync/atomic docs - data := new(struct{ val uint64 }) - data.val = before - - old, err := SwapUint64(unsafe.Pointer(&data.val), after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if data.val != after { - t.Errorf("Unexpected new value: got %v, want %v", data.val, after) - } -} - -func TestSwapUint64AlignmentError(t *testing.T) { - // Test that SwapUint64 returns an AlignmentError when passed an unaligned - // address. - data := new(struct{ val1, val2 uint64 }) - addr := uintptr(unsafe.Pointer(&data.val1)) + 1 - want := AlignmentError{Addr: addr, Alignment: 8} - if _, err := SwapUint64(unsafe.Pointer(addr), 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -func TestCompareAndSwapUint32Success(t *testing.T) { - // Test that CompareAndSwapUint32 does not return an error when the page is - // accessible. - before := uint32(rand.Int31()) - after := uint32(rand.Int31()) - val := before - - old, err := CompareAndSwapUint32(unsafe.Pointer(&val), before, after) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if old != before { - t.Errorf("Unexpected old value: got %v, want %v", old, before) - } - if val != after { - t.Errorf("Unexpected new value: got %v, want %v", val, after) - } -} - -func TestCompareAndSwapUint32AlignmentError(t *testing.T) { - // Test that CompareAndSwapUint32 returns an AlignmentError when passed an - // unaligned address. - data := new(struct{ val uint64 }) - addr := uintptr(unsafe.Pointer(&data.val)) + 1 - want := AlignmentError{Addr: addr, Alignment: 4} - if _, err := CompareAndSwapUint32(unsafe.Pointer(addr), 0, 1); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } -} - -// withSegvErrorTestMapping calls fn with a two-page mapping. The first page -// contains random data, and the second page generates SIGSEGV when accessed. -func withSegvErrorTestMapping(t *testing.T, fn func(m []byte)) { - mapping, err := syscall.Mmap(-1, 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_ANONYMOUS|syscall.MAP_PRIVATE) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - } - defer syscall.Munmap(mapping) - if err := syscall.Mprotect(mapping[pageSize:], syscall.PROT_NONE); err != nil { - t.Fatalf("Mprotect failed: %v", err) - } - initRandom(mapping[:pageSize]) - - fn(mapping) -} - -// withBusErrorTestMapping calls fn with a two-page mapping. The first page -// contains random data, and the second page generates SIGBUS when accessed. -func withBusErrorTestMapping(t *testing.T, fn func(m []byte)) { - f, err := ioutil.TempFile("", "sigbus_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - if err := f.Truncate(pageSize); err != nil { - t.Fatalf("Truncate failed: %v", err) - } - mapping, err := syscall.Mmap(int(f.Fd()), 0, 2*pageSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - } - defer syscall.Munmap(mapping) - initRandom(mapping[:pageSize]) - - fn(mapping) -} - -func TestCopyInSegvError(t *testing.T) { - // Test that CopyIn returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - dst := randBuf(pageSize) - n, err := CopyIn(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyInBusError(t *testing.T) { - // Test that CopyIn returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - dst := randBuf(pageSize) - n, err := CopyIn(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyOutSegvError(t *testing.T) { - // Test that CopyOut returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - src := randBuf(pageSize) - n, err := CopyOut(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyOutBusError(t *testing.T) { - // Test that CopyOut returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - src := randBuf(pageSize) - n, err := CopyOut(dst, src) - if n != bytesBeforeFault { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopySourceSegvError(t *testing.T) { - // Test that Copy returns a SegvError when copying from a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - dst := randBuf(pageSize) - n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopySourceBusError(t *testing.T) { - // Test that Copy returns a BusError when copying from a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - dst := randBuf(pageSize) - n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := dst[:bytesBeforeFault], mapping[pageSize-bytesBeforeFault:pageSize]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyDestinationSegvError(t *testing.T) { - // Test that Copy returns a SegvError when copying to a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - src := randBuf(pageSize) - n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestCopyDestinationBusError(t *testing.T) { - // Test that Copy returns a BusError when copying to a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - src := randBuf(pageSize) - n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected copy length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], src[:bytesBeforeFault]; !bytes.Equal(got, want) { - t.Errorf("Buffers are not equal when they should be: %v %v", got, want) - } - }) - }) - } -} - -func TestZeroOutSegvError(t *testing.T) { - // Test that ZeroOut returns a SegvError when reaching a page that signals - // SIGSEGV. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting write %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) { - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - n, err := ZeroOut(dst, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) - } - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { - t.Errorf("Non-zero bytes in written part of mapping: %v", got) - } - }) - }) - } -} - -func TestZeroOutBusError(t *testing.T) { - // Test that ZeroOut returns a BusError when reaching a page that signals - // SIGBUS. - for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ { - t.Run(fmt.Sprintf("starting write %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) { - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault)) - n, err := ZeroOut(dst, pageSize) - if n != uintptr(bytesBeforeFault) { - t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault) - } - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - if got, want := mapping[pageSize-bytesBeforeFault:pageSize], make([]byte, bytesBeforeFault); !bytes.Equal(got, want) { - t.Errorf("Non-zero bytes in written part of mapping: %v", got) - } - }) - }) - } -} - -func TestSwapUint32SegvError(t *testing.T) { - // Test that SwapUint32 returns a SegvError when reaching a page that - // signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := SwapUint32(unsafe.Pointer(secondPage), 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint32BusError(t *testing.T) { - // Test that SwapUint32 returns a BusError when reaching a page that - // signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := SwapUint32(unsafe.Pointer(secondPage), 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint64SegvError(t *testing.T) { - // Test that SwapUint64 returns a SegvError when reaching a page that - // signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := SwapUint64(unsafe.Pointer(secondPage), 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestSwapUint64BusError(t *testing.T) { - // Test that SwapUint64 returns a BusError when reaching a page that - // signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := SwapUint64(unsafe.Pointer(secondPage), 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestCompareAndSwapUint32SegvError(t *testing.T) { - // Test that CompareAndSwapUint32 returns a SegvError when reaching a page - // that signals SIGSEGV. - withSegvErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) - if want := (SegvError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func TestCompareAndSwapUint32BusError(t *testing.T) { - // Test that CompareAndSwapUint32 returns a BusError when reaching a page - // that signals SIGBUS. - withBusErrorTestMapping(t, func(mapping []byte) { - secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize - _, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1) - if want := (BusError{secondPage}); err != want { - t.Errorf("Unexpected error: got %v, want %v", err, want) - } - }) -} - -func testCopy(dst, src []byte) (panicked bool) { - defer func() { - if r := recover(); r != nil { - panicked = true - } - }() - debug.SetPanicOnFault(true) - copy(dst, src) - return -} - -func TestSegVOnMemmove(t *testing.T) { - // Test that SIGSEGVs received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - a, err := syscall.Mmap(-1, 0, bufLen, syscall.PROT_NONE, syscall.MAP_ANON|syscall.MAP_PRIVATE) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer syscall.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} - -func TestSigbusOnMemmove(t *testing.T) { - // Test that SIGBUS received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - f, err := ioutil.TempFile("", "sigbus_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - os.Remove(f.Name()) - defer f.Close() - - a, err := syscall.Mmap(int(f.Fd()), 0, bufLen, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer syscall.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD deleted file mode 100644 index fd6dc8e6e..000000000 --- a/pkg/sentry/safemem/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "safemem", - srcs = [ - "block_unsafe.go", - "io.go", - "safemem.go", - "seq_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/safemem", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/sentry/platform/safecopy", - ], -) - -go_test( - name = "safemem_test", - size = "small", - srcs = [ - "io_test.go", - "seq_test.go", - ], - embed = [":safemem"], -) diff --git a/pkg/sentry/safemem/io_test.go b/pkg/sentry/safemem/io_test.go deleted file mode 100644 index 629741bee..000000000 --- a/pkg/sentry/safemem/io_test.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safemem - -import ( - "bytes" - "io" - "testing" -) - -func makeBlocks(slices ...[]byte) []Block { - blocks := make([]Block, 0, len(slices)) - for _, s := range slices { - blocks = append(blocks, BlockFromSafeSlice(s)) - } - return blocks -} - -func TestFromIOReaderFullRead(t *testing.T) { - r := FromIOReader{bytes.NewBufferString("foobar")} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -type eofHidingReader struct { - Reader io.Reader -} - -func (r eofHidingReader) Read(dst []byte) (int, error) { - n, err := r.Reader.Read(dst) - if err == io.EOF { - return n, nil - } - return n, err -} - -func TestFromIOReaderPartialRead(t *testing.T) { - r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - // FromIOReader should stop after the eofHidingReader returns (1, nil) - // for a 3-byte read. - if wantN := uint64(4); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -type singleByteReader struct { - Reader io.Reader -} - -func (r singleByteReader) Read(dst []byte) (int, error) { - if len(dst) == 0 { - return r.Reader.Read(dst) - } - return r.Reader.Read(dst[:1]) -} - -func TestSingleByteReader(t *testing.T) { - r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) - // FromIOReader should stop after the singleByteReader returns (1, nil) - // for a 3-byte read. - if wantN := uint64(1); n != wantN || err != nil { - t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -func TestReadFullToBlocks(t *testing.T) { - r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} - dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) - n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts)) - // ReadFullToBlocks should call into FromIOReader => singleByteReader - // repeatedly until dsts is exhausted. - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { - if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { - t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) - } - } -} - -func TestFromIOWriterFullWrite(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{&dst} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -type limitedWriter struct { - Writer io.Writer - Done int - Limit int -} - -func (w *limitedWriter) Write(src []byte) (int, error) { - count := len(src) - if count > (w.Limit - w.Done) { - count = w.Limit - w.Done - } - n, err := w.Writer.Write(src[:count]) - w.Done += n - return n, err -} - -func TestFromIOWriterPartialWrite(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{&limitedWriter{&dst, 0, 4}} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - // FromIOWriter should stop after the limitedWriter returns (1, nil) for a - // 3-byte write. - if wantN := uint64(4); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -type singleByteWriter struct { - Writer io.Writer -} - -func (w singleByteWriter) Write(src []byte) (int, error) { - if len(src) == 0 { - return w.Writer.Write(src) - } - return w.Writer.Write(src[:1]) -} - -func TestSingleByteWriter(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{singleByteWriter{&dst}} - n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) - // FromIOWriter should stop after the singleByteWriter returns (1, nil) - // for a 3-byte write. - if wantN := uint64(1); n != wantN || err != nil { - t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestWriteFullToBlocks(t *testing.T) { - srcs := makeBlocks([]byte("foo"), []byte("bar")) - var dst bytes.Buffer - w := FromIOWriter{singleByteWriter{&dst}} - n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs)) - // WriteFullToBlocks should call into FromIOWriter => singleByteWriter - // repeatedly until srcs is exhausted. - if wantN := uint64(6); n != wantN || err != nil { - t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} diff --git a/pkg/sentry/safemem/safemem_state_autogen.go b/pkg/sentry/safemem/safemem_state_autogen.go new file mode 100755 index 000000000..7264df0b1 --- /dev/null +++ b/pkg/sentry/safemem/safemem_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package safemem + diff --git a/pkg/sentry/safemem/seq_test.go b/pkg/sentry/safemem/seq_test.go deleted file mode 100644 index eba4bb535..000000000 --- a/pkg/sentry/safemem/seq_test.go +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package safemem - -import ( - "bytes" - "reflect" - "testing" -) - -type blockSeqTest struct { - desc string - - pieces []string - haveOffset bool - offset uint64 - haveLimit bool - limit uint64 - - want string -} - -func (t blockSeqTest) NonEmptyByteSlices() [][]byte { - // t is a value, so we can mutate it freely. - slices := make([][]byte, 0, len(t.pieces)) - for _, str := range t.pieces { - if t.haveOffset { - strOff := t.offset - if strOff > uint64(len(str)) { - strOff = uint64(len(str)) - } - str = str[strOff:] - t.offset -= strOff - } - if t.haveLimit { - strLim := t.limit - if strLim > uint64(len(str)) { - strLim = uint64(len(str)) - } - str = str[:strLim] - t.limit -= strLim - } - if len(str) != 0 { - slices = append(slices, []byte(str)) - } - } - return slices -} - -func (t blockSeqTest) BlockSeq() BlockSeq { - blocks := make([]Block, 0, len(t.pieces)) - for _, str := range t.pieces { - blocks = append(blocks, BlockFromSafeSlice([]byte(str))) - } - bs := BlockSeqFromSlice(blocks) - if t.haveOffset { - bs = bs.DropFirst64(t.offset) - } - if t.haveLimit { - bs = bs.TakeFirst64(t.limit) - } - return bs -} - -var blockSeqTests = []blockSeqTest{ - { - desc: "Empty sequence", - }, - { - desc: "Sequence of length 1", - pieces: []string{"foobar"}, - want: "foobar", - }, - { - desc: "Sequence of length 2", - pieces: []string{"foo", "bar"}, - want: "foobar", - }, - { - desc: "Empty Blocks", - pieces: []string{"", "foo", "", "", "bar", ""}, - want: "foobar", - }, - { - desc: "Sequence with non-zero offset", - pieces: []string{"foo", "bar"}, - haveOffset: true, - offset: 2, - want: "obar", - }, - { - desc: "Sequence with non-maximal limit", - pieces: []string{"foo", "bar"}, - haveLimit: true, - limit: 5, - want: "fooba", - }, - { - desc: "Sequence with offset and limit", - pieces: []string{"foo", "bar"}, - haveOffset: true, - offset: 2, - haveLimit: true, - limit: 3, - want: "oba", - }, -} - -func TestBlockSeqNumBytes(t *testing.T) { - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - if got, want := test.BlockSeq().NumBytes(), uint64(len(test.want)); got != want { - t.Errorf("NumBytes: got %d, wanted %d", got, want) - } - }) - } -} - -func TestBlockSeqIterBlocks(t *testing.T) { - // Tests BlockSeq iteration using Head/Tail. - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - srcs := test.BlockSeq() - // "Note that a non-nil empty slice and a nil slice ... are not - // deeply equal." - reflect - slices := make([][]byte, 0, 0) - for !srcs.IsEmpty() { - src := srcs.Head() - slices = append(slices, src.ToSlice()) - nextSrcs := srcs.Tail() - if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-uint64(src.Len()); got != want { - t.Fatalf("%v.Tail(): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) - } - srcs = nextSrcs - } - if wantSlices := test.NonEmptyByteSlices(); !reflect.DeepEqual(slices, wantSlices) { - t.Errorf("Accumulated slices: got %v, wanted %v", slices, wantSlices) - } - }) - } -} - -func TestBlockSeqIterBytes(t *testing.T) { - // Tests BlockSeq iteration using Head/DropFirst. - for _, test := range blockSeqTests { - t.Run(test.desc, func(t *testing.T) { - srcs := test.BlockSeq() - var dst bytes.Buffer - for !srcs.IsEmpty() { - src := srcs.Head() - var b [1]byte - n, err := Copy(BlockFromSafeSlice(b[:]), src) - if n != 1 || err != nil { - t.Fatalf("Copy: got (%v, %v), wanted (1, nil)", n, err) - } - dst.WriteByte(b[0]) - nextSrcs := srcs.DropFirst(1) - if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-1; got != want { - t.Fatalf("%v.DropFirst(1): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) - } - srcs = nextSrcs - } - if got := string(dst.Bytes()); got != test.want { - t.Errorf("Copied string: got %q, wanted %q", got, test.want) - } - }) - } -} - -func TestBlockSeqDropBeyondLimit(t *testing.T) { - blocks := []Block{BlockFromSafeSlice([]byte("123")), BlockFromSafeSlice([]byte("4"))} - bs := BlockSeqFromSlice(blocks) - if got, want := bs.NumBytes(), uint64(4); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } - bs = bs.TakeFirst(1) - if got, want := bs.NumBytes(), uint64(1); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } - bs = bs.DropFirst(2) - if got, want := bs.NumBytes(), uint64(0); got != want { - t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) - } -} diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sentry/sighandling/BUILD deleted file mode 100644 index f561670c7..000000000 --- a/pkg/sentry/sighandling/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "sighandling", - srcs = [ - "sighandling.go", - "sighandling_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/sighandling", - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/abi/linux"], -) diff --git a/pkg/sentry/sighandling/sighandling_state_autogen.go b/pkg/sentry/sighandling/sighandling_state_autogen.go new file mode 100755 index 000000000..dad4bdda2 --- /dev/null +++ b/pkg/sentry/sighandling/sighandling_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package sighandling + diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD deleted file mode 100644 index 7a24d4806..000000000 --- a/pkg/sentry/socket/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "socket", - srcs = ["socket.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", - "//pkg/syserr", - "//pkg/tcpip", - ], -) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD deleted file mode 100644 index 39de46c39..000000000 --- a/pkg/sentry/socket/control/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "control", - srcs = ["control.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/control", - imports = [ - "gvisor.dev/gvisor/pkg/sentry/fs", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/socket/control/control_state_autogen.go b/pkg/sentry/socket/control/control_state_autogen.go new file mode 100755 index 000000000..c5ecfe700 --- /dev/null +++ b/pkg/sentry/socket/control/control_state_autogen.go @@ -0,0 +1,36 @@ +// automatically generated by stateify. + +package control + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/sentry/fs" +) + +func (x *RightsFiles) save(m state.Map) { + m.SaveValue("", ([]*fs.File)(*x)) +} + +func (x *RightsFiles) load(m state.Map) { + m.LoadValue("", new([]*fs.File), func(y interface{}) { *x = (RightsFiles)(y.([]*fs.File)) }) +} + +func (x *scmCredentials) beforeSave() {} +func (x *scmCredentials) save(m state.Map) { + x.beforeSave() + m.Save("t", &x.t) + m.Save("kuid", &x.kuid) + m.Save("kgid", &x.kgid) +} + +func (x *scmCredentials) afterLoad() {} +func (x *scmCredentials) load(m state.Map) { + m.Load("t", &x.t) + m.Load("kuid", &x.kuid) + m.Load("kgid", &x.kgid) +} + +func init() { + state.Register("control.RightsFiles", (*RightsFiles)(nil), state.Fns{Save: (*RightsFiles).save, Load: (*RightsFiles).load}) + state.Register("control.scmCredentials", (*scmCredentials)(nil), state.Fns{Save: (*scmCredentials).save, Load: (*scmCredentials).load}) +} diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD deleted file mode 100644 index 45bb24a3f..000000000 --- a/pkg/sentry/socket/epsocket/BUILD +++ /dev/null @@ -1,49 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "epsocket", - srcs = [ - "device.go", - "epsocket.go", - "provider.go", - "save_restore.go", - "stack.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/epsocket", - visibility = [ - "//pkg/sentry:internal", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/log", - "//pkg/metric", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", - "//pkg/sentry/socket", - "//pkg/sentry/unimpl", - "//pkg/sentry/usermem", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/epsocket/epsocket_state_autogen.go b/pkg/sentry/socket/epsocket/epsocket_state_autogen.go new file mode 100755 index 000000000..8a743a4ef --- /dev/null +++ b/pkg/sentry/socket/epsocket/epsocket_state_autogen.go @@ -0,0 +1,54 @@ +// automatically generated by stateify. + +package epsocket + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SocketOperations) beforeSave() {} +func (x *SocketOperations) save(m state.Map) { + x.beforeSave() + m.Save("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Save("Queue", &x.Queue) + m.Save("family", &x.family) + m.Save("Endpoint", &x.Endpoint) + m.Save("skType", &x.skType) + m.Save("protocol", &x.protocol) + m.Save("readView", &x.readView) + m.Save("readCM", &x.readCM) + m.Save("sender", &x.sender) + m.Save("sockOptTimestamp", &x.sockOptTimestamp) + m.Save("timestampValid", &x.timestampValid) + m.Save("timestampNS", &x.timestampNS) +} + +func (x *SocketOperations) afterLoad() {} +func (x *SocketOperations) load(m state.Map) { + m.Load("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Load("Queue", &x.Queue) + m.Load("family", &x.family) + m.Load("Endpoint", &x.Endpoint) + m.Load("skType", &x.skType) + m.Load("protocol", &x.protocol) + m.Load("readView", &x.readView) + m.Load("readCM", &x.readCM) + m.Load("sender", &x.sender) + m.Load("sockOptTimestamp", &x.sockOptTimestamp) + m.Load("timestampValid", &x.timestampValid) + m.Load("timestampNS", &x.timestampNS) +} + +func (x *Stack) beforeSave() {} +func (x *Stack) save(m state.Map) { + x.beforeSave() +} + +func (x *Stack) load(m state.Map) { + m.AfterLoad(x.afterLoad) +} + +func init() { + state.Register("epsocket.SocketOperations", (*SocketOperations)(nil), state.Fns{Save: (*SocketOperations).save, Load: (*SocketOperations).load}) + state.Register("epsocket.Stack", (*Stack)(nil), state.Fns{Save: (*Stack).save, Load: (*Stack).load}) +} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD deleted file mode 100644 index 4f670beb4..000000000 --- a/pkg/sentry/socket/hostinet/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "hostinet", - srcs = [ - "device.go", - "hostinet.go", - "save_restore.go", - "socket.go", - "socket_unsafe.go", - "stack.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/hostinet", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/fdnotifier", - "//pkg/log", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", - "//pkg/sentry/socket", - "//pkg/sentry/usermem", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/hostinet/hostinet_state_autogen.go b/pkg/sentry/socket/hostinet/hostinet_state_autogen.go new file mode 100755 index 000000000..0a5c7cdf3 --- /dev/null +++ b/pkg/sentry/socket/hostinet/hostinet_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package hostinet + diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD deleted file mode 100644 index f6b001b63..000000000 --- a/pkg/sentry/socket/netlink/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "netlink", - srcs = [ - "message.go", - "provider.go", - "socket.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/netlink/port", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/netlink/netlink_state_autogen.go b/pkg/sentry/socket/netlink/netlink_state_autogen.go new file mode 100755 index 000000000..794187ec0 --- /dev/null +++ b/pkg/sentry/socket/netlink/netlink_state_autogen.go @@ -0,0 +1,38 @@ +// automatically generated by stateify. + +package netlink + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Socket) beforeSave() {} +func (x *Socket) save(m state.Map) { + x.beforeSave() + m.Save("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Save("ports", &x.ports) + m.Save("protocol", &x.protocol) + m.Save("skType", &x.skType) + m.Save("ep", &x.ep) + m.Save("connection", &x.connection) + m.Save("bound", &x.bound) + m.Save("portID", &x.portID) + m.Save("sendBufferSize", &x.sendBufferSize) +} + +func (x *Socket) afterLoad() {} +func (x *Socket) load(m state.Map) { + m.Load("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Load("ports", &x.ports) + m.Load("protocol", &x.protocol) + m.Load("skType", &x.skType) + m.Load("ep", &x.ep) + m.Load("connection", &x.connection) + m.Load("bound", &x.bound) + m.Load("portID", &x.portID) + m.Load("sendBufferSize", &x.sendBufferSize) +} + +func init() { + state.Register("netlink.Socket", (*Socket)(nil), state.Fns{Save: (*Socket).save, Load: (*Socket).load}) +} diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD deleted file mode 100644 index 9e2e12799..000000000 --- a/pkg/sentry/socket/netlink/port/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "port", - srcs = ["port.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port", - visibility = ["//pkg/sentry:internal"], -) - -go_test( - name = "port_test", - srcs = ["port_test.go"], - embed = [":port"], -) diff --git a/pkg/sentry/socket/netlink/port/port_state_autogen.go b/pkg/sentry/socket/netlink/port/port_state_autogen.go new file mode 100755 index 000000000..b4b4a6758 --- /dev/null +++ b/pkg/sentry/socket/netlink/port/port_state_autogen.go @@ -0,0 +1,22 @@ +// automatically generated by stateify. + +package port + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Manager) beforeSave() {} +func (x *Manager) save(m state.Map) { + x.beforeSave() + m.Save("ports", &x.ports) +} + +func (x *Manager) afterLoad() {} +func (x *Manager) load(m state.Map) { + m.Load("ports", &x.ports) +} + +func init() { + state.Register("port.Manager", (*Manager)(nil), state.Fns{Save: (*Manager).save, Load: (*Manager).load}) +} diff --git a/pkg/sentry/socket/netlink/port/port_test.go b/pkg/sentry/socket/netlink/port/port_test.go deleted file mode 100644 index 516f6cd6c..000000000 --- a/pkg/sentry/socket/netlink/port/port_test.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package port - -import ( - "testing" -) - -func TestAllocateHint(t *testing.T) { - m := New() - - // We can get the hint port. - p, ok := m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(0, 1) got %d want 1", p) - } - - // Hint is taken. - p, ok = m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p == 1 { - t.Errorf("m.Allocate(0, 1) got 1 want anything else") - } - - // Hint is available for a different protocol. - p, ok = m.Allocate(1, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(1, 1) got %d want 1", p) - } - - m.Release(0, 1) - - // Hint is available again after release. - p, ok = m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(0, 1) got %d want 1", p) - } -} - -func TestAllocateExhausted(t *testing.T) { - m := New() - - // Fill all ports (0 is already reserved). - for i := int32(1); i < maxPorts; i++ { - p, ok := m.Allocate(0, i) - if !ok { - t.Fatalf("m.Allocate got !ok want ok") - } - if p != i { - t.Fatalf("m.Allocate(0, %d) got %d want %d", i, p, i) - } - } - - // Now no more can be allocated. - p, ok := m.Allocate(0, 1) - if ok { - t.Errorf("m.Allocate got %d, ok want !ok", p) - } -} diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD deleted file mode 100644 index 5dc8533ec..000000000 --- a/pkg/sentry/socket/netlink/route/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "route", - srcs = ["protocol.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/route", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/context", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/socket/netlink", - "//pkg/syserr", - ], -) diff --git a/pkg/sentry/socket/netlink/route/route_state_autogen.go b/pkg/sentry/socket/netlink/route/route_state_autogen.go new file mode 100755 index 000000000..f4225cb3c --- /dev/null +++ b/pkg/sentry/socket/netlink/route/route_state_autogen.go @@ -0,0 +1,20 @@ +// automatically generated by stateify. + +package route + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Protocol) beforeSave() {} +func (x *Protocol) save(m state.Map) { + x.beforeSave() +} + +func (x *Protocol) afterLoad() {} +func (x *Protocol) load(m state.Map) { +} + +func init() { + state.Register("route.Protocol", (*Protocol)(nil), state.Fns{Save: (*Protocol).save, Load: (*Protocol).load}) +} diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD deleted file mode 100644 index 96d374383..000000000 --- a/pkg/sentry/socket/rpcinet/BUILD +++ /dev/null @@ -1,60 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "rpcinet", - srcs = [ - "device.go", - "rpcinet.go", - "socket.go", - "stack.go", - "stack_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet", - visibility = ["//pkg/sentry:internal"], - deps = [ - ":syscall_rpc_go_proto", - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/hostinet", - "//pkg/sentry/socket/rpcinet/conn", - "//pkg/sentry/socket/rpcinet/notifier", - "//pkg/sentry/unimpl", - "//pkg/sentry/usermem", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/unet", - "//pkg/waiter", - ], -) - -proto_library( - name = "syscall_rpc_proto", - srcs = ["syscall_rpc.proto"], - visibility = [ - "//visibility:public", - ], -) - -go_proto_library( - name = "syscall_rpc_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto", - proto = ":syscall_rpc_proto", - visibility = [ - "//visibility:public", - ], -) diff --git a/pkg/sentry/socket/rpcinet/conn/BUILD b/pkg/sentry/socket/rpcinet/conn/BUILD deleted file mode 100644 index 23eadcb1b..000000000 --- a/pkg/sentry/socket/rpcinet/conn/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "conn", - srcs = ["conn.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/binary", - "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", - "//pkg/syserr", - "//pkg/unet", - "@com_github_golang_protobuf//proto:go_default_library", - ], -) diff --git a/pkg/sentry/socket/rpcinet/conn/conn_state_autogen.go b/pkg/sentry/socket/rpcinet/conn/conn_state_autogen.go new file mode 100755 index 000000000..f6c927a60 --- /dev/null +++ b/pkg/sentry/socket/rpcinet/conn/conn_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package conn + diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD deleted file mode 100644 index a536f2e44..000000000 --- a/pkg/sentry/socket/rpcinet/notifier/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "notifier", - srcs = ["notifier.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", - "//pkg/sentry/socket/rpcinet/conn", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier_state_autogen.go b/pkg/sentry/socket/rpcinet/notifier/notifier_state_autogen.go new file mode 100755 index 000000000..f108d91c1 --- /dev/null +++ b/pkg/sentry/socket/rpcinet/notifier/notifier_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package notifier + diff --git a/pkg/sentry/socket/rpcinet/rpcinet_state_autogen.go b/pkg/sentry/socket/rpcinet/rpcinet_state_autogen.go new file mode 100755 index 000000000..d3076c7e3 --- /dev/null +++ b/pkg/sentry/socket/rpcinet/rpcinet_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package rpcinet + diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto deleted file mode 100644 index 9586f5923..000000000 --- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto +++ /dev/null @@ -1,353 +0,0 @@ -syntax = "proto3"; - -// package syscall_rpc is a set of networking related system calls that can be -// forwarded to a socket gofer. -// -// TODO(b/77963526): Document individual RPCs. -package syscall_rpc; - -message SendmsgRequest { - uint32 fd = 1; - bytes data = 2 [ctype = CORD]; - bytes address = 3; - bool more = 4; - bool end_of_record = 5; -} - -message SendmsgResponse { - oneof result { - uint32 error_number = 1; - uint32 length = 2; - } -} - -message IOCtlRequest { - uint32 fd = 1; - uint32 cmd = 2; - bytes arg = 3; -} - -message IOCtlResponse { - oneof result { - uint32 error_number = 1; - bytes value = 2; - } -} - -message RecvmsgRequest { - uint32 fd = 1; - uint32 length = 2; - bool sender = 3; - bool peek = 4; - bool trunc = 5; - uint32 cmsg_length = 6; -} - -message OpenRequest { - bytes path = 1; - uint32 flags = 2; - uint32 mode = 3; -} - -message OpenResponse { - oneof result { - uint32 error_number = 1; - uint32 fd = 2; - } -} - -message ReadRequest { - uint32 fd = 1; - uint32 length = 2; -} - -message ReadResponse { - oneof result { - uint32 error_number = 1; - bytes data = 2 [ctype = CORD]; - } -} - -message ReadFileRequest { - string path = 1; -} - -message ReadFileResponse { - oneof result { - uint32 error_number = 1; - bytes data = 2 [ctype = CORD]; - } -} - -message WriteRequest { - uint32 fd = 1; - bytes data = 2 [ctype = CORD]; -} - -message WriteResponse { - oneof result { - uint32 error_number = 1; - uint32 length = 2; - } -} - -message WriteFileRequest { - string path = 1; - bytes content = 2; -} - -message WriteFileResponse { - uint32 error_number = 1; - uint32 written = 2; -} - -message AddressResponse { - bytes address = 1; - uint32 length = 2; -} - -message RecvmsgResponse { - message ResultPayload { - bytes data = 1 [ctype = CORD]; - AddressResponse address = 2; - uint32 length = 3; - bytes cmsg_data = 4; - } - oneof result { - uint32 error_number = 1; - ResultPayload payload = 2; - } -} - -message BindRequest { - uint32 fd = 1; - bytes address = 2; -} - -message BindResponse { - uint32 error_number = 1; -} - -message AcceptRequest { - uint32 fd = 1; - bool peer = 2; - int64 flags = 3; -} - -message AcceptResponse { - message ResultPayload { - uint32 fd = 1; - AddressResponse address = 2; - } - oneof result { - uint32 error_number = 1; - ResultPayload payload = 2; - } -} - -message ConnectRequest { - uint32 fd = 1; - bytes address = 2; -} - -message ConnectResponse { - uint32 error_number = 1; -} - -message ListenRequest { - uint32 fd = 1; - int64 backlog = 2; -} - -message ListenResponse { - uint32 error_number = 1; -} - -message ShutdownRequest { - uint32 fd = 1; - int64 how = 2; -} - -message ShutdownResponse { - uint32 error_number = 1; -} - -message CloseRequest { - uint32 fd = 1; -} - -message CloseResponse { - uint32 error_number = 1; -} - -message GetSockOptRequest { - uint32 fd = 1; - int64 level = 2; - int64 name = 3; - uint32 length = 4; -} - -message GetSockOptResponse { - oneof result { - uint32 error_number = 1; - bytes opt = 2; - } -} - -message SetSockOptRequest { - uint32 fd = 1; - int64 level = 2; - int64 name = 3; - bytes opt = 4; -} - -message SetSockOptResponse { - uint32 error_number = 1; -} - -message GetSockNameRequest { - uint32 fd = 1; -} - -message GetSockNameResponse { - oneof result { - uint32 error_number = 1; - AddressResponse address = 2; - } -} - -message GetPeerNameRequest { - uint32 fd = 1; -} - -message GetPeerNameResponse { - oneof result { - uint32 error_number = 1; - AddressResponse address = 2; - } -} - -message SocketRequest { - int64 family = 1; - int64 type = 2; - int64 protocol = 3; -} - -message SocketResponse { - oneof result { - uint32 error_number = 1; - uint32 fd = 2; - } -} - -message EpollWaitRequest { - uint32 fd = 1; - uint32 num_events = 2; - sint64 msec = 3; -} - -message EpollEvent { - uint32 fd = 1; - uint32 events = 2; -} - -message EpollEvents { - repeated EpollEvent events = 1; -} - -message EpollWaitResponse { - oneof result { - uint32 error_number = 1; - EpollEvents events = 2; - } -} - -message EpollCtlRequest { - uint32 epfd = 1; - int64 op = 2; - uint32 fd = 3; - EpollEvent event = 4; -} - -message EpollCtlResponse { - uint32 error_number = 1; -} - -message EpollCreate1Request { - int64 flag = 1; -} - -message EpollCreate1Response { - oneof result { - uint32 error_number = 1; - uint32 fd = 2; - } -} - -message PollRequest { - uint32 fd = 1; - uint32 events = 2; -} - -message PollResponse { - oneof result { - uint32 error_number = 1; - uint32 events = 2; - } -} - -message SyscallRequest { - oneof args { - SocketRequest socket = 1; - SendmsgRequest sendmsg = 2; - RecvmsgRequest recvmsg = 3; - BindRequest bind = 4; - AcceptRequest accept = 5; - ConnectRequest connect = 6; - ListenRequest listen = 7; - ShutdownRequest shutdown = 8; - CloseRequest close = 9; - GetSockOptRequest get_sock_opt = 10; - SetSockOptRequest set_sock_opt = 11; - GetSockNameRequest get_sock_name = 12; - GetPeerNameRequest get_peer_name = 13; - EpollWaitRequest epoll_wait = 14; - EpollCtlRequest epoll_ctl = 15; - EpollCreate1Request epoll_create1 = 16; - PollRequest poll = 17; - ReadRequest read = 18; - WriteRequest write = 19; - OpenRequest open = 20; - IOCtlRequest ioctl = 21; - WriteFileRequest write_file = 22; - ReadFileRequest read_file = 23; - } -} - -message SyscallResponse { - oneof result { - SocketResponse socket = 1; - SendmsgResponse sendmsg = 2; - RecvmsgResponse recvmsg = 3; - BindResponse bind = 4; - AcceptResponse accept = 5; - ConnectResponse connect = 6; - ListenResponse listen = 7; - ShutdownResponse shutdown = 8; - CloseResponse close = 9; - GetSockOptResponse get_sock_opt = 10; - SetSockOptResponse set_sock_opt = 11; - GetSockNameResponse get_sock_name = 12; - GetPeerNameResponse get_peer_name = 13; - EpollWaitResponse epoll_wait = 14; - EpollCtlResponse epoll_ctl = 15; - EpollCreate1Response epoll_create1 = 16; - PollResponse poll = 17; - ReadResponse read = 18; - WriteResponse write = 19; - OpenResponse open = 20; - IOCtlResponse ioctl = 21; - WriteFileResponse write_file = 22; - ReadFileResponse read_file = 23; - } -} diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto/syscall_rpc.pb.go b/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto/syscall_rpc.pb.go new file mode 100755 index 000000000..fb68d5294 --- /dev/null +++ b/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto/syscall_rpc.pb.go @@ -0,0 +1,3938 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/sentry/socket/rpcinet/syscall_rpc.proto + +package syscall_rpc + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type SendmsgRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + Address []byte `protobuf:"bytes,3,opt,name=address,proto3" json:"address,omitempty"` + More bool `protobuf:"varint,4,opt,name=more,proto3" json:"more,omitempty"` + EndOfRecord bool `protobuf:"varint,5,opt,name=end_of_record,json=endOfRecord,proto3" json:"end_of_record,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SendmsgRequest) Reset() { *m = SendmsgRequest{} } +func (m *SendmsgRequest) String() string { return proto.CompactTextString(m) } +func (*SendmsgRequest) ProtoMessage() {} +func (*SendmsgRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{0} +} + +func (m *SendmsgRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SendmsgRequest.Unmarshal(m, b) +} +func (m *SendmsgRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SendmsgRequest.Marshal(b, m, deterministic) +} +func (m *SendmsgRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_SendmsgRequest.Merge(m, src) +} +func (m *SendmsgRequest) XXX_Size() int { + return xxx_messageInfo_SendmsgRequest.Size(m) +} +func (m *SendmsgRequest) XXX_DiscardUnknown() { + xxx_messageInfo_SendmsgRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_SendmsgRequest proto.InternalMessageInfo + +func (m *SendmsgRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *SendmsgRequest) GetData() []byte { + if m != nil { + return m.Data + } + return nil +} + +func (m *SendmsgRequest) GetAddress() []byte { + if m != nil { + return m.Address + } + return nil +} + +func (m *SendmsgRequest) GetMore() bool { + if m != nil { + return m.More + } + return false +} + +func (m *SendmsgRequest) GetEndOfRecord() bool { + if m != nil { + return m.EndOfRecord + } + return false +} + +type SendmsgResponse struct { + // Types that are valid to be assigned to Result: + // *SendmsgResponse_ErrorNumber + // *SendmsgResponse_Length + Result isSendmsgResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SendmsgResponse) Reset() { *m = SendmsgResponse{} } +func (m *SendmsgResponse) String() string { return proto.CompactTextString(m) } +func (*SendmsgResponse) ProtoMessage() {} +func (*SendmsgResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{1} +} + +func (m *SendmsgResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SendmsgResponse.Unmarshal(m, b) +} +func (m *SendmsgResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SendmsgResponse.Marshal(b, m, deterministic) +} +func (m *SendmsgResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_SendmsgResponse.Merge(m, src) +} +func (m *SendmsgResponse) XXX_Size() int { + return xxx_messageInfo_SendmsgResponse.Size(m) +} +func (m *SendmsgResponse) XXX_DiscardUnknown() { + xxx_messageInfo_SendmsgResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_SendmsgResponse proto.InternalMessageInfo + +type isSendmsgResponse_Result interface { + isSendmsgResponse_Result() +} + +type SendmsgResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type SendmsgResponse_Length struct { + Length uint32 `protobuf:"varint,2,opt,name=length,proto3,oneof"` +} + +func (*SendmsgResponse_ErrorNumber) isSendmsgResponse_Result() {} + +func (*SendmsgResponse_Length) isSendmsgResponse_Result() {} + +func (m *SendmsgResponse) GetResult() isSendmsgResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *SendmsgResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*SendmsgResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *SendmsgResponse) GetLength() uint32 { + if x, ok := m.GetResult().(*SendmsgResponse_Length); ok { + return x.Length + } + return 0 +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*SendmsgResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*SendmsgResponse_ErrorNumber)(nil), + (*SendmsgResponse_Length)(nil), + } +} + +type IOCtlRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Cmd uint32 `protobuf:"varint,2,opt,name=cmd,proto3" json:"cmd,omitempty"` + Arg []byte `protobuf:"bytes,3,opt,name=arg,proto3" json:"arg,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *IOCtlRequest) Reset() { *m = IOCtlRequest{} } +func (m *IOCtlRequest) String() string { return proto.CompactTextString(m) } +func (*IOCtlRequest) ProtoMessage() {} +func (*IOCtlRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{2} +} + +func (m *IOCtlRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_IOCtlRequest.Unmarshal(m, b) +} +func (m *IOCtlRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_IOCtlRequest.Marshal(b, m, deterministic) +} +func (m *IOCtlRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_IOCtlRequest.Merge(m, src) +} +func (m *IOCtlRequest) XXX_Size() int { + return xxx_messageInfo_IOCtlRequest.Size(m) +} +func (m *IOCtlRequest) XXX_DiscardUnknown() { + xxx_messageInfo_IOCtlRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_IOCtlRequest proto.InternalMessageInfo + +func (m *IOCtlRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *IOCtlRequest) GetCmd() uint32 { + if m != nil { + return m.Cmd + } + return 0 +} + +func (m *IOCtlRequest) GetArg() []byte { + if m != nil { + return m.Arg + } + return nil +} + +type IOCtlResponse struct { + // Types that are valid to be assigned to Result: + // *IOCtlResponse_ErrorNumber + // *IOCtlResponse_Value + Result isIOCtlResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *IOCtlResponse) Reset() { *m = IOCtlResponse{} } +func (m *IOCtlResponse) String() string { return proto.CompactTextString(m) } +func (*IOCtlResponse) ProtoMessage() {} +func (*IOCtlResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{3} +} + +func (m *IOCtlResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_IOCtlResponse.Unmarshal(m, b) +} +func (m *IOCtlResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_IOCtlResponse.Marshal(b, m, deterministic) +} +func (m *IOCtlResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_IOCtlResponse.Merge(m, src) +} +func (m *IOCtlResponse) XXX_Size() int { + return xxx_messageInfo_IOCtlResponse.Size(m) +} +func (m *IOCtlResponse) XXX_DiscardUnknown() { + xxx_messageInfo_IOCtlResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_IOCtlResponse proto.InternalMessageInfo + +type isIOCtlResponse_Result interface { + isIOCtlResponse_Result() +} + +type IOCtlResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type IOCtlResponse_Value struct { + Value []byte `protobuf:"bytes,2,opt,name=value,proto3,oneof"` +} + +func (*IOCtlResponse_ErrorNumber) isIOCtlResponse_Result() {} + +func (*IOCtlResponse_Value) isIOCtlResponse_Result() {} + +func (m *IOCtlResponse) GetResult() isIOCtlResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *IOCtlResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*IOCtlResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *IOCtlResponse) GetValue() []byte { + if x, ok := m.GetResult().(*IOCtlResponse_Value); ok { + return x.Value + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*IOCtlResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*IOCtlResponse_ErrorNumber)(nil), + (*IOCtlResponse_Value)(nil), + } +} + +type RecvmsgRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Length uint32 `protobuf:"varint,2,opt,name=length,proto3" json:"length,omitempty"` + Sender bool `protobuf:"varint,3,opt,name=sender,proto3" json:"sender,omitempty"` + Peek bool `protobuf:"varint,4,opt,name=peek,proto3" json:"peek,omitempty"` + Trunc bool `protobuf:"varint,5,opt,name=trunc,proto3" json:"trunc,omitempty"` + CmsgLength uint32 `protobuf:"varint,6,opt,name=cmsg_length,json=cmsgLength,proto3" json:"cmsg_length,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *RecvmsgRequest) Reset() { *m = RecvmsgRequest{} } +func (m *RecvmsgRequest) String() string { return proto.CompactTextString(m) } +func (*RecvmsgRequest) ProtoMessage() {} +func (*RecvmsgRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{4} +} + +func (m *RecvmsgRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_RecvmsgRequest.Unmarshal(m, b) +} +func (m *RecvmsgRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_RecvmsgRequest.Marshal(b, m, deterministic) +} +func (m *RecvmsgRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_RecvmsgRequest.Merge(m, src) +} +func (m *RecvmsgRequest) XXX_Size() int { + return xxx_messageInfo_RecvmsgRequest.Size(m) +} +func (m *RecvmsgRequest) XXX_DiscardUnknown() { + xxx_messageInfo_RecvmsgRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_RecvmsgRequest proto.InternalMessageInfo + +func (m *RecvmsgRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *RecvmsgRequest) GetLength() uint32 { + if m != nil { + return m.Length + } + return 0 +} + +func (m *RecvmsgRequest) GetSender() bool { + if m != nil { + return m.Sender + } + return false +} + +func (m *RecvmsgRequest) GetPeek() bool { + if m != nil { + return m.Peek + } + return false +} + +func (m *RecvmsgRequest) GetTrunc() bool { + if m != nil { + return m.Trunc + } + return false +} + +func (m *RecvmsgRequest) GetCmsgLength() uint32 { + if m != nil { + return m.CmsgLength + } + return 0 +} + +type OpenRequest struct { + Path []byte `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + Flags uint32 `protobuf:"varint,2,opt,name=flags,proto3" json:"flags,omitempty"` + Mode uint32 `protobuf:"varint,3,opt,name=mode,proto3" json:"mode,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *OpenRequest) Reset() { *m = OpenRequest{} } +func (m *OpenRequest) String() string { return proto.CompactTextString(m) } +func (*OpenRequest) ProtoMessage() {} +func (*OpenRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{5} +} + +func (m *OpenRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_OpenRequest.Unmarshal(m, b) +} +func (m *OpenRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_OpenRequest.Marshal(b, m, deterministic) +} +func (m *OpenRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_OpenRequest.Merge(m, src) +} +func (m *OpenRequest) XXX_Size() int { + return xxx_messageInfo_OpenRequest.Size(m) +} +func (m *OpenRequest) XXX_DiscardUnknown() { + xxx_messageInfo_OpenRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_OpenRequest proto.InternalMessageInfo + +func (m *OpenRequest) GetPath() []byte { + if m != nil { + return m.Path + } + return nil +} + +func (m *OpenRequest) GetFlags() uint32 { + if m != nil { + return m.Flags + } + return 0 +} + +func (m *OpenRequest) GetMode() uint32 { + if m != nil { + return m.Mode + } + return 0 +} + +type OpenResponse struct { + // Types that are valid to be assigned to Result: + // *OpenResponse_ErrorNumber + // *OpenResponse_Fd + Result isOpenResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *OpenResponse) Reset() { *m = OpenResponse{} } +func (m *OpenResponse) String() string { return proto.CompactTextString(m) } +func (*OpenResponse) ProtoMessage() {} +func (*OpenResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{6} +} + +func (m *OpenResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_OpenResponse.Unmarshal(m, b) +} +func (m *OpenResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_OpenResponse.Marshal(b, m, deterministic) +} +func (m *OpenResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_OpenResponse.Merge(m, src) +} +func (m *OpenResponse) XXX_Size() int { + return xxx_messageInfo_OpenResponse.Size(m) +} +func (m *OpenResponse) XXX_DiscardUnknown() { + xxx_messageInfo_OpenResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_OpenResponse proto.InternalMessageInfo + +type isOpenResponse_Result interface { + isOpenResponse_Result() +} + +type OpenResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type OpenResponse_Fd struct { + Fd uint32 `protobuf:"varint,2,opt,name=fd,proto3,oneof"` +} + +func (*OpenResponse_ErrorNumber) isOpenResponse_Result() {} + +func (*OpenResponse_Fd) isOpenResponse_Result() {} + +func (m *OpenResponse) GetResult() isOpenResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *OpenResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*OpenResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *OpenResponse) GetFd() uint32 { + if x, ok := m.GetResult().(*OpenResponse_Fd); ok { + return x.Fd + } + return 0 +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*OpenResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*OpenResponse_ErrorNumber)(nil), + (*OpenResponse_Fd)(nil), + } +} + +type ReadRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Length uint32 `protobuf:"varint,2,opt,name=length,proto3" json:"length,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ReadRequest) Reset() { *m = ReadRequest{} } +func (m *ReadRequest) String() string { return proto.CompactTextString(m) } +func (*ReadRequest) ProtoMessage() {} +func (*ReadRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{7} +} + +func (m *ReadRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ReadRequest.Unmarshal(m, b) +} +func (m *ReadRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ReadRequest.Marshal(b, m, deterministic) +} +func (m *ReadRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ReadRequest.Merge(m, src) +} +func (m *ReadRequest) XXX_Size() int { + return xxx_messageInfo_ReadRequest.Size(m) +} +func (m *ReadRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ReadRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ReadRequest proto.InternalMessageInfo + +func (m *ReadRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *ReadRequest) GetLength() uint32 { + if m != nil { + return m.Length + } + return 0 +} + +type ReadResponse struct { + // Types that are valid to be assigned to Result: + // *ReadResponse_ErrorNumber + // *ReadResponse_Data + Result isReadResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ReadResponse) Reset() { *m = ReadResponse{} } +func (m *ReadResponse) String() string { return proto.CompactTextString(m) } +func (*ReadResponse) ProtoMessage() {} +func (*ReadResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{8} +} + +func (m *ReadResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ReadResponse.Unmarshal(m, b) +} +func (m *ReadResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ReadResponse.Marshal(b, m, deterministic) +} +func (m *ReadResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ReadResponse.Merge(m, src) +} +func (m *ReadResponse) XXX_Size() int { + return xxx_messageInfo_ReadResponse.Size(m) +} +func (m *ReadResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ReadResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ReadResponse proto.InternalMessageInfo + +type isReadResponse_Result interface { + isReadResponse_Result() +} + +type ReadResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type ReadResponse_Data struct { + Data []byte `protobuf:"bytes,2,opt,name=data,proto3,oneof"` +} + +func (*ReadResponse_ErrorNumber) isReadResponse_Result() {} + +func (*ReadResponse_Data) isReadResponse_Result() {} + +func (m *ReadResponse) GetResult() isReadResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *ReadResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*ReadResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *ReadResponse) GetData() []byte { + if x, ok := m.GetResult().(*ReadResponse_Data); ok { + return x.Data + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*ReadResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*ReadResponse_ErrorNumber)(nil), + (*ReadResponse_Data)(nil), + } +} + +type ReadFileRequest struct { + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ReadFileRequest) Reset() { *m = ReadFileRequest{} } +func (m *ReadFileRequest) String() string { return proto.CompactTextString(m) } +func (*ReadFileRequest) ProtoMessage() {} +func (*ReadFileRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{9} +} + +func (m *ReadFileRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ReadFileRequest.Unmarshal(m, b) +} +func (m *ReadFileRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ReadFileRequest.Marshal(b, m, deterministic) +} +func (m *ReadFileRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ReadFileRequest.Merge(m, src) +} +func (m *ReadFileRequest) XXX_Size() int { + return xxx_messageInfo_ReadFileRequest.Size(m) +} +func (m *ReadFileRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ReadFileRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ReadFileRequest proto.InternalMessageInfo + +func (m *ReadFileRequest) GetPath() string { + if m != nil { + return m.Path + } + return "" +} + +type ReadFileResponse struct { + // Types that are valid to be assigned to Result: + // *ReadFileResponse_ErrorNumber + // *ReadFileResponse_Data + Result isReadFileResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ReadFileResponse) Reset() { *m = ReadFileResponse{} } +func (m *ReadFileResponse) String() string { return proto.CompactTextString(m) } +func (*ReadFileResponse) ProtoMessage() {} +func (*ReadFileResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{10} +} + +func (m *ReadFileResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ReadFileResponse.Unmarshal(m, b) +} +func (m *ReadFileResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ReadFileResponse.Marshal(b, m, deterministic) +} +func (m *ReadFileResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ReadFileResponse.Merge(m, src) +} +func (m *ReadFileResponse) XXX_Size() int { + return xxx_messageInfo_ReadFileResponse.Size(m) +} +func (m *ReadFileResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ReadFileResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ReadFileResponse proto.InternalMessageInfo + +type isReadFileResponse_Result interface { + isReadFileResponse_Result() +} + +type ReadFileResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type ReadFileResponse_Data struct { + Data []byte `protobuf:"bytes,2,opt,name=data,proto3,oneof"` +} + +func (*ReadFileResponse_ErrorNumber) isReadFileResponse_Result() {} + +func (*ReadFileResponse_Data) isReadFileResponse_Result() {} + +func (m *ReadFileResponse) GetResult() isReadFileResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *ReadFileResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*ReadFileResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *ReadFileResponse) GetData() []byte { + if x, ok := m.GetResult().(*ReadFileResponse_Data); ok { + return x.Data + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*ReadFileResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*ReadFileResponse_ErrorNumber)(nil), + (*ReadFileResponse_Data)(nil), + } +} + +type WriteRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *WriteRequest) Reset() { *m = WriteRequest{} } +func (m *WriteRequest) String() string { return proto.CompactTextString(m) } +func (*WriteRequest) ProtoMessage() {} +func (*WriteRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{11} +} + +func (m *WriteRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_WriteRequest.Unmarshal(m, b) +} +func (m *WriteRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_WriteRequest.Marshal(b, m, deterministic) +} +func (m *WriteRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_WriteRequest.Merge(m, src) +} +func (m *WriteRequest) XXX_Size() int { + return xxx_messageInfo_WriteRequest.Size(m) +} +func (m *WriteRequest) XXX_DiscardUnknown() { + xxx_messageInfo_WriteRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_WriteRequest proto.InternalMessageInfo + +func (m *WriteRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *WriteRequest) GetData() []byte { + if m != nil { + return m.Data + } + return nil +} + +type WriteResponse struct { + // Types that are valid to be assigned to Result: + // *WriteResponse_ErrorNumber + // *WriteResponse_Length + Result isWriteResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *WriteResponse) Reset() { *m = WriteResponse{} } +func (m *WriteResponse) String() string { return proto.CompactTextString(m) } +func (*WriteResponse) ProtoMessage() {} +func (*WriteResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{12} +} + +func (m *WriteResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_WriteResponse.Unmarshal(m, b) +} +func (m *WriteResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_WriteResponse.Marshal(b, m, deterministic) +} +func (m *WriteResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_WriteResponse.Merge(m, src) +} +func (m *WriteResponse) XXX_Size() int { + return xxx_messageInfo_WriteResponse.Size(m) +} +func (m *WriteResponse) XXX_DiscardUnknown() { + xxx_messageInfo_WriteResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_WriteResponse proto.InternalMessageInfo + +type isWriteResponse_Result interface { + isWriteResponse_Result() +} + +type WriteResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type WriteResponse_Length struct { + Length uint32 `protobuf:"varint,2,opt,name=length,proto3,oneof"` +} + +func (*WriteResponse_ErrorNumber) isWriteResponse_Result() {} + +func (*WriteResponse_Length) isWriteResponse_Result() {} + +func (m *WriteResponse) GetResult() isWriteResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *WriteResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*WriteResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *WriteResponse) GetLength() uint32 { + if x, ok := m.GetResult().(*WriteResponse_Length); ok { + return x.Length + } + return 0 +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*WriteResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*WriteResponse_ErrorNumber)(nil), + (*WriteResponse_Length)(nil), + } +} + +type WriteFileRequest struct { + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + Content []byte `protobuf:"bytes,2,opt,name=content,proto3" json:"content,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *WriteFileRequest) Reset() { *m = WriteFileRequest{} } +func (m *WriteFileRequest) String() string { return proto.CompactTextString(m) } +func (*WriteFileRequest) ProtoMessage() {} +func (*WriteFileRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{13} +} + +func (m *WriteFileRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_WriteFileRequest.Unmarshal(m, b) +} +func (m *WriteFileRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_WriteFileRequest.Marshal(b, m, deterministic) +} +func (m *WriteFileRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_WriteFileRequest.Merge(m, src) +} +func (m *WriteFileRequest) XXX_Size() int { + return xxx_messageInfo_WriteFileRequest.Size(m) +} +func (m *WriteFileRequest) XXX_DiscardUnknown() { + xxx_messageInfo_WriteFileRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_WriteFileRequest proto.InternalMessageInfo + +func (m *WriteFileRequest) GetPath() string { + if m != nil { + return m.Path + } + return "" +} + +func (m *WriteFileRequest) GetContent() []byte { + if m != nil { + return m.Content + } + return nil +} + +type WriteFileResponse struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"` + Written uint32 `protobuf:"varint,2,opt,name=written,proto3" json:"written,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *WriteFileResponse) Reset() { *m = WriteFileResponse{} } +func (m *WriteFileResponse) String() string { return proto.CompactTextString(m) } +func (*WriteFileResponse) ProtoMessage() {} +func (*WriteFileResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{14} +} + +func (m *WriteFileResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_WriteFileResponse.Unmarshal(m, b) +} +func (m *WriteFileResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_WriteFileResponse.Marshal(b, m, deterministic) +} +func (m *WriteFileResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_WriteFileResponse.Merge(m, src) +} +func (m *WriteFileResponse) XXX_Size() int { + return xxx_messageInfo_WriteFileResponse.Size(m) +} +func (m *WriteFileResponse) XXX_DiscardUnknown() { + xxx_messageInfo_WriteFileResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_WriteFileResponse proto.InternalMessageInfo + +func (m *WriteFileResponse) GetErrorNumber() uint32 { + if m != nil { + return m.ErrorNumber + } + return 0 +} + +func (m *WriteFileResponse) GetWritten() uint32 { + if m != nil { + return m.Written + } + return 0 +} + +type AddressResponse struct { + Address []byte `protobuf:"bytes,1,opt,name=address,proto3" json:"address,omitempty"` + Length uint32 `protobuf:"varint,2,opt,name=length,proto3" json:"length,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *AddressResponse) Reset() { *m = AddressResponse{} } +func (m *AddressResponse) String() string { return proto.CompactTextString(m) } +func (*AddressResponse) ProtoMessage() {} +func (*AddressResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{15} +} + +func (m *AddressResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_AddressResponse.Unmarshal(m, b) +} +func (m *AddressResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_AddressResponse.Marshal(b, m, deterministic) +} +func (m *AddressResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_AddressResponse.Merge(m, src) +} +func (m *AddressResponse) XXX_Size() int { + return xxx_messageInfo_AddressResponse.Size(m) +} +func (m *AddressResponse) XXX_DiscardUnknown() { + xxx_messageInfo_AddressResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_AddressResponse proto.InternalMessageInfo + +func (m *AddressResponse) GetAddress() []byte { + if m != nil { + return m.Address + } + return nil +} + +func (m *AddressResponse) GetLength() uint32 { + if m != nil { + return m.Length + } + return 0 +} + +type RecvmsgResponse struct { + // Types that are valid to be assigned to Result: + // *RecvmsgResponse_ErrorNumber + // *RecvmsgResponse_Payload + Result isRecvmsgResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *RecvmsgResponse) Reset() { *m = RecvmsgResponse{} } +func (m *RecvmsgResponse) String() string { return proto.CompactTextString(m) } +func (*RecvmsgResponse) ProtoMessage() {} +func (*RecvmsgResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{16} +} + +func (m *RecvmsgResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_RecvmsgResponse.Unmarshal(m, b) +} +func (m *RecvmsgResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_RecvmsgResponse.Marshal(b, m, deterministic) +} +func (m *RecvmsgResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_RecvmsgResponse.Merge(m, src) +} +func (m *RecvmsgResponse) XXX_Size() int { + return xxx_messageInfo_RecvmsgResponse.Size(m) +} +func (m *RecvmsgResponse) XXX_DiscardUnknown() { + xxx_messageInfo_RecvmsgResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_RecvmsgResponse proto.InternalMessageInfo + +type isRecvmsgResponse_Result interface { + isRecvmsgResponse_Result() +} + +type RecvmsgResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type RecvmsgResponse_Payload struct { + Payload *RecvmsgResponse_ResultPayload `protobuf:"bytes,2,opt,name=payload,proto3,oneof"` +} + +func (*RecvmsgResponse_ErrorNumber) isRecvmsgResponse_Result() {} + +func (*RecvmsgResponse_Payload) isRecvmsgResponse_Result() {} + +func (m *RecvmsgResponse) GetResult() isRecvmsgResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *RecvmsgResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*RecvmsgResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *RecvmsgResponse) GetPayload() *RecvmsgResponse_ResultPayload { + if x, ok := m.GetResult().(*RecvmsgResponse_Payload); ok { + return x.Payload + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*RecvmsgResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*RecvmsgResponse_ErrorNumber)(nil), + (*RecvmsgResponse_Payload)(nil), + } +} + +type RecvmsgResponse_ResultPayload struct { + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + Address *AddressResponse `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"` + Length uint32 `protobuf:"varint,3,opt,name=length,proto3" json:"length,omitempty"` + CmsgData []byte `protobuf:"bytes,4,opt,name=cmsg_data,json=cmsgData,proto3" json:"cmsg_data,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *RecvmsgResponse_ResultPayload) Reset() { *m = RecvmsgResponse_ResultPayload{} } +func (m *RecvmsgResponse_ResultPayload) String() string { return proto.CompactTextString(m) } +func (*RecvmsgResponse_ResultPayload) ProtoMessage() {} +func (*RecvmsgResponse_ResultPayload) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{16, 0} +} + +func (m *RecvmsgResponse_ResultPayload) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_RecvmsgResponse_ResultPayload.Unmarshal(m, b) +} +func (m *RecvmsgResponse_ResultPayload) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_RecvmsgResponse_ResultPayload.Marshal(b, m, deterministic) +} +func (m *RecvmsgResponse_ResultPayload) XXX_Merge(src proto.Message) { + xxx_messageInfo_RecvmsgResponse_ResultPayload.Merge(m, src) +} +func (m *RecvmsgResponse_ResultPayload) XXX_Size() int { + return xxx_messageInfo_RecvmsgResponse_ResultPayload.Size(m) +} +func (m *RecvmsgResponse_ResultPayload) XXX_DiscardUnknown() { + xxx_messageInfo_RecvmsgResponse_ResultPayload.DiscardUnknown(m) +} + +var xxx_messageInfo_RecvmsgResponse_ResultPayload proto.InternalMessageInfo + +func (m *RecvmsgResponse_ResultPayload) GetData() []byte { + if m != nil { + return m.Data + } + return nil +} + +func (m *RecvmsgResponse_ResultPayload) GetAddress() *AddressResponse { + if m != nil { + return m.Address + } + return nil +} + +func (m *RecvmsgResponse_ResultPayload) GetLength() uint32 { + if m != nil { + return m.Length + } + return 0 +} + +func (m *RecvmsgResponse_ResultPayload) GetCmsgData() []byte { + if m != nil { + return m.CmsgData + } + return nil +} + +type BindRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Address []byte `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *BindRequest) Reset() { *m = BindRequest{} } +func (m *BindRequest) String() string { return proto.CompactTextString(m) } +func (*BindRequest) ProtoMessage() {} +func (*BindRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{17} +} + +func (m *BindRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_BindRequest.Unmarshal(m, b) +} +func (m *BindRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_BindRequest.Marshal(b, m, deterministic) +} +func (m *BindRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_BindRequest.Merge(m, src) +} +func (m *BindRequest) XXX_Size() int { + return xxx_messageInfo_BindRequest.Size(m) +} +func (m *BindRequest) XXX_DiscardUnknown() { + xxx_messageInfo_BindRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_BindRequest proto.InternalMessageInfo + +func (m *BindRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *BindRequest) GetAddress() []byte { + if m != nil { + return m.Address + } + return nil +} + +type BindResponse struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *BindResponse) Reset() { *m = BindResponse{} } +func (m *BindResponse) String() string { return proto.CompactTextString(m) } +func (*BindResponse) ProtoMessage() {} +func (*BindResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{18} +} + +func (m *BindResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_BindResponse.Unmarshal(m, b) +} +func (m *BindResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_BindResponse.Marshal(b, m, deterministic) +} +func (m *BindResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_BindResponse.Merge(m, src) +} +func (m *BindResponse) XXX_Size() int { + return xxx_messageInfo_BindResponse.Size(m) +} +func (m *BindResponse) XXX_DiscardUnknown() { + xxx_messageInfo_BindResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_BindResponse proto.InternalMessageInfo + +func (m *BindResponse) GetErrorNumber() uint32 { + if m != nil { + return m.ErrorNumber + } + return 0 +} + +type AcceptRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Peer bool `protobuf:"varint,2,opt,name=peer,proto3" json:"peer,omitempty"` + Flags int64 `protobuf:"varint,3,opt,name=flags,proto3" json:"flags,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *AcceptRequest) Reset() { *m = AcceptRequest{} } +func (m *AcceptRequest) String() string { return proto.CompactTextString(m) } +func (*AcceptRequest) ProtoMessage() {} +func (*AcceptRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{19} +} + +func (m *AcceptRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_AcceptRequest.Unmarshal(m, b) +} +func (m *AcceptRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_AcceptRequest.Marshal(b, m, deterministic) +} +func (m *AcceptRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_AcceptRequest.Merge(m, src) +} +func (m *AcceptRequest) XXX_Size() int { + return xxx_messageInfo_AcceptRequest.Size(m) +} +func (m *AcceptRequest) XXX_DiscardUnknown() { + xxx_messageInfo_AcceptRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_AcceptRequest proto.InternalMessageInfo + +func (m *AcceptRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *AcceptRequest) GetPeer() bool { + if m != nil { + return m.Peer + } + return false +} + +func (m *AcceptRequest) GetFlags() int64 { + if m != nil { + return m.Flags + } + return 0 +} + +type AcceptResponse struct { + // Types that are valid to be assigned to Result: + // *AcceptResponse_ErrorNumber + // *AcceptResponse_Payload + Result isAcceptResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *AcceptResponse) Reset() { *m = AcceptResponse{} } +func (m *AcceptResponse) String() string { return proto.CompactTextString(m) } +func (*AcceptResponse) ProtoMessage() {} +func (*AcceptResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{20} +} + +func (m *AcceptResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_AcceptResponse.Unmarshal(m, b) +} +func (m *AcceptResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_AcceptResponse.Marshal(b, m, deterministic) +} +func (m *AcceptResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_AcceptResponse.Merge(m, src) +} +func (m *AcceptResponse) XXX_Size() int { + return xxx_messageInfo_AcceptResponse.Size(m) +} +func (m *AcceptResponse) XXX_DiscardUnknown() { + xxx_messageInfo_AcceptResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_AcceptResponse proto.InternalMessageInfo + +type isAcceptResponse_Result interface { + isAcceptResponse_Result() +} + +type AcceptResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type AcceptResponse_Payload struct { + Payload *AcceptResponse_ResultPayload `protobuf:"bytes,2,opt,name=payload,proto3,oneof"` +} + +func (*AcceptResponse_ErrorNumber) isAcceptResponse_Result() {} + +func (*AcceptResponse_Payload) isAcceptResponse_Result() {} + +func (m *AcceptResponse) GetResult() isAcceptResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *AcceptResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*AcceptResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *AcceptResponse) GetPayload() *AcceptResponse_ResultPayload { + if x, ok := m.GetResult().(*AcceptResponse_Payload); ok { + return x.Payload + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*AcceptResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*AcceptResponse_ErrorNumber)(nil), + (*AcceptResponse_Payload)(nil), + } +} + +type AcceptResponse_ResultPayload struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Address *AddressResponse `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *AcceptResponse_ResultPayload) Reset() { *m = AcceptResponse_ResultPayload{} } +func (m *AcceptResponse_ResultPayload) String() string { return proto.CompactTextString(m) } +func (*AcceptResponse_ResultPayload) ProtoMessage() {} +func (*AcceptResponse_ResultPayload) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{20, 0} +} + +func (m *AcceptResponse_ResultPayload) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_AcceptResponse_ResultPayload.Unmarshal(m, b) +} +func (m *AcceptResponse_ResultPayload) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_AcceptResponse_ResultPayload.Marshal(b, m, deterministic) +} +func (m *AcceptResponse_ResultPayload) XXX_Merge(src proto.Message) { + xxx_messageInfo_AcceptResponse_ResultPayload.Merge(m, src) +} +func (m *AcceptResponse_ResultPayload) XXX_Size() int { + return xxx_messageInfo_AcceptResponse_ResultPayload.Size(m) +} +func (m *AcceptResponse_ResultPayload) XXX_DiscardUnknown() { + xxx_messageInfo_AcceptResponse_ResultPayload.DiscardUnknown(m) +} + +var xxx_messageInfo_AcceptResponse_ResultPayload proto.InternalMessageInfo + +func (m *AcceptResponse_ResultPayload) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *AcceptResponse_ResultPayload) GetAddress() *AddressResponse { + if m != nil { + return m.Address + } + return nil +} + +type ConnectRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Address []byte `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ConnectRequest) Reset() { *m = ConnectRequest{} } +func (m *ConnectRequest) String() string { return proto.CompactTextString(m) } +func (*ConnectRequest) ProtoMessage() {} +func (*ConnectRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{21} +} + +func (m *ConnectRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ConnectRequest.Unmarshal(m, b) +} +func (m *ConnectRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ConnectRequest.Marshal(b, m, deterministic) +} +func (m *ConnectRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ConnectRequest.Merge(m, src) +} +func (m *ConnectRequest) XXX_Size() int { + return xxx_messageInfo_ConnectRequest.Size(m) +} +func (m *ConnectRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ConnectRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ConnectRequest proto.InternalMessageInfo + +func (m *ConnectRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *ConnectRequest) GetAddress() []byte { + if m != nil { + return m.Address + } + return nil +} + +type ConnectResponse struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ConnectResponse) Reset() { *m = ConnectResponse{} } +func (m *ConnectResponse) String() string { return proto.CompactTextString(m) } +func (*ConnectResponse) ProtoMessage() {} +func (*ConnectResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{22} +} + +func (m *ConnectResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ConnectResponse.Unmarshal(m, b) +} +func (m *ConnectResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ConnectResponse.Marshal(b, m, deterministic) +} +func (m *ConnectResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ConnectResponse.Merge(m, src) +} +func (m *ConnectResponse) XXX_Size() int { + return xxx_messageInfo_ConnectResponse.Size(m) +} +func (m *ConnectResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ConnectResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ConnectResponse proto.InternalMessageInfo + +func (m *ConnectResponse) GetErrorNumber() uint32 { + if m != nil { + return m.ErrorNumber + } + return 0 +} + +type ListenRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Backlog int64 `protobuf:"varint,2,opt,name=backlog,proto3" json:"backlog,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ListenRequest) Reset() { *m = ListenRequest{} } +func (m *ListenRequest) String() string { return proto.CompactTextString(m) } +func (*ListenRequest) ProtoMessage() {} +func (*ListenRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{23} +} + +func (m *ListenRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ListenRequest.Unmarshal(m, b) +} +func (m *ListenRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ListenRequest.Marshal(b, m, deterministic) +} +func (m *ListenRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ListenRequest.Merge(m, src) +} +func (m *ListenRequest) XXX_Size() int { + return xxx_messageInfo_ListenRequest.Size(m) +} +func (m *ListenRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ListenRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ListenRequest proto.InternalMessageInfo + +func (m *ListenRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *ListenRequest) GetBacklog() int64 { + if m != nil { + return m.Backlog + } + return 0 +} + +type ListenResponse struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ListenResponse) Reset() { *m = ListenResponse{} } +func (m *ListenResponse) String() string { return proto.CompactTextString(m) } +func (*ListenResponse) ProtoMessage() {} +func (*ListenResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{24} +} + +func (m *ListenResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ListenResponse.Unmarshal(m, b) +} +func (m *ListenResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ListenResponse.Marshal(b, m, deterministic) +} +func (m *ListenResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ListenResponse.Merge(m, src) +} +func (m *ListenResponse) XXX_Size() int { + return xxx_messageInfo_ListenResponse.Size(m) +} +func (m *ListenResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ListenResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ListenResponse proto.InternalMessageInfo + +func (m *ListenResponse) GetErrorNumber() uint32 { + if m != nil { + return m.ErrorNumber + } + return 0 +} + +type ShutdownRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + How int64 `protobuf:"varint,2,opt,name=how,proto3" json:"how,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ShutdownRequest) Reset() { *m = ShutdownRequest{} } +func (m *ShutdownRequest) String() string { return proto.CompactTextString(m) } +func (*ShutdownRequest) ProtoMessage() {} +func (*ShutdownRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{25} +} + +func (m *ShutdownRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ShutdownRequest.Unmarshal(m, b) +} +func (m *ShutdownRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ShutdownRequest.Marshal(b, m, deterministic) +} +func (m *ShutdownRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ShutdownRequest.Merge(m, src) +} +func (m *ShutdownRequest) XXX_Size() int { + return xxx_messageInfo_ShutdownRequest.Size(m) +} +func (m *ShutdownRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ShutdownRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ShutdownRequest proto.InternalMessageInfo + +func (m *ShutdownRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *ShutdownRequest) GetHow() int64 { + if m != nil { + return m.How + } + return 0 +} + +type ShutdownResponse struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ShutdownResponse) Reset() { *m = ShutdownResponse{} } +func (m *ShutdownResponse) String() string { return proto.CompactTextString(m) } +func (*ShutdownResponse) ProtoMessage() {} +func (*ShutdownResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{26} +} + +func (m *ShutdownResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ShutdownResponse.Unmarshal(m, b) +} +func (m *ShutdownResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ShutdownResponse.Marshal(b, m, deterministic) +} +func (m *ShutdownResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ShutdownResponse.Merge(m, src) +} +func (m *ShutdownResponse) XXX_Size() int { + return xxx_messageInfo_ShutdownResponse.Size(m) +} +func (m *ShutdownResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ShutdownResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ShutdownResponse proto.InternalMessageInfo + +func (m *ShutdownResponse) GetErrorNumber() uint32 { + if m != nil { + return m.ErrorNumber + } + return 0 +} + +type CloseRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *CloseRequest) Reset() { *m = CloseRequest{} } +func (m *CloseRequest) String() string { return proto.CompactTextString(m) } +func (*CloseRequest) ProtoMessage() {} +func (*CloseRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{27} +} + +func (m *CloseRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_CloseRequest.Unmarshal(m, b) +} +func (m *CloseRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_CloseRequest.Marshal(b, m, deterministic) +} +func (m *CloseRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_CloseRequest.Merge(m, src) +} +func (m *CloseRequest) XXX_Size() int { + return xxx_messageInfo_CloseRequest.Size(m) +} +func (m *CloseRequest) XXX_DiscardUnknown() { + xxx_messageInfo_CloseRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_CloseRequest proto.InternalMessageInfo + +func (m *CloseRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +type CloseResponse struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *CloseResponse) Reset() { *m = CloseResponse{} } +func (m *CloseResponse) String() string { return proto.CompactTextString(m) } +func (*CloseResponse) ProtoMessage() {} +func (*CloseResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{28} +} + +func (m *CloseResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_CloseResponse.Unmarshal(m, b) +} +func (m *CloseResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_CloseResponse.Marshal(b, m, deterministic) +} +func (m *CloseResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_CloseResponse.Merge(m, src) +} +func (m *CloseResponse) XXX_Size() int { + return xxx_messageInfo_CloseResponse.Size(m) +} +func (m *CloseResponse) XXX_DiscardUnknown() { + xxx_messageInfo_CloseResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_CloseResponse proto.InternalMessageInfo + +func (m *CloseResponse) GetErrorNumber() uint32 { + if m != nil { + return m.ErrorNumber + } + return 0 +} + +type GetSockOptRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Level int64 `protobuf:"varint,2,opt,name=level,proto3" json:"level,omitempty"` + Name int64 `protobuf:"varint,3,opt,name=name,proto3" json:"name,omitempty"` + Length uint32 `protobuf:"varint,4,opt,name=length,proto3" json:"length,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetSockOptRequest) Reset() { *m = GetSockOptRequest{} } +func (m *GetSockOptRequest) String() string { return proto.CompactTextString(m) } +func (*GetSockOptRequest) ProtoMessage() {} +func (*GetSockOptRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{29} +} + +func (m *GetSockOptRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetSockOptRequest.Unmarshal(m, b) +} +func (m *GetSockOptRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetSockOptRequest.Marshal(b, m, deterministic) +} +func (m *GetSockOptRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetSockOptRequest.Merge(m, src) +} +func (m *GetSockOptRequest) XXX_Size() int { + return xxx_messageInfo_GetSockOptRequest.Size(m) +} +func (m *GetSockOptRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GetSockOptRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GetSockOptRequest proto.InternalMessageInfo + +func (m *GetSockOptRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *GetSockOptRequest) GetLevel() int64 { + if m != nil { + return m.Level + } + return 0 +} + +func (m *GetSockOptRequest) GetName() int64 { + if m != nil { + return m.Name + } + return 0 +} + +func (m *GetSockOptRequest) GetLength() uint32 { + if m != nil { + return m.Length + } + return 0 +} + +type GetSockOptResponse struct { + // Types that are valid to be assigned to Result: + // *GetSockOptResponse_ErrorNumber + // *GetSockOptResponse_Opt + Result isGetSockOptResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetSockOptResponse) Reset() { *m = GetSockOptResponse{} } +func (m *GetSockOptResponse) String() string { return proto.CompactTextString(m) } +func (*GetSockOptResponse) ProtoMessage() {} +func (*GetSockOptResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{30} +} + +func (m *GetSockOptResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetSockOptResponse.Unmarshal(m, b) +} +func (m *GetSockOptResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetSockOptResponse.Marshal(b, m, deterministic) +} +func (m *GetSockOptResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetSockOptResponse.Merge(m, src) +} +func (m *GetSockOptResponse) XXX_Size() int { + return xxx_messageInfo_GetSockOptResponse.Size(m) +} +func (m *GetSockOptResponse) XXX_DiscardUnknown() { + xxx_messageInfo_GetSockOptResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_GetSockOptResponse proto.InternalMessageInfo + +type isGetSockOptResponse_Result interface { + isGetSockOptResponse_Result() +} + +type GetSockOptResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type GetSockOptResponse_Opt struct { + Opt []byte `protobuf:"bytes,2,opt,name=opt,proto3,oneof"` +} + +func (*GetSockOptResponse_ErrorNumber) isGetSockOptResponse_Result() {} + +func (*GetSockOptResponse_Opt) isGetSockOptResponse_Result() {} + +func (m *GetSockOptResponse) GetResult() isGetSockOptResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *GetSockOptResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*GetSockOptResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *GetSockOptResponse) GetOpt() []byte { + if x, ok := m.GetResult().(*GetSockOptResponse_Opt); ok { + return x.Opt + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*GetSockOptResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*GetSockOptResponse_ErrorNumber)(nil), + (*GetSockOptResponse_Opt)(nil), + } +} + +type SetSockOptRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Level int64 `protobuf:"varint,2,opt,name=level,proto3" json:"level,omitempty"` + Name int64 `protobuf:"varint,3,opt,name=name,proto3" json:"name,omitempty"` + Opt []byte `protobuf:"bytes,4,opt,name=opt,proto3" json:"opt,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SetSockOptRequest) Reset() { *m = SetSockOptRequest{} } +func (m *SetSockOptRequest) String() string { return proto.CompactTextString(m) } +func (*SetSockOptRequest) ProtoMessage() {} +func (*SetSockOptRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{31} +} + +func (m *SetSockOptRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SetSockOptRequest.Unmarshal(m, b) +} +func (m *SetSockOptRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SetSockOptRequest.Marshal(b, m, deterministic) +} +func (m *SetSockOptRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_SetSockOptRequest.Merge(m, src) +} +func (m *SetSockOptRequest) XXX_Size() int { + return xxx_messageInfo_SetSockOptRequest.Size(m) +} +func (m *SetSockOptRequest) XXX_DiscardUnknown() { + xxx_messageInfo_SetSockOptRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_SetSockOptRequest proto.InternalMessageInfo + +func (m *SetSockOptRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *SetSockOptRequest) GetLevel() int64 { + if m != nil { + return m.Level + } + return 0 +} + +func (m *SetSockOptRequest) GetName() int64 { + if m != nil { + return m.Name + } + return 0 +} + +func (m *SetSockOptRequest) GetOpt() []byte { + if m != nil { + return m.Opt + } + return nil +} + +type SetSockOptResponse struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SetSockOptResponse) Reset() { *m = SetSockOptResponse{} } +func (m *SetSockOptResponse) String() string { return proto.CompactTextString(m) } +func (*SetSockOptResponse) ProtoMessage() {} +func (*SetSockOptResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{32} +} + +func (m *SetSockOptResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SetSockOptResponse.Unmarshal(m, b) +} +func (m *SetSockOptResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SetSockOptResponse.Marshal(b, m, deterministic) +} +func (m *SetSockOptResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_SetSockOptResponse.Merge(m, src) +} +func (m *SetSockOptResponse) XXX_Size() int { + return xxx_messageInfo_SetSockOptResponse.Size(m) +} +func (m *SetSockOptResponse) XXX_DiscardUnknown() { + xxx_messageInfo_SetSockOptResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_SetSockOptResponse proto.InternalMessageInfo + +func (m *SetSockOptResponse) GetErrorNumber() uint32 { + if m != nil { + return m.ErrorNumber + } + return 0 +} + +type GetSockNameRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetSockNameRequest) Reset() { *m = GetSockNameRequest{} } +func (m *GetSockNameRequest) String() string { return proto.CompactTextString(m) } +func (*GetSockNameRequest) ProtoMessage() {} +func (*GetSockNameRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{33} +} + +func (m *GetSockNameRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetSockNameRequest.Unmarshal(m, b) +} +func (m *GetSockNameRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetSockNameRequest.Marshal(b, m, deterministic) +} +func (m *GetSockNameRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetSockNameRequest.Merge(m, src) +} +func (m *GetSockNameRequest) XXX_Size() int { + return xxx_messageInfo_GetSockNameRequest.Size(m) +} +func (m *GetSockNameRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GetSockNameRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GetSockNameRequest proto.InternalMessageInfo + +func (m *GetSockNameRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +type GetSockNameResponse struct { + // Types that are valid to be assigned to Result: + // *GetSockNameResponse_ErrorNumber + // *GetSockNameResponse_Address + Result isGetSockNameResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetSockNameResponse) Reset() { *m = GetSockNameResponse{} } +func (m *GetSockNameResponse) String() string { return proto.CompactTextString(m) } +func (*GetSockNameResponse) ProtoMessage() {} +func (*GetSockNameResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{34} +} + +func (m *GetSockNameResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetSockNameResponse.Unmarshal(m, b) +} +func (m *GetSockNameResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetSockNameResponse.Marshal(b, m, deterministic) +} +func (m *GetSockNameResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetSockNameResponse.Merge(m, src) +} +func (m *GetSockNameResponse) XXX_Size() int { + return xxx_messageInfo_GetSockNameResponse.Size(m) +} +func (m *GetSockNameResponse) XXX_DiscardUnknown() { + xxx_messageInfo_GetSockNameResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_GetSockNameResponse proto.InternalMessageInfo + +type isGetSockNameResponse_Result interface { + isGetSockNameResponse_Result() +} + +type GetSockNameResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type GetSockNameResponse_Address struct { + Address *AddressResponse `protobuf:"bytes,2,opt,name=address,proto3,oneof"` +} + +func (*GetSockNameResponse_ErrorNumber) isGetSockNameResponse_Result() {} + +func (*GetSockNameResponse_Address) isGetSockNameResponse_Result() {} + +func (m *GetSockNameResponse) GetResult() isGetSockNameResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *GetSockNameResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*GetSockNameResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *GetSockNameResponse) GetAddress() *AddressResponse { + if x, ok := m.GetResult().(*GetSockNameResponse_Address); ok { + return x.Address + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*GetSockNameResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*GetSockNameResponse_ErrorNumber)(nil), + (*GetSockNameResponse_Address)(nil), + } +} + +type GetPeerNameRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetPeerNameRequest) Reset() { *m = GetPeerNameRequest{} } +func (m *GetPeerNameRequest) String() string { return proto.CompactTextString(m) } +func (*GetPeerNameRequest) ProtoMessage() {} +func (*GetPeerNameRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{35} +} + +func (m *GetPeerNameRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetPeerNameRequest.Unmarshal(m, b) +} +func (m *GetPeerNameRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetPeerNameRequest.Marshal(b, m, deterministic) +} +func (m *GetPeerNameRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetPeerNameRequest.Merge(m, src) +} +func (m *GetPeerNameRequest) XXX_Size() int { + return xxx_messageInfo_GetPeerNameRequest.Size(m) +} +func (m *GetPeerNameRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GetPeerNameRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GetPeerNameRequest proto.InternalMessageInfo + +func (m *GetPeerNameRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +type GetPeerNameResponse struct { + // Types that are valid to be assigned to Result: + // *GetPeerNameResponse_ErrorNumber + // *GetPeerNameResponse_Address + Result isGetPeerNameResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetPeerNameResponse) Reset() { *m = GetPeerNameResponse{} } +func (m *GetPeerNameResponse) String() string { return proto.CompactTextString(m) } +func (*GetPeerNameResponse) ProtoMessage() {} +func (*GetPeerNameResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{36} +} + +func (m *GetPeerNameResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetPeerNameResponse.Unmarshal(m, b) +} +func (m *GetPeerNameResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetPeerNameResponse.Marshal(b, m, deterministic) +} +func (m *GetPeerNameResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetPeerNameResponse.Merge(m, src) +} +func (m *GetPeerNameResponse) XXX_Size() int { + return xxx_messageInfo_GetPeerNameResponse.Size(m) +} +func (m *GetPeerNameResponse) XXX_DiscardUnknown() { + xxx_messageInfo_GetPeerNameResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_GetPeerNameResponse proto.InternalMessageInfo + +type isGetPeerNameResponse_Result interface { + isGetPeerNameResponse_Result() +} + +type GetPeerNameResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type GetPeerNameResponse_Address struct { + Address *AddressResponse `protobuf:"bytes,2,opt,name=address,proto3,oneof"` +} + +func (*GetPeerNameResponse_ErrorNumber) isGetPeerNameResponse_Result() {} + +func (*GetPeerNameResponse_Address) isGetPeerNameResponse_Result() {} + +func (m *GetPeerNameResponse) GetResult() isGetPeerNameResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *GetPeerNameResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*GetPeerNameResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *GetPeerNameResponse) GetAddress() *AddressResponse { + if x, ok := m.GetResult().(*GetPeerNameResponse_Address); ok { + return x.Address + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*GetPeerNameResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*GetPeerNameResponse_ErrorNumber)(nil), + (*GetPeerNameResponse_Address)(nil), + } +} + +type SocketRequest struct { + Family int64 `protobuf:"varint,1,opt,name=family,proto3" json:"family,omitempty"` + Type int64 `protobuf:"varint,2,opt,name=type,proto3" json:"type,omitempty"` + Protocol int64 `protobuf:"varint,3,opt,name=protocol,proto3" json:"protocol,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SocketRequest) Reset() { *m = SocketRequest{} } +func (m *SocketRequest) String() string { return proto.CompactTextString(m) } +func (*SocketRequest) ProtoMessage() {} +func (*SocketRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{37} +} + +func (m *SocketRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SocketRequest.Unmarshal(m, b) +} +func (m *SocketRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SocketRequest.Marshal(b, m, deterministic) +} +func (m *SocketRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_SocketRequest.Merge(m, src) +} +func (m *SocketRequest) XXX_Size() int { + return xxx_messageInfo_SocketRequest.Size(m) +} +func (m *SocketRequest) XXX_DiscardUnknown() { + xxx_messageInfo_SocketRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_SocketRequest proto.InternalMessageInfo + +func (m *SocketRequest) GetFamily() int64 { + if m != nil { + return m.Family + } + return 0 +} + +func (m *SocketRequest) GetType() int64 { + if m != nil { + return m.Type + } + return 0 +} + +func (m *SocketRequest) GetProtocol() int64 { + if m != nil { + return m.Protocol + } + return 0 +} + +type SocketResponse struct { + // Types that are valid to be assigned to Result: + // *SocketResponse_ErrorNumber + // *SocketResponse_Fd + Result isSocketResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SocketResponse) Reset() { *m = SocketResponse{} } +func (m *SocketResponse) String() string { return proto.CompactTextString(m) } +func (*SocketResponse) ProtoMessage() {} +func (*SocketResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{38} +} + +func (m *SocketResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SocketResponse.Unmarshal(m, b) +} +func (m *SocketResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SocketResponse.Marshal(b, m, deterministic) +} +func (m *SocketResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_SocketResponse.Merge(m, src) +} +func (m *SocketResponse) XXX_Size() int { + return xxx_messageInfo_SocketResponse.Size(m) +} +func (m *SocketResponse) XXX_DiscardUnknown() { + xxx_messageInfo_SocketResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_SocketResponse proto.InternalMessageInfo + +type isSocketResponse_Result interface { + isSocketResponse_Result() +} + +type SocketResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type SocketResponse_Fd struct { + Fd uint32 `protobuf:"varint,2,opt,name=fd,proto3,oneof"` +} + +func (*SocketResponse_ErrorNumber) isSocketResponse_Result() {} + +func (*SocketResponse_Fd) isSocketResponse_Result() {} + +func (m *SocketResponse) GetResult() isSocketResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *SocketResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*SocketResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *SocketResponse) GetFd() uint32 { + if x, ok := m.GetResult().(*SocketResponse_Fd); ok { + return x.Fd + } + return 0 +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*SocketResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*SocketResponse_ErrorNumber)(nil), + (*SocketResponse_Fd)(nil), + } +} + +type EpollWaitRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + NumEvents uint32 `protobuf:"varint,2,opt,name=num_events,json=numEvents,proto3" json:"num_events,omitempty"` + Msec int64 `protobuf:"zigzag64,3,opt,name=msec,proto3" json:"msec,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *EpollWaitRequest) Reset() { *m = EpollWaitRequest{} } +func (m *EpollWaitRequest) String() string { return proto.CompactTextString(m) } +func (*EpollWaitRequest) ProtoMessage() {} +func (*EpollWaitRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{39} +} + +func (m *EpollWaitRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_EpollWaitRequest.Unmarshal(m, b) +} +func (m *EpollWaitRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_EpollWaitRequest.Marshal(b, m, deterministic) +} +func (m *EpollWaitRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_EpollWaitRequest.Merge(m, src) +} +func (m *EpollWaitRequest) XXX_Size() int { + return xxx_messageInfo_EpollWaitRequest.Size(m) +} +func (m *EpollWaitRequest) XXX_DiscardUnknown() { + xxx_messageInfo_EpollWaitRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_EpollWaitRequest proto.InternalMessageInfo + +func (m *EpollWaitRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *EpollWaitRequest) GetNumEvents() uint32 { + if m != nil { + return m.NumEvents + } + return 0 +} + +func (m *EpollWaitRequest) GetMsec() int64 { + if m != nil { + return m.Msec + } + return 0 +} + +type EpollEvent struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Events uint32 `protobuf:"varint,2,opt,name=events,proto3" json:"events,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *EpollEvent) Reset() { *m = EpollEvent{} } +func (m *EpollEvent) String() string { return proto.CompactTextString(m) } +func (*EpollEvent) ProtoMessage() {} +func (*EpollEvent) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{40} +} + +func (m *EpollEvent) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_EpollEvent.Unmarshal(m, b) +} +func (m *EpollEvent) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_EpollEvent.Marshal(b, m, deterministic) +} +func (m *EpollEvent) XXX_Merge(src proto.Message) { + xxx_messageInfo_EpollEvent.Merge(m, src) +} +func (m *EpollEvent) XXX_Size() int { + return xxx_messageInfo_EpollEvent.Size(m) +} +func (m *EpollEvent) XXX_DiscardUnknown() { + xxx_messageInfo_EpollEvent.DiscardUnknown(m) +} + +var xxx_messageInfo_EpollEvent proto.InternalMessageInfo + +func (m *EpollEvent) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *EpollEvent) GetEvents() uint32 { + if m != nil { + return m.Events + } + return 0 +} + +type EpollEvents struct { + Events []*EpollEvent `protobuf:"bytes,1,rep,name=events,proto3" json:"events,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *EpollEvents) Reset() { *m = EpollEvents{} } +func (m *EpollEvents) String() string { return proto.CompactTextString(m) } +func (*EpollEvents) ProtoMessage() {} +func (*EpollEvents) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{41} +} + +func (m *EpollEvents) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_EpollEvents.Unmarshal(m, b) +} +func (m *EpollEvents) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_EpollEvents.Marshal(b, m, deterministic) +} +func (m *EpollEvents) XXX_Merge(src proto.Message) { + xxx_messageInfo_EpollEvents.Merge(m, src) +} +func (m *EpollEvents) XXX_Size() int { + return xxx_messageInfo_EpollEvents.Size(m) +} +func (m *EpollEvents) XXX_DiscardUnknown() { + xxx_messageInfo_EpollEvents.DiscardUnknown(m) +} + +var xxx_messageInfo_EpollEvents proto.InternalMessageInfo + +func (m *EpollEvents) GetEvents() []*EpollEvent { + if m != nil { + return m.Events + } + return nil +} + +type EpollWaitResponse struct { + // Types that are valid to be assigned to Result: + // *EpollWaitResponse_ErrorNumber + // *EpollWaitResponse_Events + Result isEpollWaitResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *EpollWaitResponse) Reset() { *m = EpollWaitResponse{} } +func (m *EpollWaitResponse) String() string { return proto.CompactTextString(m) } +func (*EpollWaitResponse) ProtoMessage() {} +func (*EpollWaitResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{42} +} + +func (m *EpollWaitResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_EpollWaitResponse.Unmarshal(m, b) +} +func (m *EpollWaitResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_EpollWaitResponse.Marshal(b, m, deterministic) +} +func (m *EpollWaitResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_EpollWaitResponse.Merge(m, src) +} +func (m *EpollWaitResponse) XXX_Size() int { + return xxx_messageInfo_EpollWaitResponse.Size(m) +} +func (m *EpollWaitResponse) XXX_DiscardUnknown() { + xxx_messageInfo_EpollWaitResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_EpollWaitResponse proto.InternalMessageInfo + +type isEpollWaitResponse_Result interface { + isEpollWaitResponse_Result() +} + +type EpollWaitResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type EpollWaitResponse_Events struct { + Events *EpollEvents `protobuf:"bytes,2,opt,name=events,proto3,oneof"` +} + +func (*EpollWaitResponse_ErrorNumber) isEpollWaitResponse_Result() {} + +func (*EpollWaitResponse_Events) isEpollWaitResponse_Result() {} + +func (m *EpollWaitResponse) GetResult() isEpollWaitResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *EpollWaitResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*EpollWaitResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *EpollWaitResponse) GetEvents() *EpollEvents { + if x, ok := m.GetResult().(*EpollWaitResponse_Events); ok { + return x.Events + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*EpollWaitResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*EpollWaitResponse_ErrorNumber)(nil), + (*EpollWaitResponse_Events)(nil), + } +} + +type EpollCtlRequest struct { + Epfd uint32 `protobuf:"varint,1,opt,name=epfd,proto3" json:"epfd,omitempty"` + Op int64 `protobuf:"varint,2,opt,name=op,proto3" json:"op,omitempty"` + Fd uint32 `protobuf:"varint,3,opt,name=fd,proto3" json:"fd,omitempty"` + Event *EpollEvent `protobuf:"bytes,4,opt,name=event,proto3" json:"event,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *EpollCtlRequest) Reset() { *m = EpollCtlRequest{} } +func (m *EpollCtlRequest) String() string { return proto.CompactTextString(m) } +func (*EpollCtlRequest) ProtoMessage() {} +func (*EpollCtlRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{43} +} + +func (m *EpollCtlRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_EpollCtlRequest.Unmarshal(m, b) +} +func (m *EpollCtlRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_EpollCtlRequest.Marshal(b, m, deterministic) +} +func (m *EpollCtlRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_EpollCtlRequest.Merge(m, src) +} +func (m *EpollCtlRequest) XXX_Size() int { + return xxx_messageInfo_EpollCtlRequest.Size(m) +} +func (m *EpollCtlRequest) XXX_DiscardUnknown() { + xxx_messageInfo_EpollCtlRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_EpollCtlRequest proto.InternalMessageInfo + +func (m *EpollCtlRequest) GetEpfd() uint32 { + if m != nil { + return m.Epfd + } + return 0 +} + +func (m *EpollCtlRequest) GetOp() int64 { + if m != nil { + return m.Op + } + return 0 +} + +func (m *EpollCtlRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *EpollCtlRequest) GetEvent() *EpollEvent { + if m != nil { + return m.Event + } + return nil +} + +type EpollCtlResponse struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3" json:"error_number,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *EpollCtlResponse) Reset() { *m = EpollCtlResponse{} } +func (m *EpollCtlResponse) String() string { return proto.CompactTextString(m) } +func (*EpollCtlResponse) ProtoMessage() {} +func (*EpollCtlResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{44} +} + +func (m *EpollCtlResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_EpollCtlResponse.Unmarshal(m, b) +} +func (m *EpollCtlResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_EpollCtlResponse.Marshal(b, m, deterministic) +} +func (m *EpollCtlResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_EpollCtlResponse.Merge(m, src) +} +func (m *EpollCtlResponse) XXX_Size() int { + return xxx_messageInfo_EpollCtlResponse.Size(m) +} +func (m *EpollCtlResponse) XXX_DiscardUnknown() { + xxx_messageInfo_EpollCtlResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_EpollCtlResponse proto.InternalMessageInfo + +func (m *EpollCtlResponse) GetErrorNumber() uint32 { + if m != nil { + return m.ErrorNumber + } + return 0 +} + +type EpollCreate1Request struct { + Flag int64 `protobuf:"varint,1,opt,name=flag,proto3" json:"flag,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *EpollCreate1Request) Reset() { *m = EpollCreate1Request{} } +func (m *EpollCreate1Request) String() string { return proto.CompactTextString(m) } +func (*EpollCreate1Request) ProtoMessage() {} +func (*EpollCreate1Request) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{45} +} + +func (m *EpollCreate1Request) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_EpollCreate1Request.Unmarshal(m, b) +} +func (m *EpollCreate1Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_EpollCreate1Request.Marshal(b, m, deterministic) +} +func (m *EpollCreate1Request) XXX_Merge(src proto.Message) { + xxx_messageInfo_EpollCreate1Request.Merge(m, src) +} +func (m *EpollCreate1Request) XXX_Size() int { + return xxx_messageInfo_EpollCreate1Request.Size(m) +} +func (m *EpollCreate1Request) XXX_DiscardUnknown() { + xxx_messageInfo_EpollCreate1Request.DiscardUnknown(m) +} + +var xxx_messageInfo_EpollCreate1Request proto.InternalMessageInfo + +func (m *EpollCreate1Request) GetFlag() int64 { + if m != nil { + return m.Flag + } + return 0 +} + +type EpollCreate1Response struct { + // Types that are valid to be assigned to Result: + // *EpollCreate1Response_ErrorNumber + // *EpollCreate1Response_Fd + Result isEpollCreate1Response_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *EpollCreate1Response) Reset() { *m = EpollCreate1Response{} } +func (m *EpollCreate1Response) String() string { return proto.CompactTextString(m) } +func (*EpollCreate1Response) ProtoMessage() {} +func (*EpollCreate1Response) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{46} +} + +func (m *EpollCreate1Response) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_EpollCreate1Response.Unmarshal(m, b) +} +func (m *EpollCreate1Response) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_EpollCreate1Response.Marshal(b, m, deterministic) +} +func (m *EpollCreate1Response) XXX_Merge(src proto.Message) { + xxx_messageInfo_EpollCreate1Response.Merge(m, src) +} +func (m *EpollCreate1Response) XXX_Size() int { + return xxx_messageInfo_EpollCreate1Response.Size(m) +} +func (m *EpollCreate1Response) XXX_DiscardUnknown() { + xxx_messageInfo_EpollCreate1Response.DiscardUnknown(m) +} + +var xxx_messageInfo_EpollCreate1Response proto.InternalMessageInfo + +type isEpollCreate1Response_Result interface { + isEpollCreate1Response_Result() +} + +type EpollCreate1Response_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type EpollCreate1Response_Fd struct { + Fd uint32 `protobuf:"varint,2,opt,name=fd,proto3,oneof"` +} + +func (*EpollCreate1Response_ErrorNumber) isEpollCreate1Response_Result() {} + +func (*EpollCreate1Response_Fd) isEpollCreate1Response_Result() {} + +func (m *EpollCreate1Response) GetResult() isEpollCreate1Response_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *EpollCreate1Response) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*EpollCreate1Response_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *EpollCreate1Response) GetFd() uint32 { + if x, ok := m.GetResult().(*EpollCreate1Response_Fd); ok { + return x.Fd + } + return 0 +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*EpollCreate1Response) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*EpollCreate1Response_ErrorNumber)(nil), + (*EpollCreate1Response_Fd)(nil), + } +} + +type PollRequest struct { + Fd uint32 `protobuf:"varint,1,opt,name=fd,proto3" json:"fd,omitempty"` + Events uint32 `protobuf:"varint,2,opt,name=events,proto3" json:"events,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *PollRequest) Reset() { *m = PollRequest{} } +func (m *PollRequest) String() string { return proto.CompactTextString(m) } +func (*PollRequest) ProtoMessage() {} +func (*PollRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{47} +} + +func (m *PollRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_PollRequest.Unmarshal(m, b) +} +func (m *PollRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_PollRequest.Marshal(b, m, deterministic) +} +func (m *PollRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_PollRequest.Merge(m, src) +} +func (m *PollRequest) XXX_Size() int { + return xxx_messageInfo_PollRequest.Size(m) +} +func (m *PollRequest) XXX_DiscardUnknown() { + xxx_messageInfo_PollRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_PollRequest proto.InternalMessageInfo + +func (m *PollRequest) GetFd() uint32 { + if m != nil { + return m.Fd + } + return 0 +} + +func (m *PollRequest) GetEvents() uint32 { + if m != nil { + return m.Events + } + return 0 +} + +type PollResponse struct { + // Types that are valid to be assigned to Result: + // *PollResponse_ErrorNumber + // *PollResponse_Events + Result isPollResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *PollResponse) Reset() { *m = PollResponse{} } +func (m *PollResponse) String() string { return proto.CompactTextString(m) } +func (*PollResponse) ProtoMessage() {} +func (*PollResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{48} +} + +func (m *PollResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_PollResponse.Unmarshal(m, b) +} +func (m *PollResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_PollResponse.Marshal(b, m, deterministic) +} +func (m *PollResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_PollResponse.Merge(m, src) +} +func (m *PollResponse) XXX_Size() int { + return xxx_messageInfo_PollResponse.Size(m) +} +func (m *PollResponse) XXX_DiscardUnknown() { + xxx_messageInfo_PollResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_PollResponse proto.InternalMessageInfo + +type isPollResponse_Result interface { + isPollResponse_Result() +} + +type PollResponse_ErrorNumber struct { + ErrorNumber uint32 `protobuf:"varint,1,opt,name=error_number,json=errorNumber,proto3,oneof"` +} + +type PollResponse_Events struct { + Events uint32 `protobuf:"varint,2,opt,name=events,proto3,oneof"` +} + +func (*PollResponse_ErrorNumber) isPollResponse_Result() {} + +func (*PollResponse_Events) isPollResponse_Result() {} + +func (m *PollResponse) GetResult() isPollResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *PollResponse) GetErrorNumber() uint32 { + if x, ok := m.GetResult().(*PollResponse_ErrorNumber); ok { + return x.ErrorNumber + } + return 0 +} + +func (m *PollResponse) GetEvents() uint32 { + if x, ok := m.GetResult().(*PollResponse_Events); ok { + return x.Events + } + return 0 +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*PollResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*PollResponse_ErrorNumber)(nil), + (*PollResponse_Events)(nil), + } +} + +type SyscallRequest struct { + // Types that are valid to be assigned to Args: + // *SyscallRequest_Socket + // *SyscallRequest_Sendmsg + // *SyscallRequest_Recvmsg + // *SyscallRequest_Bind + // *SyscallRequest_Accept + // *SyscallRequest_Connect + // *SyscallRequest_Listen + // *SyscallRequest_Shutdown + // *SyscallRequest_Close + // *SyscallRequest_GetSockOpt + // *SyscallRequest_SetSockOpt + // *SyscallRequest_GetSockName + // *SyscallRequest_GetPeerName + // *SyscallRequest_EpollWait + // *SyscallRequest_EpollCtl + // *SyscallRequest_EpollCreate1 + // *SyscallRequest_Poll + // *SyscallRequest_Read + // *SyscallRequest_Write + // *SyscallRequest_Open + // *SyscallRequest_Ioctl + // *SyscallRequest_WriteFile + // *SyscallRequest_ReadFile + Args isSyscallRequest_Args `protobuf_oneof:"args"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SyscallRequest) Reset() { *m = SyscallRequest{} } +func (m *SyscallRequest) String() string { return proto.CompactTextString(m) } +func (*SyscallRequest) ProtoMessage() {} +func (*SyscallRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{49} +} + +func (m *SyscallRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SyscallRequest.Unmarshal(m, b) +} +func (m *SyscallRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SyscallRequest.Marshal(b, m, deterministic) +} +func (m *SyscallRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_SyscallRequest.Merge(m, src) +} +func (m *SyscallRequest) XXX_Size() int { + return xxx_messageInfo_SyscallRequest.Size(m) +} +func (m *SyscallRequest) XXX_DiscardUnknown() { + xxx_messageInfo_SyscallRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_SyscallRequest proto.InternalMessageInfo + +type isSyscallRequest_Args interface { + isSyscallRequest_Args() +} + +type SyscallRequest_Socket struct { + Socket *SocketRequest `protobuf:"bytes,1,opt,name=socket,proto3,oneof"` +} + +type SyscallRequest_Sendmsg struct { + Sendmsg *SendmsgRequest `protobuf:"bytes,2,opt,name=sendmsg,proto3,oneof"` +} + +type SyscallRequest_Recvmsg struct { + Recvmsg *RecvmsgRequest `protobuf:"bytes,3,opt,name=recvmsg,proto3,oneof"` +} + +type SyscallRequest_Bind struct { + Bind *BindRequest `protobuf:"bytes,4,opt,name=bind,proto3,oneof"` +} + +type SyscallRequest_Accept struct { + Accept *AcceptRequest `protobuf:"bytes,5,opt,name=accept,proto3,oneof"` +} + +type SyscallRequest_Connect struct { + Connect *ConnectRequest `protobuf:"bytes,6,opt,name=connect,proto3,oneof"` +} + +type SyscallRequest_Listen struct { + Listen *ListenRequest `protobuf:"bytes,7,opt,name=listen,proto3,oneof"` +} + +type SyscallRequest_Shutdown struct { + Shutdown *ShutdownRequest `protobuf:"bytes,8,opt,name=shutdown,proto3,oneof"` +} + +type SyscallRequest_Close struct { + Close *CloseRequest `protobuf:"bytes,9,opt,name=close,proto3,oneof"` +} + +type SyscallRequest_GetSockOpt struct { + GetSockOpt *GetSockOptRequest `protobuf:"bytes,10,opt,name=get_sock_opt,json=getSockOpt,proto3,oneof"` +} + +type SyscallRequest_SetSockOpt struct { + SetSockOpt *SetSockOptRequest `protobuf:"bytes,11,opt,name=set_sock_opt,json=setSockOpt,proto3,oneof"` +} + +type SyscallRequest_GetSockName struct { + GetSockName *GetSockNameRequest `protobuf:"bytes,12,opt,name=get_sock_name,json=getSockName,proto3,oneof"` +} + +type SyscallRequest_GetPeerName struct { + GetPeerName *GetPeerNameRequest `protobuf:"bytes,13,opt,name=get_peer_name,json=getPeerName,proto3,oneof"` +} + +type SyscallRequest_EpollWait struct { + EpollWait *EpollWaitRequest `protobuf:"bytes,14,opt,name=epoll_wait,json=epollWait,proto3,oneof"` +} + +type SyscallRequest_EpollCtl struct { + EpollCtl *EpollCtlRequest `protobuf:"bytes,15,opt,name=epoll_ctl,json=epollCtl,proto3,oneof"` +} + +type SyscallRequest_EpollCreate1 struct { + EpollCreate1 *EpollCreate1Request `protobuf:"bytes,16,opt,name=epoll_create1,json=epollCreate1,proto3,oneof"` +} + +type SyscallRequest_Poll struct { + Poll *PollRequest `protobuf:"bytes,17,opt,name=poll,proto3,oneof"` +} + +type SyscallRequest_Read struct { + Read *ReadRequest `protobuf:"bytes,18,opt,name=read,proto3,oneof"` +} + +type SyscallRequest_Write struct { + Write *WriteRequest `protobuf:"bytes,19,opt,name=write,proto3,oneof"` +} + +type SyscallRequest_Open struct { + Open *OpenRequest `protobuf:"bytes,20,opt,name=open,proto3,oneof"` +} + +type SyscallRequest_Ioctl struct { + Ioctl *IOCtlRequest `protobuf:"bytes,21,opt,name=ioctl,proto3,oneof"` +} + +type SyscallRequest_WriteFile struct { + WriteFile *WriteFileRequest `protobuf:"bytes,22,opt,name=write_file,json=writeFile,proto3,oneof"` +} + +type SyscallRequest_ReadFile struct { + ReadFile *ReadFileRequest `protobuf:"bytes,23,opt,name=read_file,json=readFile,proto3,oneof"` +} + +func (*SyscallRequest_Socket) isSyscallRequest_Args() {} + +func (*SyscallRequest_Sendmsg) isSyscallRequest_Args() {} + +func (*SyscallRequest_Recvmsg) isSyscallRequest_Args() {} + +func (*SyscallRequest_Bind) isSyscallRequest_Args() {} + +func (*SyscallRequest_Accept) isSyscallRequest_Args() {} + +func (*SyscallRequest_Connect) isSyscallRequest_Args() {} + +func (*SyscallRequest_Listen) isSyscallRequest_Args() {} + +func (*SyscallRequest_Shutdown) isSyscallRequest_Args() {} + +func (*SyscallRequest_Close) isSyscallRequest_Args() {} + +func (*SyscallRequest_GetSockOpt) isSyscallRequest_Args() {} + +func (*SyscallRequest_SetSockOpt) isSyscallRequest_Args() {} + +func (*SyscallRequest_GetSockName) isSyscallRequest_Args() {} + +func (*SyscallRequest_GetPeerName) isSyscallRequest_Args() {} + +func (*SyscallRequest_EpollWait) isSyscallRequest_Args() {} + +func (*SyscallRequest_EpollCtl) isSyscallRequest_Args() {} + +func (*SyscallRequest_EpollCreate1) isSyscallRequest_Args() {} + +func (*SyscallRequest_Poll) isSyscallRequest_Args() {} + +func (*SyscallRequest_Read) isSyscallRequest_Args() {} + +func (*SyscallRequest_Write) isSyscallRequest_Args() {} + +func (*SyscallRequest_Open) isSyscallRequest_Args() {} + +func (*SyscallRequest_Ioctl) isSyscallRequest_Args() {} + +func (*SyscallRequest_WriteFile) isSyscallRequest_Args() {} + +func (*SyscallRequest_ReadFile) isSyscallRequest_Args() {} + +func (m *SyscallRequest) GetArgs() isSyscallRequest_Args { + if m != nil { + return m.Args + } + return nil +} + +func (m *SyscallRequest) GetSocket() *SocketRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Socket); ok { + return x.Socket + } + return nil +} + +func (m *SyscallRequest) GetSendmsg() *SendmsgRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Sendmsg); ok { + return x.Sendmsg + } + return nil +} + +func (m *SyscallRequest) GetRecvmsg() *RecvmsgRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Recvmsg); ok { + return x.Recvmsg + } + return nil +} + +func (m *SyscallRequest) GetBind() *BindRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Bind); ok { + return x.Bind + } + return nil +} + +func (m *SyscallRequest) GetAccept() *AcceptRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Accept); ok { + return x.Accept + } + return nil +} + +func (m *SyscallRequest) GetConnect() *ConnectRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Connect); ok { + return x.Connect + } + return nil +} + +func (m *SyscallRequest) GetListen() *ListenRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Listen); ok { + return x.Listen + } + return nil +} + +func (m *SyscallRequest) GetShutdown() *ShutdownRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Shutdown); ok { + return x.Shutdown + } + return nil +} + +func (m *SyscallRequest) GetClose() *CloseRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Close); ok { + return x.Close + } + return nil +} + +func (m *SyscallRequest) GetGetSockOpt() *GetSockOptRequest { + if x, ok := m.GetArgs().(*SyscallRequest_GetSockOpt); ok { + return x.GetSockOpt + } + return nil +} + +func (m *SyscallRequest) GetSetSockOpt() *SetSockOptRequest { + if x, ok := m.GetArgs().(*SyscallRequest_SetSockOpt); ok { + return x.SetSockOpt + } + return nil +} + +func (m *SyscallRequest) GetGetSockName() *GetSockNameRequest { + if x, ok := m.GetArgs().(*SyscallRequest_GetSockName); ok { + return x.GetSockName + } + return nil +} + +func (m *SyscallRequest) GetGetPeerName() *GetPeerNameRequest { + if x, ok := m.GetArgs().(*SyscallRequest_GetPeerName); ok { + return x.GetPeerName + } + return nil +} + +func (m *SyscallRequest) GetEpollWait() *EpollWaitRequest { + if x, ok := m.GetArgs().(*SyscallRequest_EpollWait); ok { + return x.EpollWait + } + return nil +} + +func (m *SyscallRequest) GetEpollCtl() *EpollCtlRequest { + if x, ok := m.GetArgs().(*SyscallRequest_EpollCtl); ok { + return x.EpollCtl + } + return nil +} + +func (m *SyscallRequest) GetEpollCreate1() *EpollCreate1Request { + if x, ok := m.GetArgs().(*SyscallRequest_EpollCreate1); ok { + return x.EpollCreate1 + } + return nil +} + +func (m *SyscallRequest) GetPoll() *PollRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Poll); ok { + return x.Poll + } + return nil +} + +func (m *SyscallRequest) GetRead() *ReadRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Read); ok { + return x.Read + } + return nil +} + +func (m *SyscallRequest) GetWrite() *WriteRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Write); ok { + return x.Write + } + return nil +} + +func (m *SyscallRequest) GetOpen() *OpenRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Open); ok { + return x.Open + } + return nil +} + +func (m *SyscallRequest) GetIoctl() *IOCtlRequest { + if x, ok := m.GetArgs().(*SyscallRequest_Ioctl); ok { + return x.Ioctl + } + return nil +} + +func (m *SyscallRequest) GetWriteFile() *WriteFileRequest { + if x, ok := m.GetArgs().(*SyscallRequest_WriteFile); ok { + return x.WriteFile + } + return nil +} + +func (m *SyscallRequest) GetReadFile() *ReadFileRequest { + if x, ok := m.GetArgs().(*SyscallRequest_ReadFile); ok { + return x.ReadFile + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*SyscallRequest) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*SyscallRequest_Socket)(nil), + (*SyscallRequest_Sendmsg)(nil), + (*SyscallRequest_Recvmsg)(nil), + (*SyscallRequest_Bind)(nil), + (*SyscallRequest_Accept)(nil), + (*SyscallRequest_Connect)(nil), + (*SyscallRequest_Listen)(nil), + (*SyscallRequest_Shutdown)(nil), + (*SyscallRequest_Close)(nil), + (*SyscallRequest_GetSockOpt)(nil), + (*SyscallRequest_SetSockOpt)(nil), + (*SyscallRequest_GetSockName)(nil), + (*SyscallRequest_GetPeerName)(nil), + (*SyscallRequest_EpollWait)(nil), + (*SyscallRequest_EpollCtl)(nil), + (*SyscallRequest_EpollCreate1)(nil), + (*SyscallRequest_Poll)(nil), + (*SyscallRequest_Read)(nil), + (*SyscallRequest_Write)(nil), + (*SyscallRequest_Open)(nil), + (*SyscallRequest_Ioctl)(nil), + (*SyscallRequest_WriteFile)(nil), + (*SyscallRequest_ReadFile)(nil), + } +} + +type SyscallResponse struct { + // Types that are valid to be assigned to Result: + // *SyscallResponse_Socket + // *SyscallResponse_Sendmsg + // *SyscallResponse_Recvmsg + // *SyscallResponse_Bind + // *SyscallResponse_Accept + // *SyscallResponse_Connect + // *SyscallResponse_Listen + // *SyscallResponse_Shutdown + // *SyscallResponse_Close + // *SyscallResponse_GetSockOpt + // *SyscallResponse_SetSockOpt + // *SyscallResponse_GetSockName + // *SyscallResponse_GetPeerName + // *SyscallResponse_EpollWait + // *SyscallResponse_EpollCtl + // *SyscallResponse_EpollCreate1 + // *SyscallResponse_Poll + // *SyscallResponse_Read + // *SyscallResponse_Write + // *SyscallResponse_Open + // *SyscallResponse_Ioctl + // *SyscallResponse_WriteFile + // *SyscallResponse_ReadFile + Result isSyscallResponse_Result `protobuf_oneof:"result"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *SyscallResponse) Reset() { *m = SyscallResponse{} } +func (m *SyscallResponse) String() string { return proto.CompactTextString(m) } +func (*SyscallResponse) ProtoMessage() {} +func (*SyscallResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_dd04f3a8f0c5288b, []int{50} +} + +func (m *SyscallResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_SyscallResponse.Unmarshal(m, b) +} +func (m *SyscallResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_SyscallResponse.Marshal(b, m, deterministic) +} +func (m *SyscallResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_SyscallResponse.Merge(m, src) +} +func (m *SyscallResponse) XXX_Size() int { + return xxx_messageInfo_SyscallResponse.Size(m) +} +func (m *SyscallResponse) XXX_DiscardUnknown() { + xxx_messageInfo_SyscallResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_SyscallResponse proto.InternalMessageInfo + +type isSyscallResponse_Result interface { + isSyscallResponse_Result() +} + +type SyscallResponse_Socket struct { + Socket *SocketResponse `protobuf:"bytes,1,opt,name=socket,proto3,oneof"` +} + +type SyscallResponse_Sendmsg struct { + Sendmsg *SendmsgResponse `protobuf:"bytes,2,opt,name=sendmsg,proto3,oneof"` +} + +type SyscallResponse_Recvmsg struct { + Recvmsg *RecvmsgResponse `protobuf:"bytes,3,opt,name=recvmsg,proto3,oneof"` +} + +type SyscallResponse_Bind struct { + Bind *BindResponse `protobuf:"bytes,4,opt,name=bind,proto3,oneof"` +} + +type SyscallResponse_Accept struct { + Accept *AcceptResponse `protobuf:"bytes,5,opt,name=accept,proto3,oneof"` +} + +type SyscallResponse_Connect struct { + Connect *ConnectResponse `protobuf:"bytes,6,opt,name=connect,proto3,oneof"` +} + +type SyscallResponse_Listen struct { + Listen *ListenResponse `protobuf:"bytes,7,opt,name=listen,proto3,oneof"` +} + +type SyscallResponse_Shutdown struct { + Shutdown *ShutdownResponse `protobuf:"bytes,8,opt,name=shutdown,proto3,oneof"` +} + +type SyscallResponse_Close struct { + Close *CloseResponse `protobuf:"bytes,9,opt,name=close,proto3,oneof"` +} + +type SyscallResponse_GetSockOpt struct { + GetSockOpt *GetSockOptResponse `protobuf:"bytes,10,opt,name=get_sock_opt,json=getSockOpt,proto3,oneof"` +} + +type SyscallResponse_SetSockOpt struct { + SetSockOpt *SetSockOptResponse `protobuf:"bytes,11,opt,name=set_sock_opt,json=setSockOpt,proto3,oneof"` +} + +type SyscallResponse_GetSockName struct { + GetSockName *GetSockNameResponse `protobuf:"bytes,12,opt,name=get_sock_name,json=getSockName,proto3,oneof"` +} + +type SyscallResponse_GetPeerName struct { + GetPeerName *GetPeerNameResponse `protobuf:"bytes,13,opt,name=get_peer_name,json=getPeerName,proto3,oneof"` +} + +type SyscallResponse_EpollWait struct { + EpollWait *EpollWaitResponse `protobuf:"bytes,14,opt,name=epoll_wait,json=epollWait,proto3,oneof"` +} + +type SyscallResponse_EpollCtl struct { + EpollCtl *EpollCtlResponse `protobuf:"bytes,15,opt,name=epoll_ctl,json=epollCtl,proto3,oneof"` +} + +type SyscallResponse_EpollCreate1 struct { + EpollCreate1 *EpollCreate1Response `protobuf:"bytes,16,opt,name=epoll_create1,json=epollCreate1,proto3,oneof"` +} + +type SyscallResponse_Poll struct { + Poll *PollResponse `protobuf:"bytes,17,opt,name=poll,proto3,oneof"` +} + +type SyscallResponse_Read struct { + Read *ReadResponse `protobuf:"bytes,18,opt,name=read,proto3,oneof"` +} + +type SyscallResponse_Write struct { + Write *WriteResponse `protobuf:"bytes,19,opt,name=write,proto3,oneof"` +} + +type SyscallResponse_Open struct { + Open *OpenResponse `protobuf:"bytes,20,opt,name=open,proto3,oneof"` +} + +type SyscallResponse_Ioctl struct { + Ioctl *IOCtlResponse `protobuf:"bytes,21,opt,name=ioctl,proto3,oneof"` +} + +type SyscallResponse_WriteFile struct { + WriteFile *WriteFileResponse `protobuf:"bytes,22,opt,name=write_file,json=writeFile,proto3,oneof"` +} + +type SyscallResponse_ReadFile struct { + ReadFile *ReadFileResponse `protobuf:"bytes,23,opt,name=read_file,json=readFile,proto3,oneof"` +} + +func (*SyscallResponse_Socket) isSyscallResponse_Result() {} + +func (*SyscallResponse_Sendmsg) isSyscallResponse_Result() {} + +func (*SyscallResponse_Recvmsg) isSyscallResponse_Result() {} + +func (*SyscallResponse_Bind) isSyscallResponse_Result() {} + +func (*SyscallResponse_Accept) isSyscallResponse_Result() {} + +func (*SyscallResponse_Connect) isSyscallResponse_Result() {} + +func (*SyscallResponse_Listen) isSyscallResponse_Result() {} + +func (*SyscallResponse_Shutdown) isSyscallResponse_Result() {} + +func (*SyscallResponse_Close) isSyscallResponse_Result() {} + +func (*SyscallResponse_GetSockOpt) isSyscallResponse_Result() {} + +func (*SyscallResponse_SetSockOpt) isSyscallResponse_Result() {} + +func (*SyscallResponse_GetSockName) isSyscallResponse_Result() {} + +func (*SyscallResponse_GetPeerName) isSyscallResponse_Result() {} + +func (*SyscallResponse_EpollWait) isSyscallResponse_Result() {} + +func (*SyscallResponse_EpollCtl) isSyscallResponse_Result() {} + +func (*SyscallResponse_EpollCreate1) isSyscallResponse_Result() {} + +func (*SyscallResponse_Poll) isSyscallResponse_Result() {} + +func (*SyscallResponse_Read) isSyscallResponse_Result() {} + +func (*SyscallResponse_Write) isSyscallResponse_Result() {} + +func (*SyscallResponse_Open) isSyscallResponse_Result() {} + +func (*SyscallResponse_Ioctl) isSyscallResponse_Result() {} + +func (*SyscallResponse_WriteFile) isSyscallResponse_Result() {} + +func (*SyscallResponse_ReadFile) isSyscallResponse_Result() {} + +func (m *SyscallResponse) GetResult() isSyscallResponse_Result { + if m != nil { + return m.Result + } + return nil +} + +func (m *SyscallResponse) GetSocket() *SocketResponse { + if x, ok := m.GetResult().(*SyscallResponse_Socket); ok { + return x.Socket + } + return nil +} + +func (m *SyscallResponse) GetSendmsg() *SendmsgResponse { + if x, ok := m.GetResult().(*SyscallResponse_Sendmsg); ok { + return x.Sendmsg + } + return nil +} + +func (m *SyscallResponse) GetRecvmsg() *RecvmsgResponse { + if x, ok := m.GetResult().(*SyscallResponse_Recvmsg); ok { + return x.Recvmsg + } + return nil +} + +func (m *SyscallResponse) GetBind() *BindResponse { + if x, ok := m.GetResult().(*SyscallResponse_Bind); ok { + return x.Bind + } + return nil +} + +func (m *SyscallResponse) GetAccept() *AcceptResponse { + if x, ok := m.GetResult().(*SyscallResponse_Accept); ok { + return x.Accept + } + return nil +} + +func (m *SyscallResponse) GetConnect() *ConnectResponse { + if x, ok := m.GetResult().(*SyscallResponse_Connect); ok { + return x.Connect + } + return nil +} + +func (m *SyscallResponse) GetListen() *ListenResponse { + if x, ok := m.GetResult().(*SyscallResponse_Listen); ok { + return x.Listen + } + return nil +} + +func (m *SyscallResponse) GetShutdown() *ShutdownResponse { + if x, ok := m.GetResult().(*SyscallResponse_Shutdown); ok { + return x.Shutdown + } + return nil +} + +func (m *SyscallResponse) GetClose() *CloseResponse { + if x, ok := m.GetResult().(*SyscallResponse_Close); ok { + return x.Close + } + return nil +} + +func (m *SyscallResponse) GetGetSockOpt() *GetSockOptResponse { + if x, ok := m.GetResult().(*SyscallResponse_GetSockOpt); ok { + return x.GetSockOpt + } + return nil +} + +func (m *SyscallResponse) GetSetSockOpt() *SetSockOptResponse { + if x, ok := m.GetResult().(*SyscallResponse_SetSockOpt); ok { + return x.SetSockOpt + } + return nil +} + +func (m *SyscallResponse) GetGetSockName() *GetSockNameResponse { + if x, ok := m.GetResult().(*SyscallResponse_GetSockName); ok { + return x.GetSockName + } + return nil +} + +func (m *SyscallResponse) GetGetPeerName() *GetPeerNameResponse { + if x, ok := m.GetResult().(*SyscallResponse_GetPeerName); ok { + return x.GetPeerName + } + return nil +} + +func (m *SyscallResponse) GetEpollWait() *EpollWaitResponse { + if x, ok := m.GetResult().(*SyscallResponse_EpollWait); ok { + return x.EpollWait + } + return nil +} + +func (m *SyscallResponse) GetEpollCtl() *EpollCtlResponse { + if x, ok := m.GetResult().(*SyscallResponse_EpollCtl); ok { + return x.EpollCtl + } + return nil +} + +func (m *SyscallResponse) GetEpollCreate1() *EpollCreate1Response { + if x, ok := m.GetResult().(*SyscallResponse_EpollCreate1); ok { + return x.EpollCreate1 + } + return nil +} + +func (m *SyscallResponse) GetPoll() *PollResponse { + if x, ok := m.GetResult().(*SyscallResponse_Poll); ok { + return x.Poll + } + return nil +} + +func (m *SyscallResponse) GetRead() *ReadResponse { + if x, ok := m.GetResult().(*SyscallResponse_Read); ok { + return x.Read + } + return nil +} + +func (m *SyscallResponse) GetWrite() *WriteResponse { + if x, ok := m.GetResult().(*SyscallResponse_Write); ok { + return x.Write + } + return nil +} + +func (m *SyscallResponse) GetOpen() *OpenResponse { + if x, ok := m.GetResult().(*SyscallResponse_Open); ok { + return x.Open + } + return nil +} + +func (m *SyscallResponse) GetIoctl() *IOCtlResponse { + if x, ok := m.GetResult().(*SyscallResponse_Ioctl); ok { + return x.Ioctl + } + return nil +} + +func (m *SyscallResponse) GetWriteFile() *WriteFileResponse { + if x, ok := m.GetResult().(*SyscallResponse_WriteFile); ok { + return x.WriteFile + } + return nil +} + +func (m *SyscallResponse) GetReadFile() *ReadFileResponse { + if x, ok := m.GetResult().(*SyscallResponse_ReadFile); ok { + return x.ReadFile + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*SyscallResponse) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*SyscallResponse_Socket)(nil), + (*SyscallResponse_Sendmsg)(nil), + (*SyscallResponse_Recvmsg)(nil), + (*SyscallResponse_Bind)(nil), + (*SyscallResponse_Accept)(nil), + (*SyscallResponse_Connect)(nil), + (*SyscallResponse_Listen)(nil), + (*SyscallResponse_Shutdown)(nil), + (*SyscallResponse_Close)(nil), + (*SyscallResponse_GetSockOpt)(nil), + (*SyscallResponse_SetSockOpt)(nil), + (*SyscallResponse_GetSockName)(nil), + (*SyscallResponse_GetPeerName)(nil), + (*SyscallResponse_EpollWait)(nil), + (*SyscallResponse_EpollCtl)(nil), + (*SyscallResponse_EpollCreate1)(nil), + (*SyscallResponse_Poll)(nil), + (*SyscallResponse_Read)(nil), + (*SyscallResponse_Write)(nil), + (*SyscallResponse_Open)(nil), + (*SyscallResponse_Ioctl)(nil), + (*SyscallResponse_WriteFile)(nil), + (*SyscallResponse_ReadFile)(nil), + } +} + +func init() { + proto.RegisterType((*SendmsgRequest)(nil), "syscall_rpc.SendmsgRequest") + proto.RegisterType((*SendmsgResponse)(nil), "syscall_rpc.SendmsgResponse") + proto.RegisterType((*IOCtlRequest)(nil), "syscall_rpc.IOCtlRequest") + proto.RegisterType((*IOCtlResponse)(nil), "syscall_rpc.IOCtlResponse") + proto.RegisterType((*RecvmsgRequest)(nil), "syscall_rpc.RecvmsgRequest") + proto.RegisterType((*OpenRequest)(nil), "syscall_rpc.OpenRequest") + proto.RegisterType((*OpenResponse)(nil), "syscall_rpc.OpenResponse") + proto.RegisterType((*ReadRequest)(nil), "syscall_rpc.ReadRequest") + proto.RegisterType((*ReadResponse)(nil), "syscall_rpc.ReadResponse") + proto.RegisterType((*ReadFileRequest)(nil), "syscall_rpc.ReadFileRequest") + proto.RegisterType((*ReadFileResponse)(nil), "syscall_rpc.ReadFileResponse") + proto.RegisterType((*WriteRequest)(nil), "syscall_rpc.WriteRequest") + proto.RegisterType((*WriteResponse)(nil), "syscall_rpc.WriteResponse") + proto.RegisterType((*WriteFileRequest)(nil), "syscall_rpc.WriteFileRequest") + proto.RegisterType((*WriteFileResponse)(nil), "syscall_rpc.WriteFileResponse") + proto.RegisterType((*AddressResponse)(nil), "syscall_rpc.AddressResponse") + proto.RegisterType((*RecvmsgResponse)(nil), "syscall_rpc.RecvmsgResponse") + proto.RegisterType((*RecvmsgResponse_ResultPayload)(nil), "syscall_rpc.RecvmsgResponse.ResultPayload") + proto.RegisterType((*BindRequest)(nil), "syscall_rpc.BindRequest") + proto.RegisterType((*BindResponse)(nil), "syscall_rpc.BindResponse") + proto.RegisterType((*AcceptRequest)(nil), "syscall_rpc.AcceptRequest") + proto.RegisterType((*AcceptResponse)(nil), "syscall_rpc.AcceptResponse") + proto.RegisterType((*AcceptResponse_ResultPayload)(nil), "syscall_rpc.AcceptResponse.ResultPayload") + proto.RegisterType((*ConnectRequest)(nil), "syscall_rpc.ConnectRequest") + proto.RegisterType((*ConnectResponse)(nil), "syscall_rpc.ConnectResponse") + proto.RegisterType((*ListenRequest)(nil), "syscall_rpc.ListenRequest") + proto.RegisterType((*ListenResponse)(nil), "syscall_rpc.ListenResponse") + proto.RegisterType((*ShutdownRequest)(nil), "syscall_rpc.ShutdownRequest") + proto.RegisterType((*ShutdownResponse)(nil), "syscall_rpc.ShutdownResponse") + proto.RegisterType((*CloseRequest)(nil), "syscall_rpc.CloseRequest") + proto.RegisterType((*CloseResponse)(nil), "syscall_rpc.CloseResponse") + proto.RegisterType((*GetSockOptRequest)(nil), "syscall_rpc.GetSockOptRequest") + proto.RegisterType((*GetSockOptResponse)(nil), "syscall_rpc.GetSockOptResponse") + proto.RegisterType((*SetSockOptRequest)(nil), "syscall_rpc.SetSockOptRequest") + proto.RegisterType((*SetSockOptResponse)(nil), "syscall_rpc.SetSockOptResponse") + proto.RegisterType((*GetSockNameRequest)(nil), "syscall_rpc.GetSockNameRequest") + proto.RegisterType((*GetSockNameResponse)(nil), "syscall_rpc.GetSockNameResponse") + proto.RegisterType((*GetPeerNameRequest)(nil), "syscall_rpc.GetPeerNameRequest") + proto.RegisterType((*GetPeerNameResponse)(nil), "syscall_rpc.GetPeerNameResponse") + proto.RegisterType((*SocketRequest)(nil), "syscall_rpc.SocketRequest") + proto.RegisterType((*SocketResponse)(nil), "syscall_rpc.SocketResponse") + proto.RegisterType((*EpollWaitRequest)(nil), "syscall_rpc.EpollWaitRequest") + proto.RegisterType((*EpollEvent)(nil), "syscall_rpc.EpollEvent") + proto.RegisterType((*EpollEvents)(nil), "syscall_rpc.EpollEvents") + proto.RegisterType((*EpollWaitResponse)(nil), "syscall_rpc.EpollWaitResponse") + proto.RegisterType((*EpollCtlRequest)(nil), "syscall_rpc.EpollCtlRequest") + proto.RegisterType((*EpollCtlResponse)(nil), "syscall_rpc.EpollCtlResponse") + proto.RegisterType((*EpollCreate1Request)(nil), "syscall_rpc.EpollCreate1Request") + proto.RegisterType((*EpollCreate1Response)(nil), "syscall_rpc.EpollCreate1Response") + proto.RegisterType((*PollRequest)(nil), "syscall_rpc.PollRequest") + proto.RegisterType((*PollResponse)(nil), "syscall_rpc.PollResponse") + proto.RegisterType((*SyscallRequest)(nil), "syscall_rpc.SyscallRequest") + proto.RegisterType((*SyscallResponse)(nil), "syscall_rpc.SyscallResponse") +} + +func init() { + proto.RegisterFile("pkg/sentry/socket/rpcinet/syscall_rpc.proto", fileDescriptor_dd04f3a8f0c5288b) +} + +var fileDescriptor_dd04f3a8f0c5288b = []byte{ + // 1838 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xbc, 0x98, 0xdd, 0x52, 0xe3, 0xc8, + 0x15, 0x80, 0x6d, 0x6c, 0xc0, 0x1c, 0xff, 0xa2, 0x21, 0xac, 0xf8, 0x67, 0x95, 0xa4, 0x8a, 0x4d, + 0x2a, 0xb8, 0xc6, 0x33, 0xec, 0x90, 0xdd, 0xad, 0xdd, 0x2c, 0x04, 0xd6, 0x53, 0x99, 0x1a, 0x88, + 0x5c, 0x09, 0xa9, 0xe4, 0xc2, 0x25, 0xa4, 0xb6, 0x71, 0x21, 0x4b, 0x8a, 0x24, 0x43, 0x71, 0x93, + 0x07, 0xc8, 0x75, 0xee, 0x72, 0x91, 0x87, 0xca, 0x03, 0xe4, 0x09, 0xf2, 0x0e, 0xa9, 0xd3, 0xdd, + 0x92, 0xba, 0x45, 0x8b, 0xc1, 0x29, 0x6a, 0xef, 0xd4, 0xad, 0xf3, 0xd7, 0xe7, 0x74, 0x7f, 0x3a, + 0x2d, 0xf8, 0x65, 0x70, 0x3b, 0xee, 0x46, 0xc4, 0x8b, 0xc3, 0x87, 0x6e, 0xe4, 0xdb, 0xb7, 0x24, + 0xee, 0x86, 0x81, 0x3d, 0xf1, 0x48, 0xdc, 0x8d, 0x1e, 0x22, 0xdb, 0x72, 0xdd, 0x61, 0x18, 0xd8, + 0x87, 0x41, 0xe8, 0xc7, 0xbe, 0x56, 0x17, 0xa6, 0x8c, 0xbf, 0x97, 0xa1, 0x35, 0x20, 0x9e, 0x33, + 0x8d, 0xc6, 0x26, 0xf9, 0xeb, 0x8c, 0x44, 0xb1, 0xd6, 0x82, 0x85, 0x91, 0xa3, 0x97, 0xf7, 0xcb, + 0x07, 0x4d, 0x73, 0x61, 0xe4, 0x68, 0xeb, 0x50, 0x75, 0xac, 0xd8, 0xd2, 0x17, 0xf6, 0xcb, 0x07, + 0x8d, 0x93, 0x85, 0x5a, 0xd9, 0xa4, 0x63, 0x4d, 0x87, 0x65, 0xcb, 0x71, 0x42, 0x12, 0x45, 0x7a, + 0x05, 0x5f, 0x99, 0xc9, 0x50, 0xd3, 0xa0, 0x3a, 0xf5, 0x43, 0xa2, 0x57, 0xf7, 0xcb, 0x07, 0x35, + 0x93, 0x3e, 0x6b, 0x06, 0x34, 0x89, 0xe7, 0x0c, 0xfd, 0xd1, 0x30, 0x24, 0xb6, 0x1f, 0x3a, 0xfa, + 0x22, 0x7d, 0x59, 0x27, 0x9e, 0x73, 0x31, 0x32, 0xe9, 0x94, 0xf1, 0x67, 0x68, 0xa7, 0xb1, 0x44, + 0x81, 0xef, 0x45, 0x44, 0xfb, 0x29, 0x34, 0x48, 0x18, 0xfa, 0xe1, 0xd0, 0x9b, 0x4d, 0xaf, 0x49, + 0xc8, 0xc2, 0xea, 0x97, 0xcc, 0x3a, 0x9d, 0xfd, 0x48, 0x27, 0x35, 0x1d, 0x96, 0x5c, 0xe2, 0x8d, + 0xe3, 0x1b, 0x1a, 0x23, 0xbe, 0xe6, 0xe3, 0x93, 0x1a, 0x2c, 0x85, 0x24, 0x9a, 0xb9, 0xb1, 0x71, + 0x02, 0x8d, 0xf7, 0x17, 0xa7, 0xb1, 0x5b, 0xb4, 0xca, 0x0e, 0x54, 0xec, 0xa9, 0xc3, 0x0c, 0x98, + 0xf8, 0x88, 0x33, 0x56, 0x38, 0xe6, 0x6b, 0xc3, 0x47, 0xe3, 0x8f, 0xd0, 0xe4, 0x36, 0xe6, 0x89, + 0x6e, 0x1d, 0x16, 0xef, 0x2c, 0x77, 0x46, 0x58, 0x02, 0xfb, 0x25, 0x93, 0x0d, 0x85, 0xd8, 0xfe, + 0x59, 0x86, 0x96, 0x49, 0xec, 0xbb, 0x27, 0x8b, 0x20, 0x2d, 0x31, 0x59, 0x20, 0xce, 0x47, 0xc4, + 0x73, 0x48, 0x48, 0xe3, 0xac, 0x99, 0x7c, 0x84, 0x25, 0x08, 0x08, 0xb9, 0x4d, 0x4a, 0x80, 0xcf, + 0xda, 0x1a, 0x2c, 0xc6, 0xe1, 0xcc, 0xb3, 0x79, 0xea, 0xd9, 0x40, 0xdb, 0x83, 0xba, 0x3d, 0x8d, + 0xc6, 0x43, 0x6e, 0x7e, 0x89, 0x9a, 0x07, 0x9c, 0xfa, 0x40, 0x67, 0x8c, 0xdf, 0x41, 0xfd, 0x22, + 0x20, 0x5e, 0x12, 0x19, 0x5a, 0xb6, 0xe2, 0x1b, 0x1a, 0x5b, 0xc3, 0xa4, 0xcf, 0x68, 0x79, 0xe4, + 0x5a, 0xe3, 0x88, 0x07, 0xc7, 0x06, 0x6c, 0x1b, 0x38, 0x84, 0x46, 0xd6, 0x34, 0xe9, 0xb3, 0x71, + 0x01, 0x0d, 0x66, 0x6c, 0x9e, 0x0c, 0x76, 0x68, 0x32, 0x92, 0xda, 0x2e, 0x8c, 0x1c, 0x21, 0x77, + 0x47, 0x50, 0x37, 0x89, 0xe5, 0xcc, 0x99, 0x37, 0xe3, 0x0a, 0x1a, 0x4c, 0x6d, 0xbe, 0x7d, 0x96, + 0x3b, 0x09, 0xfd, 0x12, 0x3b, 0x0b, 0x42, 0x3c, 0x3f, 0x87, 0x36, 0x1a, 0x3e, 0x9f, 0xb8, 0x44, + 0x95, 0xb1, 0x15, 0x96, 0x31, 0xe3, 0x2f, 0xd0, 0xc9, 0xc4, 0x5e, 0x3a, 0x86, 0x2f, 0xa1, 0x71, + 0x15, 0x4e, 0x62, 0x32, 0xe7, 0x89, 0x36, 0xfe, 0x04, 0x4d, 0xae, 0xf7, 0xd2, 0xa7, 0xef, 0x37, + 0xd0, 0xa1, 0x96, 0x3f, 0x91, 0x16, 0x64, 0x8a, 0xed, 0x7b, 0x31, 0xf1, 0x62, 0x16, 0x9c, 0x99, + 0x0c, 0x8d, 0x4b, 0x58, 0x15, 0x2c, 0xf0, 0xf8, 0x3e, 0x57, 0xc5, 0x97, 0x8f, 0x6e, 0xf9, 0x3e, + 0x9c, 0xc4, 0x31, 0xf1, 0xf8, 0x0e, 0x48, 0x86, 0xc6, 0x29, 0xb4, 0xbf, 0x67, 0xc0, 0x4a, 0xed, + 0x09, 0x48, 0x2b, 0xcb, 0x48, 0x2b, 0xda, 0x47, 0xff, 0x5a, 0xc0, 0x7a, 0xf3, 0xa3, 0x3b, 0x4f, + 0xd6, 0xce, 0x61, 0x39, 0xb0, 0x1e, 0x5c, 0xdf, 0x62, 0x1b, 0xbb, 0xde, 0xfb, 0xc5, 0xa1, 0x88, + 0xea, 0x9c, 0xcd, 0x43, 0x93, 0xe6, 0xf1, 0x92, 0x69, 0xf4, 0x4b, 0x66, 0xa2, 0xbc, 0xf9, 0x8f, + 0x32, 0x34, 0xa5, 0x97, 0x69, 0x75, 0xcb, 0x39, 0x5e, 0x7f, 0x99, 0x2d, 0x8e, 0x79, 0xdc, 0x96, + 0x3c, 0xe6, 0x72, 0xa1, 0x5a, 0x7a, 0x45, 0x42, 0xcf, 0x16, 0xac, 0x50, 0x70, 0x50, 0x67, 0x55, + 0x9a, 0xae, 0x1a, 0x4e, 0xfc, 0x56, 0xde, 0x8c, 0xef, 0xa0, 0x7e, 0x32, 0xf1, 0x0a, 0x0f, 0xa8, + 0x2e, 0x47, 0x95, 0xa5, 0xdc, 0x78, 0x0d, 0x0d, 0xa6, 0xf8, 0xec, 0x62, 0x1b, 0xef, 0xa1, 0xf9, + 0xbd, 0x6d, 0x93, 0x20, 0x2e, 0xf2, 0xc6, 0xb0, 0x18, 0x52, 0x57, 0x0c, 0x8b, 0x61, 0x06, 0x2f, + 0x5c, 0x5e, 0x85, 0xc3, 0xcb, 0xf8, 0x4f, 0x19, 0x5a, 0x89, 0xad, 0x79, 0xea, 0x7a, 0x96, 0xaf, + 0xeb, 0x17, 0x72, 0x96, 0x25, 0x93, 0xc5, 0x65, 0xbd, 0xca, 0x57, 0x35, 0xbf, 0x92, 0xff, 0xb3, + 0x9a, 0x42, 0x61, 0xbe, 0x82, 0xd6, 0xa9, 0xef, 0x79, 0xc4, 0x8e, 0xe7, 0xaf, 0xcd, 0x5b, 0x68, + 0xa7, 0xba, 0xcf, 0x2f, 0xcf, 0xaf, 0xa1, 0xf9, 0x61, 0x12, 0xc5, 0xd9, 0xb7, 0x44, 0xe1, 0xf0, + 0xda, 0xb2, 0x6f, 0x5d, 0x7f, 0x4c, 0x1d, 0x56, 0xcc, 0x64, 0x68, 0xbc, 0x81, 0x56, 0xa2, 0xfa, + 0x7c, 0x7f, 0x6f, 0xa0, 0x3d, 0xb8, 0x99, 0xc5, 0x8e, 0x7f, 0xef, 0x3d, 0xf1, 0xd9, 0xbf, 0xf1, + 0xef, 0xb9, 0x37, 0x7c, 0x34, 0x8e, 0xa0, 0x93, 0x29, 0x3d, 0xdf, 0xd7, 0x2e, 0x34, 0x4e, 0x5d, + 0x3f, 0x2a, 0x62, 0xae, 0xd1, 0x83, 0x26, 0x7f, 0xff, 0x7c, 0x9b, 0x04, 0x56, 0x7f, 0x20, 0xf1, + 0xc0, 0xb7, 0x6f, 0x2f, 0x8a, 0xb7, 0xf4, 0x1a, 0x2c, 0xba, 0xe4, 0x8e, 0xb8, 0x7c, 0x0d, 0x6c, + 0x80, 0x1b, 0xdd, 0xb3, 0xa6, 0x84, 0xef, 0x69, 0xfa, 0x2c, 0x1c, 0xe4, 0x6a, 0xee, 0x5b, 0xa8, + 0x89, 0x6e, 0xe6, 0xd9, 0xed, 0x1a, 0x54, 0xfc, 0x20, 0x4e, 0x3b, 0x1b, 0x1c, 0x08, 0x3b, 0x6c, + 0x08, 0xab, 0x83, 0x17, 0x8c, 0xbf, 0xc3, 0x9c, 0x31, 0xd4, 0xe0, 0xa3, 0xf1, 0x0e, 0xb4, 0xc1, + 0xe3, 0xc8, 0x9f, 0x91, 0xd9, 0x9f, 0xa5, 0x4b, 0xfe, 0x68, 0x4d, 0x0b, 0x6b, 0xf6, 0x37, 0x78, + 0x25, 0x49, 0xcd, 0x93, 0x99, 0xe3, 0xb9, 0xce, 0x27, 0x1e, 0xfd, 0xc7, 0x27, 0x94, 0x45, 0x79, + 0x49, 0x48, 0xf8, 0xe9, 0x28, 0x33, 0xa9, 0x1f, 0x3b, 0xca, 0x2b, 0x68, 0x0e, 0xe8, 0x9d, 0x23, + 0x09, 0x70, 0x1d, 0x96, 0x46, 0xd6, 0x74, 0xe2, 0x3e, 0x50, 0x9f, 0x15, 0x93, 0x8f, 0xb0, 0xa6, + 0xf1, 0x43, 0x40, 0x78, 0xa1, 0xe9, 0xb3, 0xb6, 0x09, 0x35, 0x7a, 0x2b, 0xb1, 0x7d, 0x97, 0xd7, + 0x3a, 0x1d, 0x1b, 0xbf, 0x87, 0x56, 0x62, 0xf8, 0xa5, 0xba, 0xc5, 0x3f, 0x40, 0xe7, 0x2c, 0xf0, + 0x5d, 0xf7, 0xca, 0x9a, 0x14, 0x6e, 0xc8, 0x1d, 0x00, 0x6f, 0x36, 0x1d, 0x92, 0x3b, 0xe2, 0xc5, + 0x49, 0x47, 0xbb, 0xe2, 0xcd, 0xa6, 0x67, 0x74, 0x82, 0x76, 0xb5, 0x11, 0xb1, 0x69, 0xb4, 0x9a, + 0x49, 0x9f, 0x8d, 0xb7, 0x00, 0xd4, 0x2c, 0x15, 0x51, 0xf5, 0xa0, 0x92, 0x31, 0x3e, 0x32, 0xbe, + 0x85, 0x7a, 0xa6, 0x15, 0x69, 0xdd, 0x54, 0xac, 0xbc, 0x5f, 0x39, 0xa8, 0xf7, 0x3e, 0x93, 0x4a, + 0x91, 0x49, 0xa6, 0xfa, 0x77, 0xb0, 0x2a, 0x2c, 0x66, 0x9e, 0x14, 0xf5, 0xa4, 0x88, 0xea, 0x3d, + 0xbd, 0xc0, 0x55, 0x84, 0xcd, 0x1c, 0x93, 0x14, 0x92, 0x18, 0x43, 0x9b, 0x8a, 0x08, 0xb7, 0x29, + 0x0d, 0xaa, 0x24, 0x48, 0x17, 0x4d, 0x9f, 0x31, 0x0d, 0x7e, 0xc0, 0x8b, 0xbd, 0xe0, 0x07, 0x3c, + 0x2d, 0x95, 0x34, 0x2d, 0xbf, 0x82, 0x45, 0x6a, 0x9a, 0x1e, 0xe8, 0x27, 0x96, 0xcb, 0xa4, 0x90, + 0xcb, 0x99, 0xd7, 0xe7, 0x9f, 0xf4, 0x2f, 0xe0, 0x15, 0x53, 0x0b, 0x89, 0x15, 0x93, 0xd7, 0x42, + 0xc0, 0xf8, 0x9d, 0xe7, 0x3b, 0x94, 0x3e, 0x1b, 0x57, 0xb0, 0x26, 0x8b, 0xbe, 0xe0, 0x1d, 0xe5, + 0xd2, 0x77, 0xdd, 0x27, 0xee, 0x28, 0xca, 0xfd, 0x71, 0x05, 0x0d, 0xa6, 0x36, 0x67, 0x37, 0x2e, + 0x1a, 0x53, 0x16, 0xf0, 0xdf, 0x00, 0xad, 0x01, 0x4b, 0x76, 0x12, 0xd3, 0x5b, 0x58, 0x62, 0x3f, + 0x0e, 0xa8, 0xd5, 0x7a, 0x6f, 0x53, 0xaa, 0x86, 0x74, 0xbe, 0xd1, 0x24, 0x93, 0xd5, 0xde, 0xc1, + 0x72, 0xc4, 0x2e, 0xec, 0x7c, 0x23, 0x6d, 0xc9, 0x6a, 0xd2, 0x8f, 0x05, 0xa4, 0x07, 0x97, 0x46, + 0xc5, 0x90, 0x75, 0xb8, 0x74, 0x43, 0xe4, 0x15, 0xe5, 0xcb, 0x30, 0x2a, 0x72, 0x69, 0xed, 0x10, + 0xaa, 0xd7, 0x13, 0xcf, 0xe1, 0x7b, 0x46, 0xde, 0xb7, 0x42, 0x9b, 0x89, 0x97, 0x22, 0x94, 0xc3, + 0x75, 0x59, 0xb4, 0xe5, 0xa2, 0x97, 0xde, 0xfc, 0xba, 0xa4, 0x66, 0x11, 0xd7, 0xc5, 0x64, 0x31, + 0x3c, 0x9b, 0xb5, 0x37, 0xf4, 0x3e, 0x9c, 0x0f, 0x4f, 0x6e, 0x9b, 0x30, 0x3c, 0x2e, 0x8d, 0xee, + 0x5c, 0xda, 0xa6, 0xe8, 0xcb, 0x0a, 0x77, 0x52, 0xf3, 0x43, 0xef, 0x49, 0x74, 0x42, 0xfb, 0x0a, + 0x6a, 0x11, 0x6f, 0x39, 0xf4, 0x9a, 0x02, 0xc3, 0xb9, 0x26, 0xa6, 0x5f, 0x32, 0x53, 0x79, 0xed, + 0x35, 0x2c, 0xda, 0xd8, 0x57, 0xe8, 0x2b, 0x54, 0x71, 0x43, 0x0e, 0x54, 0xe8, 0x48, 0xfa, 0x25, + 0x93, 0x49, 0x6a, 0x27, 0xd0, 0x18, 0x93, 0x78, 0x88, 0x35, 0x1c, 0xe2, 0x07, 0x15, 0xa8, 0xe6, + 0xae, 0xa4, 0xf9, 0xa8, 0xef, 0xe8, 0x97, 0x4c, 0x18, 0xa7, 0x93, 0x68, 0x23, 0x12, 0x6d, 0xd4, + 0x15, 0x36, 0x06, 0x2a, 0x1b, 0x51, 0x66, 0xe3, 0x0c, 0x9a, 0x69, 0x1c, 0xf4, 0x63, 0xdf, 0xa0, + 0x46, 0xf6, 0x54, 0x81, 0x08, 0x1f, 0x40, 0xdc, 0xf1, 0xe3, 0x6c, 0x36, 0x31, 0x83, 0xbd, 0x3c, + 0x33, 0xd3, 0x54, 0x9b, 0xc9, 0x7d, 0x47, 0xb9, 0x99, 0x64, 0x56, 0xfb, 0x16, 0x80, 0xe0, 0xe9, + 0x1f, 0xde, 0x5b, 0x93, 0x58, 0x6f, 0x51, 0x1b, 0x3b, 0x8f, 0x99, 0x24, 0x7c, 0x39, 0xfa, 0x25, + 0x73, 0x85, 0x24, 0x73, 0xda, 0xd7, 0xc0, 0x06, 0x43, 0x3b, 0x76, 0xf5, 0xb6, 0xa2, 0x8a, 0x39, + 0x66, 0x62, 0x15, 0x09, 0x9f, 0xd2, 0x7e, 0x80, 0x26, 0x57, 0x66, 0xec, 0xd1, 0x3b, 0xd4, 0xc0, + 0xbe, 0xc2, 0x80, 0xc4, 0xb1, 0x7e, 0xc9, 0x6c, 0x10, 0x61, 0x1a, 0xcf, 0x07, 0x0e, 0xf5, 0x55, + 0xc5, 0xf9, 0x10, 0x18, 0x84, 0xe7, 0x03, 0xe5, 0x50, 0x3e, 0x24, 0x96, 0xa3, 0x6b, 0x0a, 0x79, + 0xe1, 0xbf, 0x0a, 0xca, 0xa3, 0x1c, 0x6e, 0x37, 0xbc, 0x3f, 0x13, 0xfd, 0x95, 0x62, 0xbb, 0x89, + 0x3f, 0x1d, 0x70, 0xbb, 0x51, 0x49, 0x74, 0xe1, 0x07, 0xc4, 0xd3, 0xd7, 0x14, 0x2e, 0x84, 0x1f, + 0x4b, 0xe8, 0x02, 0xe5, 0xd0, 0xc5, 0xc4, 0xc7, 0x24, 0xfe, 0x44, 0xe1, 0x42, 0xfc, 0x87, 0x87, + 0x2e, 0xa8, 0x24, 0xd6, 0x8e, 0xfa, 0x1a, 0x8e, 0x26, 0x2e, 0xd1, 0xd7, 0x15, 0xb5, 0xcb, 0xff, + 0x7d, 0xc0, 0xda, 0xdd, 0x27, 0x73, 0x58, 0x3b, 0x5c, 0x1d, 0x53, 0xff, 0x4c, 0x51, 0xbb, 0xdc, + 0x2f, 0x1d, 0xac, 0x5d, 0xc8, 0xa7, 0x4e, 0x96, 0xa0, 0x6a, 0x85, 0xe3, 0xc8, 0xf8, 0x2f, 0x40, + 0x3b, 0xa5, 0x2a, 0x47, 0xf6, 0x51, 0x0e, 0xab, 0x5b, 0x4a, 0xac, 0xa6, 0xdd, 0x55, 0xc2, 0xd5, + 0xe3, 0x3c, 0x57, 0xb7, 0xd5, 0x5c, 0xcd, 0xda, 0xb2, 0x04, 0xac, 0xc7, 0x79, 0xb0, 0x6e, 0x3f, + 0xf5, 0x5b, 0x41, 0x24, 0x6b, 0x57, 0x22, 0xeb, 0x86, 0x82, 0xac, 0xa9, 0x0e, 0x43, 0xeb, 0x51, + 0x0e, 0xad, 0x5b, 0x4f, 0x5c, 0x74, 0x05, 0xb6, 0x1e, 0xe7, 0xd9, 0xba, 0xad, 0x66, 0x6b, 0x16, + 0x61, 0x02, 0xd7, 0xa3, 0x1c, 0x5c, 0xb7, 0x94, 0x70, 0xcd, 0x1c, 0x72, 0xba, 0x7e, 0xfd, 0x88, + 0xae, 0x3b, 0x05, 0x74, 0x4d, 0x55, 0x33, 0xbc, 0xf6, 0x64, 0xbc, 0x6e, 0xaa, 0xf0, 0x9a, 0xaa, + 0x71, 0xbe, 0x9e, 0x2a, 0xf9, 0xba, 0x57, 0xc8, 0xd7, 0x54, 0x5f, 0x04, 0xec, 0xa9, 0x12, 0xb0, + 0x7b, 0x85, 0x80, 0xcd, 0x8c, 0x08, 0x84, 0x3d, 0x57, 0x13, 0x76, 0xbf, 0x98, 0xb0, 0xa9, 0x19, + 0x09, 0xb1, 0xe7, 0x6a, 0xc4, 0xee, 0x17, 0x23, 0x56, 0xb2, 0x93, 0x32, 0xf6, 0x3b, 0x05, 0x63, + 0x77, 0x8b, 0x18, 0x9b, 0x9a, 0x10, 0x20, 0xfb, 0xcd, 0x63, 0xc8, 0xee, 0x14, 0x40, 0x36, 0x2b, + 0x66, 0x4a, 0xd9, 0xbe, 0x9a, 0xb2, 0x9f, 0x3f, 0x41, 0xd9, 0xd4, 0x8a, 0x8c, 0xd9, 0xae, 0x84, + 0xd9, 0x0d, 0x05, 0x66, 0xb3, 0xc3, 0x42, 0x39, 0xdb, 0x95, 0x38, 0xbb, 0xa1, 0xe0, 0x6c, 0xa6, + 0x40, 0x41, 0xdb, 0x93, 0x41, 0xbb, 0xa9, 0x02, 0x6d, 0xb6, 0xf1, 0x18, 0x69, 0xbb, 0x12, 0x69, + 0x37, 0x14, 0xa4, 0xcd, 0x9c, 0x50, 0xd4, 0xf6, 0x64, 0xd4, 0x6e, 0xaa, 0x50, 0x9b, 0x39, 0x61, + 0xac, 0xfd, 0x4e, 0xc1, 0xda, 0xdd, 0x22, 0xd6, 0x66, 0x35, 0xcc, 0x60, 0xfb, 0xcd, 0x63, 0xd8, + 0xee, 0x14, 0xc0, 0x36, 0xab, 0x61, 0x4a, 0xdb, 0xb4, 0x8b, 0xbd, 0x5e, 0xa2, 0x17, 0xc5, 0x37, + 0xff, 0x0b, 0x00, 0x00, 0xff, 0xff, 0xd5, 0x22, 0x29, 0x56, 0xfd, 0x1a, 0x00, 0x00, +} diff --git a/pkg/sentry/socket/socket_state_autogen.go b/pkg/sentry/socket/socket_state_autogen.go new file mode 100755 index 000000000..3df219247 --- /dev/null +++ b/pkg/sentry/socket/socket_state_autogen.go @@ -0,0 +1,24 @@ +// automatically generated by stateify. + +package socket + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SendReceiveTimeout) beforeSave() {} +func (x *SendReceiveTimeout) save(m state.Map) { + x.beforeSave() + m.Save("send", &x.send) + m.Save("recv", &x.recv) +} + +func (x *SendReceiveTimeout) afterLoad() {} +func (x *SendReceiveTimeout) load(m state.Map) { + m.Load("send", &x.send) + m.Load("recv", &x.recv) +} + +func init() { + state.Register("socket.SendReceiveTimeout", (*SendReceiveTimeout)(nil), state.Fns{Save: (*SendReceiveTimeout).save, Load: (*SendReceiveTimeout).load}) +} diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD deleted file mode 100644 index 8580eb87d..000000000 --- a/pkg/sentry/socket/unix/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "unix", - srcs = [ - "device.go", - "io.go", - "unix.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/refs", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", - "//pkg/sentry/socket", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/epsocket", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD deleted file mode 100644 index 82173dea7..000000000 --- a/pkg/sentry/socket/unix/transport/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -go_template_instance( - name = "transport_message_list", - out = "transport_message_list.go", - package = "transport", - prefix = "message", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*message", - "Linker": "*message", - }, -) - -go_library( - name = "transport", - srcs = [ - "connectioned.go", - "connectioned_state.go", - "connectionless.go", - "queue.go", - "transport_message_list.go", - "unix.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport", - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/ilist", - "//pkg/refs", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/unix/transport/transport_message_list.go b/pkg/sentry/socket/unix/transport/transport_message_list.go new file mode 100755 index 000000000..6d394860c --- /dev/null +++ b/pkg/sentry/socket/unix/transport/transport_message_list.go @@ -0,0 +1,173 @@ +package transport + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type messageElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (messageElementMapper) linkerFor(elem *message) *message { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type messageList struct { + head *message + tail *message +} + +// Reset resets list l to the empty state. +func (l *messageList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *messageList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *messageList) Front() *message { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *messageList) Back() *message { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *messageList) PushFront(e *message) { + messageElementMapper{}.linkerFor(e).SetNext(l.head) + messageElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + messageElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *messageList) PushBack(e *message) { + messageElementMapper{}.linkerFor(e).SetNext(nil) + messageElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + messageElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *messageList) PushBackList(m *messageList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + messageElementMapper{}.linkerFor(l.tail).SetNext(m.head) + messageElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *messageList) InsertAfter(b, e *message) { + a := messageElementMapper{}.linkerFor(b).Next() + messageElementMapper{}.linkerFor(e).SetNext(a) + messageElementMapper{}.linkerFor(e).SetPrev(b) + messageElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + messageElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *messageList) InsertBefore(a, e *message) { + b := messageElementMapper{}.linkerFor(a).Prev() + messageElementMapper{}.linkerFor(e).SetNext(a) + messageElementMapper{}.linkerFor(e).SetPrev(b) + messageElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + messageElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *messageList) Remove(e *message) { + prev := messageElementMapper{}.linkerFor(e).Prev() + next := messageElementMapper{}.linkerFor(e).Next() + + if prev != nil { + messageElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + messageElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type messageEntry struct { + next *message + prev *message +} + +// Next returns the entry that follows e in the list. +func (e *messageEntry) Next() *message { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *messageEntry) Prev() *message { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *messageEntry) SetNext(elem *message) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *messageEntry) SetPrev(elem *message) { + e.prev = elem +} diff --git a/pkg/sentry/socket/unix/transport/transport_state_autogen.go b/pkg/sentry/socket/unix/transport/transport_state_autogen.go new file mode 100755 index 000000000..83eedb711 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/transport_state_autogen.go @@ -0,0 +1,191 @@ +// automatically generated by stateify. + +package transport + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *connectionedEndpoint) beforeSave() {} +func (x *connectionedEndpoint) save(m state.Map) { + x.beforeSave() + var acceptedChan []*connectionedEndpoint = x.saveAcceptedChan() + m.SaveValue("acceptedChan", acceptedChan) + m.Save("baseEndpoint", &x.baseEndpoint) + m.Save("id", &x.id) + m.Save("idGenerator", &x.idGenerator) + m.Save("stype", &x.stype) +} + +func (x *connectionedEndpoint) afterLoad() {} +func (x *connectionedEndpoint) load(m state.Map) { + m.Load("baseEndpoint", &x.baseEndpoint) + m.Load("id", &x.id) + m.Load("idGenerator", &x.idGenerator) + m.Load("stype", &x.stype) + m.LoadValue("acceptedChan", new([]*connectionedEndpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*connectionedEndpoint)) }) +} + +func (x *connectionlessEndpoint) beforeSave() {} +func (x *connectionlessEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("baseEndpoint", &x.baseEndpoint) +} + +func (x *connectionlessEndpoint) afterLoad() {} +func (x *connectionlessEndpoint) load(m state.Map) { + m.Load("baseEndpoint", &x.baseEndpoint) +} + +func (x *queue) beforeSave() {} +func (x *queue) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("ReaderQueue", &x.ReaderQueue) + m.Save("WriterQueue", &x.WriterQueue) + m.Save("closed", &x.closed) + m.Save("used", &x.used) + m.Save("limit", &x.limit) + m.Save("dataList", &x.dataList) +} + +func (x *queue) afterLoad() {} +func (x *queue) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("ReaderQueue", &x.ReaderQueue) + m.Load("WriterQueue", &x.WriterQueue) + m.Load("closed", &x.closed) + m.Load("used", &x.used) + m.Load("limit", &x.limit) + m.Load("dataList", &x.dataList) +} + +func (x *messageList) beforeSave() {} +func (x *messageList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *messageList) afterLoad() {} +func (x *messageList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *messageEntry) beforeSave() {} +func (x *messageEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *messageEntry) afterLoad() {} +func (x *messageEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func (x *ControlMessages) beforeSave() {} +func (x *ControlMessages) save(m state.Map) { + x.beforeSave() + m.Save("Rights", &x.Rights) + m.Save("Credentials", &x.Credentials) +} + +func (x *ControlMessages) afterLoad() {} +func (x *ControlMessages) load(m state.Map) { + m.Load("Rights", &x.Rights) + m.Load("Credentials", &x.Credentials) +} + +func (x *message) beforeSave() {} +func (x *message) save(m state.Map) { + x.beforeSave() + m.Save("messageEntry", &x.messageEntry) + m.Save("Data", &x.Data) + m.Save("Control", &x.Control) + m.Save("Address", &x.Address) +} + +func (x *message) afterLoad() {} +func (x *message) load(m state.Map) { + m.Load("messageEntry", &x.messageEntry) + m.Load("Data", &x.Data) + m.Load("Control", &x.Control) + m.Load("Address", &x.Address) +} + +func (x *queueReceiver) beforeSave() {} +func (x *queueReceiver) save(m state.Map) { + x.beforeSave() + m.Save("readQueue", &x.readQueue) +} + +func (x *queueReceiver) afterLoad() {} +func (x *queueReceiver) load(m state.Map) { + m.Load("readQueue", &x.readQueue) +} + +func (x *streamQueueReceiver) beforeSave() {} +func (x *streamQueueReceiver) save(m state.Map) { + x.beforeSave() + m.Save("queueReceiver", &x.queueReceiver) + m.Save("buffer", &x.buffer) + m.Save("control", &x.control) + m.Save("addr", &x.addr) +} + +func (x *streamQueueReceiver) afterLoad() {} +func (x *streamQueueReceiver) load(m state.Map) { + m.Load("queueReceiver", &x.queueReceiver) + m.Load("buffer", &x.buffer) + m.Load("control", &x.control) + m.Load("addr", &x.addr) +} + +func (x *connectedEndpoint) beforeSave() {} +func (x *connectedEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("endpoint", &x.endpoint) + m.Save("writeQueue", &x.writeQueue) +} + +func (x *connectedEndpoint) afterLoad() {} +func (x *connectedEndpoint) load(m state.Map) { + m.Load("endpoint", &x.endpoint) + m.Load("writeQueue", &x.writeQueue) +} + +func (x *baseEndpoint) beforeSave() {} +func (x *baseEndpoint) save(m state.Map) { + x.beforeSave() + m.Save("Queue", &x.Queue) + m.Save("passcred", &x.passcred) + m.Save("receiver", &x.receiver) + m.Save("connected", &x.connected) + m.Save("path", &x.path) +} + +func (x *baseEndpoint) afterLoad() {} +func (x *baseEndpoint) load(m state.Map) { + m.Load("Queue", &x.Queue) + m.Load("passcred", &x.passcred) + m.Load("receiver", &x.receiver) + m.Load("connected", &x.connected) + m.Load("path", &x.path) +} + +func init() { + state.Register("transport.connectionedEndpoint", (*connectionedEndpoint)(nil), state.Fns{Save: (*connectionedEndpoint).save, Load: (*connectionedEndpoint).load}) + state.Register("transport.connectionlessEndpoint", (*connectionlessEndpoint)(nil), state.Fns{Save: (*connectionlessEndpoint).save, Load: (*connectionlessEndpoint).load}) + state.Register("transport.queue", (*queue)(nil), state.Fns{Save: (*queue).save, Load: (*queue).load}) + state.Register("transport.messageList", (*messageList)(nil), state.Fns{Save: (*messageList).save, Load: (*messageList).load}) + state.Register("transport.messageEntry", (*messageEntry)(nil), state.Fns{Save: (*messageEntry).save, Load: (*messageEntry).load}) + state.Register("transport.ControlMessages", (*ControlMessages)(nil), state.Fns{Save: (*ControlMessages).save, Load: (*ControlMessages).load}) + state.Register("transport.message", (*message)(nil), state.Fns{Save: (*message).save, Load: (*message).load}) + state.Register("transport.queueReceiver", (*queueReceiver)(nil), state.Fns{Save: (*queueReceiver).save, Load: (*queueReceiver).load}) + state.Register("transport.streamQueueReceiver", (*streamQueueReceiver)(nil), state.Fns{Save: (*streamQueueReceiver).save, Load: (*streamQueueReceiver).load}) + state.Register("transport.connectedEndpoint", (*connectedEndpoint)(nil), state.Fns{Save: (*connectedEndpoint).save, Load: (*connectedEndpoint).load}) + state.Register("transport.baseEndpoint", (*baseEndpoint)(nil), state.Fns{Save: (*baseEndpoint).save, Load: (*baseEndpoint).load}) +} diff --git a/pkg/sentry/socket/unix/unix_state_autogen.go b/pkg/sentry/socket/unix/unix_state_autogen.go new file mode 100755 index 000000000..86f734156 --- /dev/null +++ b/pkg/sentry/socket/unix/unix_state_autogen.go @@ -0,0 +1,28 @@ +// automatically generated by stateify. + +package unix + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SocketOperations) beforeSave() {} +func (x *SocketOperations) save(m state.Map) { + x.beforeSave() + m.Save("AtomicRefCount", &x.AtomicRefCount) + m.Save("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Save("ep", &x.ep) + m.Save("stype", &x.stype) +} + +func (x *SocketOperations) afterLoad() {} +func (x *SocketOperations) load(m state.Map) { + m.Load("AtomicRefCount", &x.AtomicRefCount) + m.Load("SendReceiveTimeout", &x.SendReceiveTimeout) + m.Load("ep", &x.ep) + m.Load("stype", &x.stype) +} + +func init() { + state.Register("unix.SocketOperations", (*SocketOperations)(nil), state.Fns{Save: (*SocketOperations).save, Load: (*SocketOperations).load}) +} diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD deleted file mode 100644 index f297ef3b7..000000000 --- a/pkg/sentry/state/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "state", - srcs = [ - "state.go", - "state_metadata.go", - "state_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/state", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/watchdog", - "//pkg/state/statefile", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/state/state_state_autogen.go b/pkg/sentry/state/state_state_autogen.go new file mode 100755 index 000000000..6c0d9b7a7 --- /dev/null +++ b/pkg/sentry/state/state_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package state + diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD deleted file mode 100644 index d77c7a433..000000000 --- a/pkg/sentry/strace/BUILD +++ /dev/null @@ -1,53 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "strace", - srcs = [ - "capability.go", - "clone.go", - "futex.go", - "linux64.go", - "open.go", - "poll.go", - "ptrace.go", - "signal.go", - "socket.go", - "strace.go", - "syscalls.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/strace", - visibility = ["//:sandbox"], - deps = [ - ":strace_go_proto", - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/bits", - "//pkg/eventchannel", - "//pkg/seccomp", - "//pkg/sentry/arch", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/epsocket", - "//pkg/sentry/socket/netlink", - "//pkg/sentry/syscalls/linux", - "//pkg/sentry/usermem", - ], -) - -proto_library( - name = "strace_proto", - srcs = ["strace.proto"], - visibility = ["//visibility:public"], -) - -go_proto_library( - name = "strace_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto", - proto = ":strace_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/strace/strace.proto b/pkg/sentry/strace/strace.proto deleted file mode 100644 index 4b2f73a5f..000000000 --- a/pkg/sentry/strace/strace.proto +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -message Strace { - // Process name that made the syscall. - string process = 1; - - // Syscall function name. - string function = 2; - - // List of syscall arguments formatted as strings. - repeated string args = 3; - - oneof info { - StraceEnter enter = 4; - StraceExit exit = 5; - } -} - -message StraceEnter { -} - -message StraceExit { - // Return value formatted as string. - string return = 1; - - // Formatted error string in case syscall failed. - string error = 2; - - // Value of errno upon syscall exit. - int64 err_no = 3; // errno is a macro and gets expanded :-( - - // Time elapsed between syscall enter and exit. - int64 elapsed_ns = 4; -} diff --git a/pkg/sentry/strace/strace_go_proto/strace.pb.go b/pkg/sentry/strace/strace_go_proto/strace.pb.go new file mode 100755 index 000000000..ef45661bc --- /dev/null +++ b/pkg/sentry/strace/strace_go_proto/strace.pb.go @@ -0,0 +1,247 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/sentry/strace/strace.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type Strace struct { + Process string `protobuf:"bytes,1,opt,name=process,proto3" json:"process,omitempty"` + Function string `protobuf:"bytes,2,opt,name=function,proto3" json:"function,omitempty"` + Args []string `protobuf:"bytes,3,rep,name=args,proto3" json:"args,omitempty"` + // Types that are valid to be assigned to Info: + // *Strace_Enter + // *Strace_Exit + Info isStrace_Info `protobuf_oneof:"info"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Strace) Reset() { *m = Strace{} } +func (m *Strace) String() string { return proto.CompactTextString(m) } +func (*Strace) ProtoMessage() {} +func (*Strace) Descriptor() ([]byte, []int) { + return fileDescriptor_50c4b43677c82b5f, []int{0} +} + +func (m *Strace) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Strace.Unmarshal(m, b) +} +func (m *Strace) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Strace.Marshal(b, m, deterministic) +} +func (m *Strace) XXX_Merge(src proto.Message) { + xxx_messageInfo_Strace.Merge(m, src) +} +func (m *Strace) XXX_Size() int { + return xxx_messageInfo_Strace.Size(m) +} +func (m *Strace) XXX_DiscardUnknown() { + xxx_messageInfo_Strace.DiscardUnknown(m) +} + +var xxx_messageInfo_Strace proto.InternalMessageInfo + +func (m *Strace) GetProcess() string { + if m != nil { + return m.Process + } + return "" +} + +func (m *Strace) GetFunction() string { + if m != nil { + return m.Function + } + return "" +} + +func (m *Strace) GetArgs() []string { + if m != nil { + return m.Args + } + return nil +} + +type isStrace_Info interface { + isStrace_Info() +} + +type Strace_Enter struct { + Enter *StraceEnter `protobuf:"bytes,4,opt,name=enter,proto3,oneof"` +} + +type Strace_Exit struct { + Exit *StraceExit `protobuf:"bytes,5,opt,name=exit,proto3,oneof"` +} + +func (*Strace_Enter) isStrace_Info() {} + +func (*Strace_Exit) isStrace_Info() {} + +func (m *Strace) GetInfo() isStrace_Info { + if m != nil { + return m.Info + } + return nil +} + +func (m *Strace) GetEnter() *StraceEnter { + if x, ok := m.GetInfo().(*Strace_Enter); ok { + return x.Enter + } + return nil +} + +func (m *Strace) GetExit() *StraceExit { + if x, ok := m.GetInfo().(*Strace_Exit); ok { + return x.Exit + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*Strace) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*Strace_Enter)(nil), + (*Strace_Exit)(nil), + } +} + +type StraceEnter struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *StraceEnter) Reset() { *m = StraceEnter{} } +func (m *StraceEnter) String() string { return proto.CompactTextString(m) } +func (*StraceEnter) ProtoMessage() {} +func (*StraceEnter) Descriptor() ([]byte, []int) { + return fileDescriptor_50c4b43677c82b5f, []int{1} +} + +func (m *StraceEnter) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_StraceEnter.Unmarshal(m, b) +} +func (m *StraceEnter) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_StraceEnter.Marshal(b, m, deterministic) +} +func (m *StraceEnter) XXX_Merge(src proto.Message) { + xxx_messageInfo_StraceEnter.Merge(m, src) +} +func (m *StraceEnter) XXX_Size() int { + return xxx_messageInfo_StraceEnter.Size(m) +} +func (m *StraceEnter) XXX_DiscardUnknown() { + xxx_messageInfo_StraceEnter.DiscardUnknown(m) +} + +var xxx_messageInfo_StraceEnter proto.InternalMessageInfo + +type StraceExit struct { + Return string `protobuf:"bytes,1,opt,name=return,proto3" json:"return,omitempty"` + Error string `protobuf:"bytes,2,opt,name=error,proto3" json:"error,omitempty"` + ErrNo int64 `protobuf:"varint,3,opt,name=err_no,json=errNo,proto3" json:"err_no,omitempty"` + ElapsedNs int64 `protobuf:"varint,4,opt,name=elapsed_ns,json=elapsedNs,proto3" json:"elapsed_ns,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *StraceExit) Reset() { *m = StraceExit{} } +func (m *StraceExit) String() string { return proto.CompactTextString(m) } +func (*StraceExit) ProtoMessage() {} +func (*StraceExit) Descriptor() ([]byte, []int) { + return fileDescriptor_50c4b43677c82b5f, []int{2} +} + +func (m *StraceExit) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_StraceExit.Unmarshal(m, b) +} +func (m *StraceExit) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_StraceExit.Marshal(b, m, deterministic) +} +func (m *StraceExit) XXX_Merge(src proto.Message) { + xxx_messageInfo_StraceExit.Merge(m, src) +} +func (m *StraceExit) XXX_Size() int { + return xxx_messageInfo_StraceExit.Size(m) +} +func (m *StraceExit) XXX_DiscardUnknown() { + xxx_messageInfo_StraceExit.DiscardUnknown(m) +} + +var xxx_messageInfo_StraceExit proto.InternalMessageInfo + +func (m *StraceExit) GetReturn() string { + if m != nil { + return m.Return + } + return "" +} + +func (m *StraceExit) GetError() string { + if m != nil { + return m.Error + } + return "" +} + +func (m *StraceExit) GetErrNo() int64 { + if m != nil { + return m.ErrNo + } + return 0 +} + +func (m *StraceExit) GetElapsedNs() int64 { + if m != nil { + return m.ElapsedNs + } + return 0 +} + +func init() { + proto.RegisterType((*Strace)(nil), "gvisor.Strace") + proto.RegisterType((*StraceEnter)(nil), "gvisor.StraceEnter") + proto.RegisterType((*StraceExit)(nil), "gvisor.StraceExit") +} + +func init() { proto.RegisterFile("pkg/sentry/strace/strace.proto", fileDescriptor_50c4b43677c82b5f) } + +var fileDescriptor_50c4b43677c82b5f = []byte{ + // 255 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0x90, 0xdd, 0x4a, 0xf4, 0x30, + 0x10, 0x86, 0xb7, 0x5f, 0xdb, 0x7c, 0x76, 0x16, 0x4f, 0xc6, 0x1f, 0x82, 0xa0, 0x94, 0x1e, 0x05, + 0x84, 0x2e, 0xe8, 0x1d, 0x08, 0xc2, 0x1e, 0xed, 0x41, 0xbc, 0x80, 0xa5, 0xd6, 0xd9, 0x12, 0x94, + 0x24, 0x4c, 0xb2, 0xb2, 0x5e, 0x96, 0x77, 0x28, 0xa6, 0xf1, 0x07, 0x8f, 0x92, 0x67, 0xde, 0x87, + 0x0c, 0x6f, 0xe0, 0xca, 0x3f, 0x4f, 0xab, 0x40, 0x36, 0xf2, 0xdb, 0x2a, 0x44, 0x1e, 0x46, 0xca, + 0x47, 0xef, 0xd9, 0x45, 0x87, 0x62, 0x7a, 0x35, 0xc1, 0x71, 0xf7, 0x5e, 0x80, 0x78, 0x48, 0x01, + 0x4a, 0xf8, 0xef, 0xd9, 0x8d, 0x14, 0x82, 0x2c, 0xda, 0x42, 0x35, 0xfa, 0x0b, 0xf1, 0x02, 0x8e, + 0x76, 0x7b, 0x3b, 0x46, 0xe3, 0xac, 0xfc, 0x97, 0xa2, 0x6f, 0x46, 0x84, 0x6a, 0xe0, 0x29, 0xc8, + 0xb2, 0x2d, 0x55, 0xa3, 0xd3, 0x1d, 0xaf, 0xa1, 0x26, 0x1b, 0x89, 0x65, 0xd5, 0x16, 0x6a, 0x79, + 0x73, 0xd2, 0xcf, 0xcb, 0xfa, 0x79, 0xd1, 0xfd, 0x67, 0xb4, 0x5e, 0xe8, 0xd9, 0x41, 0x05, 0x15, + 0x1d, 0x4c, 0x94, 0x75, 0x72, 0xf1, 0x8f, 0x7b, 0x30, 0x71, 0xbd, 0xd0, 0xc9, 0xb8, 0x13, 0x50, + 0x19, 0xbb, 0x73, 0xdd, 0x31, 0x2c, 0x7f, 0xbd, 0xd4, 0x79, 0x80, 0x1f, 0x19, 0xcf, 0x41, 0x30, + 0xc5, 0x3d, 0xdb, 0x5c, 0x22, 0x13, 0x9e, 0x42, 0x4d, 0xcc, 0x8e, 0x73, 0x81, 0x19, 0xf0, 0x0c, + 0x04, 0x31, 0x6f, 0xad, 0x93, 0x65, 0x5b, 0xa8, 0x32, 0x8d, 0x37, 0x0e, 0x2f, 0x01, 0xe8, 0x65, + 0xf0, 0x81, 0x9e, 0xb6, 0x36, 0xa4, 0x16, 0xa5, 0x6e, 0xf2, 0x64, 0x13, 0x1e, 0x45, 0xfa, 0xc3, + 0xdb, 0x8f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x42, 0x9a, 0xbc, 0x81, 0x65, 0x01, 0x00, 0x00, +} diff --git a/pkg/sentry/strace/strace_state_autogen.go b/pkg/sentry/strace/strace_state_autogen.go new file mode 100755 index 000000000..9dc697ed6 --- /dev/null +++ b/pkg/sentry/strace/strace_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package strace + diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD deleted file mode 100644 index 18fddee76..000000000 --- a/pkg/sentry/syscalls/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "syscalls", - srcs = [ - "epoll.go", - "syscalls.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls", - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/sentry/arch", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/epoll", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/time", - "//pkg/syserror", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD deleted file mode 100644 index 428c67799..000000000 --- a/pkg/sentry/syscalls/linux/BUILD +++ /dev/null @@ -1,92 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "linux", - srcs = [ - "error.go", - "flags.go", - "linux64.go", - "sigset.go", - "sys_aio.go", - "sys_capability.go", - "sys_epoll.go", - "sys_eventfd.go", - "sys_file.go", - "sys_futex.go", - "sys_getdents.go", - "sys_identity.go", - "sys_inotify.go", - "sys_lseek.go", - "sys_mempolicy.go", - "sys_mmap.go", - "sys_mount.go", - "sys_pipe.go", - "sys_poll.go", - "sys_prctl.go", - "sys_random.go", - "sys_read.go", - "sys_rlimit.go", - "sys_rusage.go", - "sys_sched.go", - "sys_seccomp.go", - "sys_sem.go", - "sys_shm.go", - "sys_signal.go", - "sys_socket.go", - "sys_splice.go", - "sys_stat.go", - "sys_sync.go", - "sys_sysinfo.go", - "sys_syslog.go", - "sys_thread.go", - "sys_time.go", - "sys_timer.go", - "sys_timerfd.go", - "sys_tls.go", - "sys_utsname.go", - "sys_write.go", - "timespec.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls/linux", - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi", - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/bpf", - "//pkg/log", - "//pkg/metric", - "//pkg/rand", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fs/timerfd", - "//pkg/sentry/fs/tmpfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/epoll", - "//pkg/sentry/kernel/eventfd", - "//pkg/sentry/kernel/fasync", - "//pkg/sentry/kernel/kdefs", - "//pkg/sentry/kernel/pipe", - "//pkg/sentry/kernel/sched", - "//pkg/sentry/kernel/shm", - "//pkg/sentry/kernel/time", - "//pkg/sentry/limits", - "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/safemem", - "//pkg/sentry/socket", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/syscalls", - "//pkg/sentry/usage", - "//pkg/sentry/usermem", - "//pkg/syserror", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/syscalls/linux/linux_state_autogen.go b/pkg/sentry/syscalls/linux/linux_state_autogen.go new file mode 100755 index 000000000..06d1fb23f --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux_state_autogen.go @@ -0,0 +1,80 @@ +// automatically generated by stateify. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *ioEvent) beforeSave() {} +func (x *ioEvent) save(m state.Map) { + x.beforeSave() + m.Save("Data", &x.Data) + m.Save("Obj", &x.Obj) + m.Save("Result", &x.Result) + m.Save("Result2", &x.Result2) +} + +func (x *ioEvent) afterLoad() {} +func (x *ioEvent) load(m state.Map) { + m.Load("Data", &x.Data) + m.Load("Obj", &x.Obj) + m.Load("Result", &x.Result) + m.Load("Result2", &x.Result2) +} + +func (x *futexWaitRestartBlock) beforeSave() {} +func (x *futexWaitRestartBlock) save(m state.Map) { + x.beforeSave() + m.Save("duration", &x.duration) + m.Save("addr", &x.addr) + m.Save("private", &x.private) + m.Save("val", &x.val) + m.Save("mask", &x.mask) +} + +func (x *futexWaitRestartBlock) afterLoad() {} +func (x *futexWaitRestartBlock) load(m state.Map) { + m.Load("duration", &x.duration) + m.Load("addr", &x.addr) + m.Load("private", &x.private) + m.Load("val", &x.val) + m.Load("mask", &x.mask) +} + +func (x *pollRestartBlock) beforeSave() {} +func (x *pollRestartBlock) save(m state.Map) { + x.beforeSave() + m.Save("pfdAddr", &x.pfdAddr) + m.Save("nfds", &x.nfds) + m.Save("timeout", &x.timeout) +} + +func (x *pollRestartBlock) afterLoad() {} +func (x *pollRestartBlock) load(m state.Map) { + m.Load("pfdAddr", &x.pfdAddr) + m.Load("nfds", &x.nfds) + m.Load("timeout", &x.timeout) +} + +func (x *clockNanosleepRestartBlock) beforeSave() {} +func (x *clockNanosleepRestartBlock) save(m state.Map) { + x.beforeSave() + m.Save("c", &x.c) + m.Save("duration", &x.duration) + m.Save("rem", &x.rem) +} + +func (x *clockNanosleepRestartBlock) afterLoad() {} +func (x *clockNanosleepRestartBlock) load(m state.Map) { + m.Load("c", &x.c) + m.Load("duration", &x.duration) + m.Load("rem", &x.rem) +} + +func init() { + state.Register("linux.ioEvent", (*ioEvent)(nil), state.Fns{Save: (*ioEvent).save, Load: (*ioEvent).load}) + state.Register("linux.futexWaitRestartBlock", (*futexWaitRestartBlock)(nil), state.Fns{Save: (*futexWaitRestartBlock).save, Load: (*futexWaitRestartBlock).load}) + state.Register("linux.pollRestartBlock", (*pollRestartBlock)(nil), state.Fns{Save: (*pollRestartBlock).save, Load: (*pollRestartBlock).load}) + state.Register("linux.clockNanosleepRestartBlock", (*clockNanosleepRestartBlock)(nil), state.Fns{Save: (*clockNanosleepRestartBlock).save, Load: (*clockNanosleepRestartBlock).load}) +} diff --git a/pkg/sentry/syscalls/syscalls_state_autogen.go b/pkg/sentry/syscalls/syscalls_state_autogen.go new file mode 100755 index 000000000..c114e7989 --- /dev/null +++ b/pkg/sentry/syscalls/syscalls_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package syscalls + diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD deleted file mode 100644 index d2ede0353..000000000 --- a/pkg/sentry/time/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") - -go_template_instance( - name = "seqatomic_parameters", - out = "seqatomic_parameters.go", - package = "time", - suffix = "Parameters", - template = "//third_party/gvsync:generic_seqatomic", - types = { - "Value": "Parameters", - }, -) - -go_library( - name = "time", - srcs = [ - "arith_arm64.go", - "calibrated_clock.go", - "clock_id.go", - "clocks.go", - "muldiv_amd64.s", - "muldiv_arm64.s", - "parameters.go", - "sampler.go", - "sampler_unsafe.go", - "seqatomic_parameters.go", - "tsc_amd64.s", - "tsc_arm64.s", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/time", - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/metric", - "//pkg/syserror", - "//third_party/gvsync", - ], -) - -go_test( - name = "time_test", - srcs = [ - "calibrated_clock_test.go", - "parameters_test.go", - "sampler_test.go", - ], - embed = [":time"], -) diff --git a/pkg/sentry/time/LICENSE b/pkg/sentry/time/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/sentry/time/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/sentry/time/calibrated_clock_test.go b/pkg/sentry/time/calibrated_clock_test.go deleted file mode 100644 index d6622bfe2..000000000 --- a/pkg/sentry/time/calibrated_clock_test.go +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package time - -import ( - "testing" - "time" -) - -// newTestCalibratedClock returns a CalibratedClock that collects samples from -// the given sample list and cycle counts from the given cycle list. -func newTestCalibratedClock(samples []sample, cycles []TSCValue) *CalibratedClock { - return &CalibratedClock{ - ref: newTestSampler(samples, cycles), - } -} - -func TestConstantFrequency(t *testing.T) { - // Perfectly constant frequency. - samples := []sample{ - {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100}, - {before: 200000, after: 200000 + defaultOverheadCycles, ref: 200}, - {before: 300000, after: 300000 + defaultOverheadCycles, ref: 300}, - {before: 400000, after: 400000 + defaultOverheadCycles, ref: 400}, - {before: 500000, after: 500000 + defaultOverheadCycles, ref: 500}, - {before: 600000, after: 600000 + defaultOverheadCycles, ref: 600}, - {before: 700000, after: 700000 + defaultOverheadCycles, ref: 700}, - } - - c := newTestCalibratedClock(samples, nil) - - // Update from all samples. - for range samples { - c.Update() - } - - c.mu.RLock() - if !c.ready { - c.mu.RUnlock() - t.Fatalf("clock not ready") - } - // A bit after the last sample. - now, ok := c.params.ComputeTime(750000) - c.mu.RUnlock() - if !ok { - t.Fatalf("ComputeTime ok got %v want true", ok) - } - - t.Logf("now: %v", now) - - // Time should be between the current sample and where we'd expect the - // next sample. - if now < 700 || now > 800 { - t.Errorf("now got %v want > 700 && < 800", now) - } -} - -func TestErrorCorrection(t *testing.T) { - testCases := []struct { - name string - samples [5]sample - projectedTimeStart int64 - projectedTimeEnd int64 - }{ - // Initial calibration should be ~1MHz for each of these, and - // the reference clock changes in samples[2]. - { - name: "slow-down", - samples: [5]sample{ - {before: 1000000, after: 1000001, ref: ReferenceNS(1 * ApproxUpdateInterval.Nanoseconds())}, - {before: 2000000, after: 2000001, ref: ReferenceNS(2 * ApproxUpdateInterval.Nanoseconds())}, - // Reference clock has slowed down, causing 100ms of error. - {before: 3010000, after: 3010001, ref: ReferenceNS(3 * ApproxUpdateInterval.Nanoseconds())}, - {before: 4020000, after: 4020001, ref: ReferenceNS(4 * ApproxUpdateInterval.Nanoseconds())}, - {before: 5030000, after: 5030001, ref: ReferenceNS(5 * ApproxUpdateInterval.Nanoseconds())}, - }, - projectedTimeStart: 3005 * time.Millisecond.Nanoseconds(), - projectedTimeEnd: 3015 * time.Millisecond.Nanoseconds(), - }, - { - name: "speed-up", - samples: [5]sample{ - {before: 1000000, after: 1000001, ref: ReferenceNS(1 * ApproxUpdateInterval.Nanoseconds())}, - {before: 2000000, after: 2000001, ref: ReferenceNS(2 * ApproxUpdateInterval.Nanoseconds())}, - // Reference clock has sped up, causing 100ms of error. - {before: 2990000, after: 2990001, ref: ReferenceNS(3 * ApproxUpdateInterval.Nanoseconds())}, - {before: 3980000, after: 3980001, ref: ReferenceNS(4 * ApproxUpdateInterval.Nanoseconds())}, - {before: 4970000, after: 4970001, ref: ReferenceNS(5 * ApproxUpdateInterval.Nanoseconds())}, - }, - projectedTimeStart: 2985 * time.Millisecond.Nanoseconds(), - projectedTimeEnd: 2995 * time.Millisecond.Nanoseconds(), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := newTestCalibratedClock(tc.samples[:], nil) - - // Initial calibration takes two updates. - _, ok := c.Update() - if ok { - t.Fatalf("Update ready too early") - } - - params, ok := c.Update() - if !ok { - t.Fatalf("Update not ready") - } - - // Initial calibration is ~1MHz. - hz := params.Frequency - if hz < 990000 || hz > 1010000 { - t.Fatalf("Frequency got %v want > 990kHz && < 1010kHz", hz) - } - - // Project time at the next update. Given the 1MHz - // calibration, it is expected to be ~3.1s/2.9s, not - // the actual 3s. - // - // N.B. the next update time is the "after" time above. - projected, ok := params.ComputeTime(tc.samples[2].after) - if !ok { - t.Fatalf("ComputeTime ok got %v want true", ok) - } - if projected < tc.projectedTimeStart || projected > tc.projectedTimeEnd { - t.Fatalf("ComputeTime(%v) got %v want > %v && < %v", tc.samples[2].after, projected, tc.projectedTimeStart, tc.projectedTimeEnd) - } - - // Update again to see the changed reference clock. - params, ok = c.Update() - if !ok { - t.Fatalf("Update not ready") - } - - // We now know that TSC = tc.samples[2].after -> 3s, - // but with the previous params indicated that TSC - // tc.samples[2].after -> 3.5s/2.5s. We can't allow the - // clock to go backwards, and having the clock jump - // forwards is undesirable. There should be a smooth - // transition that corrects the clock error over time. - // Check that the clock is continuous at TSC = - // tc.samples[2].after. - newProjected, ok := params.ComputeTime(tc.samples[2].after) - if !ok { - t.Fatalf("ComputeTime ok got %v want true", ok) - } - if newProjected != projected { - t.Errorf("Discontinuous time; ComputeTime(%v) got %v want %v", tc.samples[2].after, newProjected, projected) - } - - // As the reference clock stablizes, ensure that the clock error - // decreases. - initialErr := c.errorNS - t.Logf("initial error: %v ns", initialErr) - - _, ok = c.Update() - if !ok { - t.Fatalf("Update not ready") - } - if c.errorNS.Magnitude() > initialErr.Magnitude() { - t.Errorf("errorNS increased, got %v want |%v| <= |%v|", c.errorNS, c.errorNS, initialErr) - } - - _, ok = c.Update() - if !ok { - t.Fatalf("Update not ready") - } - if c.errorNS.Magnitude() > initialErr.Magnitude() { - t.Errorf("errorNS increased, got %v want |%v| <= |%v|", c.errorNS, c.errorNS, initialErr) - } - - t.Logf("final error: %v ns", c.errorNS) - }) - } -} diff --git a/pkg/sentry/time/parameters_test.go b/pkg/sentry/time/parameters_test.go deleted file mode 100644 index e1b9084ac..000000000 --- a/pkg/sentry/time/parameters_test.go +++ /dev/null @@ -1,486 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package time - -import ( - "math" - "testing" - "time" -) - -func TestParametersComputeTime(t *testing.T) { - testCases := []struct { - name string - params Parameters - now TSCValue - want int64 - }{ - { - // Now is the same as the base cycles. - name: "base-cycles", - params: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 10000, - want: 5000 * time.Millisecond.Nanoseconds(), - }, - { - // Now is the behind the base cycles. Time is frozen. - name: "backwards", - params: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 9000, - want: 5000 * time.Millisecond.Nanoseconds(), - }, - { - // Now is ahead of the base cycles. - name: "ahead", - params: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 15000, - want: 5500 * time.Millisecond.Nanoseconds(), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, ok := tc.params.ComputeTime(tc.now) - if !ok { - t.Errorf("ComputeTime ok got %v want true", got) - } - if got != tc.want { - t.Errorf("ComputeTime got %+v want %+v", got, tc.want) - } - }) - } -} - -func TestParametersErrorAdjust(t *testing.T) { - testCases := []struct { - name string - oldParams Parameters - now TSCValue - newParams Parameters - want Parameters - errorNS ReferenceNS - wantErr bool - }{ - { - // newParams are perfectly continuous with oldParams - // and don't need adjustment. - name: "continuous", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 50000, - newParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - }, - { - // Same as "continuous", but with now ahead of - // newParams.BaseCycles. The result is the same as - // there is no error to correct. - name: "continuous-nowdiff", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 60000, - newParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - }, - { - // errorAdjust bails out if the TSC goes backwards. - name: "tsc-backwards", - oldParams: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(1000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 9000, - newParams: Parameters{ - BaseCycles: 9000, - BaseRef: ReferenceNS(1100 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - wantErr: true, - }, - { - // errorAdjust bails out if new params are from after now. - name: "params-after-now", - oldParams: Parameters{ - BaseCycles: 10000, - BaseRef: ReferenceNS(1000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 11000, - newParams: Parameters{ - BaseCycles: 12000, - BaseRef: ReferenceNS(1200 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - wantErr: true, - }, - { - // Host clock sped up. - name: "speed-up", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 45000, - // Host frequency changed to 9000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - // From oldParams, we think ref = 4.5s at cycles = 45000. - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 9000, - }, - want: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(4500 * time.Millisecond.Nanoseconds()), - // We must decrease the new frequency by 50% to - // correct 0.5s of error in 1s - // (ApproxUpdateInterval). - Frequency: 4500, - }, - errorNS: ReferenceNS(-500 * time.Millisecond.Nanoseconds()), - }, - { - // Host clock sped up, with now ahead of newParams. - name: "speed-up-nowdiff", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 50000, - // Host frequency changed to 9000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 9000, - }, - // nextRef = 6000ms - // nextCycles = 9000 * (6000ms - 5000ms) + 45000 - // nextCycles = 9000 * (1s) + 45000 - // nextCycles = 54000 - // f = (54000 - 50000) / 1s = 4000 - // - // ref = 5000ms - (50000 - 45000) / 4000 - // ref = 3.75s - want: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(3750 * time.Millisecond.Nanoseconds()), - Frequency: 4000, - }, - // oldNow = 50000 * 10000 = 5s - // newNow = (50000 - 45000) / 9000 + 5s = 5.555s - errorNS: ReferenceNS((5000*time.Millisecond - 5555555555).Nanoseconds()), - }, - { - // Host clock sped up. The new parameters are so far - // ahead that the next update time already passed. - name: "speed-up-uncorrectable-baseref", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 50000, - // Host frequency changed to 5000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(9000 * time.Millisecond.Nanoseconds()), - Frequency: 5000, - }, - // The next update should be at 10s, but newParams - // already passed 6s. Thus it is impossible to correct - // the clock by then. - wantErr: true, - }, - { - // Host clock sped up. The new parameters are moving so - // fast that the next update should be before now. - name: "speed-up-uncorrectable-frequency", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 55000, - // Host frequency changed to 7500 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 45000, - BaseRef: ReferenceNS(6000 * time.Millisecond.Nanoseconds()), - Frequency: 7500, - }, - // The next update should be at 6.5s, but newParams are - // so far ahead and fast that they reach 6.5s at cycle - // 48750, which before now! Thus it is impossible to - // correct the clock by then. - wantErr: true, - }, - { - // Host clock slowed down. - name: "slow-down", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 55000, - // Host frequency changed to 11000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 55000, - // From oldParams, we think ref = 5.5s at cycles = 55000. - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 11000, - }, - want: Parameters{ - BaseCycles: 55000, - BaseRef: ReferenceNS(5500 * time.Millisecond.Nanoseconds()), - // We must increase the new frequency by 50% to - // correct 0.5s of error in 1s - // (ApproxUpdateInterval). - Frequency: 16500, - }, - errorNS: ReferenceNS(500 * time.Millisecond.Nanoseconds()), - }, - { - // Host clock slowed down, with now ahead of newParams. - name: "slow-down-nowdiff", - oldParams: Parameters{ - BaseCycles: 0, - BaseRef: 0, - Frequency: 10000, - }, - now: 60000, - // Host frequency changed to 11000 immediately after - // oldParams was returned. - newParams: Parameters{ - BaseCycles: 55000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 11000, - }, - // nextRef = 7000ms - // nextCycles = 11000 * (7000ms - 5000ms) + 55000 - // nextCycles = 11000 * (2000ms) + 55000 - // nextCycles = 77000 - // f = (77000 - 60000) / 1s = 17000 - // - // ref = 6000ms - (60000 - 55000) / 17000 - // ref = 5705882353ns - want: Parameters{ - BaseCycles: 55000, - BaseRef: ReferenceNS(5705882353), - Frequency: 17000, - }, - // oldNow = 60000 * 10000 = 6s - // newNow = (60000 - 55000) / 11000 + 5s = 5.4545s - errorNS: ReferenceNS((6*time.Second - 5454545454).Nanoseconds()), - }, - { - // Host time went backwards. - name: "time-backwards", - oldParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 60000, - newParams: Parameters{ - BaseCycles: 60000, - // From oldParams, we think ref = 6s at cycles = 60000. - BaseRef: ReferenceNS(4000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 60000, - BaseRef: ReferenceNS(6000 * time.Millisecond.Nanoseconds()), - // We must increase the frequency by 200% to - // correct 2s of error in 1s - // (ApproxUpdateInterval). - Frequency: 30000, - }, - errorNS: ReferenceNS(2000 * time.Millisecond.Nanoseconds()), - }, - { - // Host time went backwards, with now ahead of newParams. - name: "time-backwards-nowdiff", - oldParams: Parameters{ - BaseCycles: 50000, - BaseRef: ReferenceNS(5000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - now: 65000, - // nextRef = 7500ms - // nextCycles = 10000 * (7500ms - 4000ms) + 60000 - // nextCycles = 10000 * (3500ms) + 60000 - // nextCycles = 95000 - // f = (95000 - 65000) / 1s = 30000 - // - // ref = 6500ms - (65000 - 60000) / 30000 - // ref = 6333333333ns - newParams: Parameters{ - BaseCycles: 60000, - BaseRef: ReferenceNS(4000 * time.Millisecond.Nanoseconds()), - Frequency: 10000, - }, - want: Parameters{ - BaseCycles: 60000, - BaseRef: ReferenceNS(6333333334), - Frequency: 30000, - }, - // oldNow = 65000 * 10000 = 6.5s - // newNow = (65000 - 60000) / 10000 + 4s = 4.5s - errorNS: ReferenceNS(2000 * time.Millisecond.Nanoseconds()), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, errorNS, err := errorAdjust(tc.oldParams, tc.newParams, tc.now) - if err != nil && !tc.wantErr { - t.Errorf("err got %v want nil", err) - } else if err == nil && tc.wantErr { - t.Errorf("err got nil want non-nil") - } - - if got != tc.want { - t.Errorf("Parameters got %+v want %+v", got, tc.want) - } - if errorNS != tc.errorNS { - t.Errorf("errorNS got %v want %v", errorNS, tc.errorNS) - } - }) - } -} - -func testMuldiv(t *testing.T, v uint64) { - for i := uint64(1); i <= 1000000; i++ { - mult := uint64(1000000000) - div := i * mult - res, ok := muldiv64(v, mult, div) - if !ok { - t.Errorf("Result of %v * %v / %v ok got false want true", v, mult, div) - } - if want := v / i; res != want { - t.Errorf("Bad result of %v * %v / %v: got %v, want %v", v, mult, div, res, want) - } - } -} - -func TestMulDiv(t *testing.T) { - testMuldiv(t, math.MaxUint64) - for i := int64(-10); i <= 10; i++ { - testMuldiv(t, uint64(i)) - } -} - -func TestMulDivZero(t *testing.T) { - if r, ok := muldiv64(2, 4, 0); ok { - t.Errorf("muldiv64(2, 4, 0) got %d, ok want !ok", r) - } - - if r, ok := muldiv64(0, 0, 0); ok { - t.Errorf("muldiv64(0, 0, 0) got %d, ok want !ok", r) - } -} - -func TestMulDivOverflow(t *testing.T) { - testCases := []struct { - name string - val uint64 - mult uint64 - div uint64 - ok bool - ret uint64 - }{ - { - name: "2^62", - val: 1 << 63, - mult: 4, - div: 8, - ok: true, - ret: 1 << 62, - }, - { - name: "2^64-1", - val: 0xffffffffffffffff, - mult: 1, - div: 1, - ok: true, - ret: 0xffffffffffffffff, - }, - { - name: "2^64", - val: 1 << 63, - mult: 4, - div: 2, - ok: false, - }, - { - name: "2^125", - val: 1 << 63, - mult: 1 << 63, - div: 2, - ok: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - r, ok := muldiv64(tc.val, tc.mult, tc.div) - if ok != tc.ok { - t.Errorf("ok got %v want %v", ok, tc.ok) - } - if tc.ok && r != tc.ret { - t.Errorf("ret got %v want %v", r, tc.ret) - } - }) - } -} diff --git a/pkg/sentry/time/sampler_test.go b/pkg/sentry/time/sampler_test.go deleted file mode 100644 index 3e70a1134..000000000 --- a/pkg/sentry/time/sampler_test.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package time - -import ( - "errors" - "testing" -) - -// errNoSamples is returned when testReferenceClocks runs out of samples. -var errNoSamples = errors.New("no samples available") - -// testReferenceClocks returns a preset list of samples and cycle counts. -type testReferenceClocks struct { - samples []sample - cycles []TSCValue -} - -// Sample implements referenceClocks.Sample, returning the next sample in the list. -func (t *testReferenceClocks) Sample(_ ClockID) (sample, error) { - if len(t.samples) == 0 { - return sample{}, errNoSamples - } - - s := t.samples[0] - if len(t.samples) == 1 { - t.samples = nil - } else { - t.samples = t.samples[1:] - } - - return s, nil -} - -// Cycles implements referenceClocks.Cycles, returning the next TSCValue in the list. -func (t *testReferenceClocks) Cycles() TSCValue { - if len(t.cycles) == 0 { - return 0 - } - - c := t.cycles[0] - if len(t.cycles) == 1 { - t.cycles = nil - } else { - t.cycles = t.cycles[1:] - } - - return c -} - -// newTestSampler returns a sampler that collects samples from -// the given sample list and cycle counts from the given cycle list. -func newTestSampler(samples []sample, cycles []TSCValue) *sampler { - return &sampler{ - clocks: &testReferenceClocks{ - samples: samples, - cycles: cycles, - }, - overhead: defaultOverheadCycles, - } -} - -// generateSamples generates n samples with the given overhead. -func generateSamples(n int, overhead TSCValue) []sample { - samples := []sample{{before: 1000000, after: 1000000 + overhead, ref: 100}} - for i := 0; i < n-1; i++ { - prev := samples[len(samples)-1] - samples = append(samples, sample{ - before: prev.before + 1000000, - after: prev.after + 1000000, - ref: prev.ref + 100, - }) - } - return samples -} - -// TestSample ensures that samples can be collected. -func TestSample(t *testing.T) { - testCases := []struct { - name string - samples []sample - err error - }{ - { - name: "basic", - samples: []sample{ - {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100}, - }, - err: nil, - }, - { - // Sample with backwards TSC ignored. - // referenceClock should retry and get errNoSamples. - name: "backwards-tsc-ignored", - samples: []sample{ - {before: 100000, after: 90000, ref: 100}, - }, - err: errNoSamples, - }, - { - // Sample far above overhead skipped. - // referenceClock should retry and get errNoSamples. - name: "reject-overhead", - samples: []sample{ - {before: 100000, after: 100000 + 5*defaultOverheadCycles, ref: 100}, - }, - err: errNoSamples, - }, - { - // Maximum overhead allowed is bounded. - name: "over-max-overhead", - // Generate a bunch of samples. The reference clock - // needs a while to ramp up its expected overhead. - samples: generateSamples(100, 2*maxOverheadCycles), - err: errOverheadTooHigh, - }, - { - // Overhead at maximum overhead is allowed. - name: "max-overhead", - // Generate a bunch of samples. The reference clock - // needs a while to ramp up its expected overhead. - samples: generateSamples(100, maxOverheadCycles), - err: nil, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - s := newTestSampler(tc.samples, nil) - err := s.Sample() - if err != tc.err { - t.Errorf("Sample err got %v want %v", err, tc.err) - } - }) - } -} - -// TestOutliersIgnored tests that referenceClock ignores samples with very high -// overhead. -func TestOutliersIgnored(t *testing.T) { - s := newTestSampler([]sample{ - {before: 100000, after: 100000 + defaultOverheadCycles, ref: 100}, - {before: 200000, after: 200000 + defaultOverheadCycles, ref: 200}, - {before: 300000, after: 300000 + defaultOverheadCycles, ref: 300}, - {before: 400000, after: 400000 + defaultOverheadCycles, ref: 400}, - {before: 500000, after: 500000 + 5*defaultOverheadCycles, ref: 500}, // Ignored - {before: 600000, after: 600000 + defaultOverheadCycles, ref: 600}, - {before: 700000, after: 700000 + defaultOverheadCycles, ref: 700}, - }, nil) - - // Collect 5 samples. - for i := 0; i < 5; i++ { - err := s.Sample() - if err != nil { - t.Fatalf("Unexpected error while sampling: %v", err) - } - } - - oldest, newest, ok := s.Range() - if !ok { - t.Fatalf("Range not ok") - } - - if oldest.ref != 100 { - t.Errorf("oldest.ref got %v want %v", oldest.ref, 100) - } - - // We skipped the high-overhead sample. - if newest.ref != 600 { - t.Errorf("newest.ref got %v want %v", newest.ref, 600) - } -} diff --git a/pkg/sentry/time/seqatomic_parameters.go b/pkg/sentry/time/seqatomic_parameters.go new file mode 100755 index 000000000..89792c56d --- /dev/null +++ b/pkg/sentry/time/seqatomic_parameters.go @@ -0,0 +1,55 @@ +package time + +import ( + "fmt" + "reflect" + "strings" + "unsafe" + + "gvisor.dev/gvisor/third_party/gvsync" +) + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in sc. +func SeqAtomicLoadParameters(sc *gvsync.SeqCount, ptr *Parameters) Parameters { + // This function doesn't use SeqAtomicTryLoad because doing so is + // measurably, significantly (~20%) slower; Go is awful at inlining. + var val Parameters + for { + epoch := sc.BeginRead() + if gvsync.RaceEnabled { + + gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + + val = *ptr + } + if sc.ReadOk(epoch) { + break + } + } + return val +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read +// would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +func SeqAtomicTryLoadParameters(sc *gvsync.SeqCount, epoch gvsync.SeqCountEpoch, ptr *Parameters) (Parameters, bool) { + var val Parameters + if gvsync.RaceEnabled { + gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + val = *ptr + } + return val, sc.ReadOk(epoch) +} + +func initParameters() { + var val Parameters + typ := reflect.TypeOf(val) + name := typ.Name() + if ptrs := gvsync.PointersInType(typ, name); len(ptrs) != 0 { + panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) + } +} diff --git a/pkg/sentry/time/time_state_autogen.go b/pkg/sentry/time/time_state_autogen.go new file mode 100755 index 000000000..ea614b056 --- /dev/null +++ b/pkg/sentry/time/time_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package time + diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD deleted file mode 100644 index b69603da3..000000000 --- a/pkg/sentry/unimpl/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") - -package(licenses = ["notice"]) - -proto_library( - name = "unimplemented_syscall_proto", - srcs = ["unimplemented_syscall.proto"], - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_proto"], -) - -go_proto_library( - name = "unimplemented_syscall_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto", - proto = ":unimplemented_syscall_proto", - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_go_proto"], -) - -go_library( - name = "unimpl", - srcs = ["events.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl", - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/sentry/context", - ], -) diff --git a/pkg/sentry/unimpl/unimpl_state_autogen.go b/pkg/sentry/unimpl/unimpl_state_autogen.go new file mode 100755 index 000000000..b9d1116f3 --- /dev/null +++ b/pkg/sentry/unimpl/unimpl_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package unimpl + diff --git a/pkg/sentry/unimpl/unimplemented_syscall.proto b/pkg/sentry/unimpl/unimplemented_syscall.proto deleted file mode 100644 index 0d7a94be7..000000000 --- a/pkg/sentry/unimpl/unimplemented_syscall.proto +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor; - -import "pkg/sentry/arch/registers.proto"; - -message UnimplementedSyscall { - // Task ID. - int32 tid = 1; - - // Registers at the time of the call. - Registers registers = 2; -} diff --git a/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go b/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go new file mode 100755 index 000000000..4dfb169cc --- /dev/null +++ b/pkg/sentry/unimpl/unimplemented_syscall_go_proto/unimplemented_syscall.pb.go @@ -0,0 +1,91 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/sentry/unimpl/unimplemented_syscall.proto + +package gvisor + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + registers_go_proto "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type UnimplementedSyscall struct { + Tid int32 `protobuf:"varint,1,opt,name=tid,proto3" json:"tid,omitempty"` + Registers *registers_go_proto.Registers `protobuf:"bytes,2,opt,name=registers,proto3" json:"registers,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *UnimplementedSyscall) Reset() { *m = UnimplementedSyscall{} } +func (m *UnimplementedSyscall) String() string { return proto.CompactTextString(m) } +func (*UnimplementedSyscall) ProtoMessage() {} +func (*UnimplementedSyscall) Descriptor() ([]byte, []int) { + return fileDescriptor_ddc2fcd2bea3c75d, []int{0} +} + +func (m *UnimplementedSyscall) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_UnimplementedSyscall.Unmarshal(m, b) +} +func (m *UnimplementedSyscall) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_UnimplementedSyscall.Marshal(b, m, deterministic) +} +func (m *UnimplementedSyscall) XXX_Merge(src proto.Message) { + xxx_messageInfo_UnimplementedSyscall.Merge(m, src) +} +func (m *UnimplementedSyscall) XXX_Size() int { + return xxx_messageInfo_UnimplementedSyscall.Size(m) +} +func (m *UnimplementedSyscall) XXX_DiscardUnknown() { + xxx_messageInfo_UnimplementedSyscall.DiscardUnknown(m) +} + +var xxx_messageInfo_UnimplementedSyscall proto.InternalMessageInfo + +func (m *UnimplementedSyscall) GetTid() int32 { + if m != nil { + return m.Tid + } + return 0 +} + +func (m *UnimplementedSyscall) GetRegisters() *registers_go_proto.Registers { + if m != nil { + return m.Registers + } + return nil +} + +func init() { + proto.RegisterType((*UnimplementedSyscall)(nil), "gvisor.UnimplementedSyscall") +} + +func init() { + proto.RegisterFile("pkg/sentry/unimpl/unimplemented_syscall.proto", fileDescriptor_ddc2fcd2bea3c75d) +} + +var fileDescriptor_ddc2fcd2bea3c75d = []byte{ + // 149 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x2d, 0xc8, 0x4e, 0xd7, + 0x2f, 0x4e, 0xcd, 0x2b, 0x29, 0xaa, 0xd4, 0x2f, 0xcd, 0xcb, 0xcc, 0x2d, 0xc8, 0x81, 0x52, 0xa9, + 0xb9, 0xa9, 0x79, 0x25, 0xa9, 0x29, 0xf1, 0xc5, 0x95, 0xc5, 0xc9, 0x89, 0x39, 0x39, 0x7a, 0x05, + 0x45, 0xf9, 0x25, 0xf9, 0x42, 0x6c, 0xe9, 0x65, 0x99, 0xc5, 0xf9, 0x45, 0x52, 0xf2, 0x48, 0xda, + 0x12, 0x8b, 0x92, 0x33, 0xf4, 0x8b, 0x52, 0xd3, 0x33, 0x8b, 0x4b, 0x52, 0x8b, 0x8a, 0x21, 0x0a, + 0x95, 0x22, 0xb9, 0x44, 0x42, 0x91, 0xcd, 0x09, 0x86, 0x18, 0x23, 0x24, 0xc0, 0xc5, 0x5c, 0x92, + 0x99, 0x22, 0xc1, 0xa8, 0xc0, 0xa8, 0xc1, 0x1a, 0x04, 0x62, 0x0a, 0xe9, 0x73, 0x71, 0xc2, 0x35, + 0x4b, 0x30, 0x29, 0x30, 0x6a, 0x70, 0x1b, 0x09, 0xea, 0x41, 0xac, 0xd1, 0x0b, 0x82, 0x49, 0x04, + 0x21, 0xd4, 0x24, 0xb1, 0x81, 0x6d, 0x30, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x51, 0x4a, 0x47, + 0x79, 0xbb, 0x00, 0x00, 0x00, +} diff --git a/pkg/sentry/uniqueid/BUILD b/pkg/sentry/uniqueid/BUILD deleted file mode 100644 index 86a87edd4..000000000 --- a/pkg/sentry/uniqueid/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "uniqueid", - srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/uniqueid", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/socket/unix/transport", - ], -) diff --git a/pkg/sentry/uniqueid/uniqueid_state_autogen.go b/pkg/sentry/uniqueid/uniqueid_state_autogen.go new file mode 100755 index 000000000..09e4327e4 --- /dev/null +++ b/pkg/sentry/uniqueid/uniqueid_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package uniqueid + diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD deleted file mode 100644 index a34c39540..000000000 --- a/pkg/sentry/usage/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "usage", - srcs = [ - "cpu.go", - "io.go", - "memory.go", - "memory_unsafe.go", - "usage.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/usage", - visibility = [ - "//pkg/sentry:internal", - ], - deps = [ - "//pkg/bits", - "//pkg/memutil", - ], -) diff --git a/pkg/sentry/usage/usage_state_autogen.go b/pkg/sentry/usage/usage_state_autogen.go new file mode 100755 index 000000000..d98e3e916 --- /dev/null +++ b/pkg/sentry/usage/usage_state_autogen.go @@ -0,0 +1,50 @@ +// automatically generated by stateify. + +package usage + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *CPUStats) beforeSave() {} +func (x *CPUStats) save(m state.Map) { + x.beforeSave() + m.Save("UserTime", &x.UserTime) + m.Save("SysTime", &x.SysTime) + m.Save("VoluntarySwitches", &x.VoluntarySwitches) +} + +func (x *CPUStats) afterLoad() {} +func (x *CPUStats) load(m state.Map) { + m.Load("UserTime", &x.UserTime) + m.Load("SysTime", &x.SysTime) + m.Load("VoluntarySwitches", &x.VoluntarySwitches) +} + +func (x *IO) beforeSave() {} +func (x *IO) save(m state.Map) { + x.beforeSave() + m.Save("CharsRead", &x.CharsRead) + m.Save("CharsWritten", &x.CharsWritten) + m.Save("ReadSyscalls", &x.ReadSyscalls) + m.Save("WriteSyscalls", &x.WriteSyscalls) + m.Save("BytesRead", &x.BytesRead) + m.Save("BytesWritten", &x.BytesWritten) + m.Save("BytesWriteCancelled", &x.BytesWriteCancelled) +} + +func (x *IO) afterLoad() {} +func (x *IO) load(m state.Map) { + m.Load("CharsRead", &x.CharsRead) + m.Load("CharsWritten", &x.CharsWritten) + m.Load("ReadSyscalls", &x.ReadSyscalls) + m.Load("WriteSyscalls", &x.WriteSyscalls) + m.Load("BytesRead", &x.BytesRead) + m.Load("BytesWritten", &x.BytesWritten) + m.Load("BytesWriteCancelled", &x.BytesWriteCancelled) +} + +func init() { + state.Register("usage.CPUStats", (*CPUStats)(nil), state.Fns{Save: (*CPUStats).save, Load: (*CPUStats).load}) + state.Register("usage.IO", (*IO)(nil), state.Fns{Save: (*IO).save, Load: (*IO).load}) +} diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD deleted file mode 100644 index a5b4206bb..000000000 --- a/pkg/sentry/usermem/BUILD +++ /dev/null @@ -1,57 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "addr_range", - out = "addr_range.go", - package = "usermem", - prefix = "Addr", - template = "//pkg/segment:generic_range", - types = { - "T": "Addr", - }, -) - -go_library( - name = "usermem", - srcs = [ - "access_type.go", - "addr.go", - "addr_range.go", - "addr_range_seq_unsafe.go", - "bytes_io.go", - "bytes_io_unsafe.go", - "usermem.go", - "usermem_arm64.go", - "usermem_unsafe.go", - "usermem_x86.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/usermem", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/atomicbitops", - "//pkg/binary", - "//pkg/log", - "//pkg/sentry/context", - "//pkg/sentry/safemem", - "//pkg/syserror", - "//pkg/tcpip/buffer", - ], -) - -go_test( - name = "usermem_test", - size = "small", - srcs = [ - "addr_range_seq_test.go", - "usermem_test.go", - ], - embed = [":usermem"], - deps = [ - "//pkg/sentry/context", - "//pkg/sentry/safemem", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/usermem/README.md b/pkg/sentry/usermem/README.md deleted file mode 100644 index f6d2137eb..000000000 --- a/pkg/sentry/usermem/README.md +++ /dev/null @@ -1,31 +0,0 @@ -This package defines primitives for sentry access to application memory. - -Major types: - -- The `IO` interface represents a virtual address space and provides I/O - methods on that address space. `IO` is the lowest-level primitive. The - primary implementation of the `IO` interface is `mm.MemoryManager`. - -- `IOSequence` represents a collection of individually-contiguous address - ranges in a `IO` that is operated on sequentially, analogous to Linux's - `struct iov_iter`. - -Major usage patterns: - -- Access to a task's virtual memory, subject to the application's memory - protections and while running on that task's goroutine, from a context that - is at or above the level of the `kernel` package (e.g. most syscall - implementations in `syscalls/linux`); use the `kernel.Task.Copy*` wrappers - defined in `kernel/task_usermem.go`. - -- Access to a task's virtual memory, from a context that is at or above the - level of the `kernel` package, but where any of the above constraints does - not hold (e.g. `PTRACE_POKEDATA`, which ignores application memory - protections); obtain the task's `mm.MemoryManager` by calling - `kernel.Task.MemoryManager`, and call its `IO` methods directly. - -- Access to a task's virtual memory, from a context that is below the level of - the `kernel` package (e.g. filesystem I/O); clients must pass I/O arguments - from higher layers, usually in the form of an `IOSequence`. The - `kernel.Task.SingleIOSequence` and `kernel.Task.IovecsIOSequence` functions - in `kernel/task_usermem.go` are convenience functions for doing so. diff --git a/pkg/sentry/usermem/addr_range.go b/pkg/sentry/usermem/addr_range.go new file mode 100755 index 000000000..152ed1434 --- /dev/null +++ b/pkg/sentry/usermem/addr_range.go @@ -0,0 +1,62 @@ +package usermem + +// A Range represents a contiguous range of T. +// +// +stateify savable +type AddrRange struct { + // Start is the inclusive start of the range. + Start Addr + + // End is the exclusive end of the range. + End Addr +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r AddrRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r AddrRange) Length() Addr { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r AddrRange) Contains(x Addr) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r AddrRange) Overlaps(r2 AddrRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r AddrRange) IsSupersetOf(r2 AddrRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r AddrRange) Intersect(r2 AddrRange) AddrRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r AddrRange) CanSplitAt(x Addr) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/sentry/usermem/addr_range_seq_test.go b/pkg/sentry/usermem/addr_range_seq_test.go deleted file mode 100644 index 82f735026..000000000 --- a/pkg/sentry/usermem/addr_range_seq_test.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "testing" -) - -var addrRangeSeqTests = []struct { - desc string - ranges []AddrRange -}{ - { - desc: "Empty sequence", - }, - { - desc: "Single empty AddrRange", - ranges: []AddrRange{ - {0x10, 0x10}, - }, - }, - { - desc: "Single non-empty AddrRange of length 1", - ranges: []AddrRange{ - {0x10, 0x11}, - }, - }, - { - desc: "Single non-empty AddrRange of length 2", - ranges: []AddrRange{ - {0x10, 0x12}, - }, - }, - { - desc: "Multiple non-empty AddrRanges", - ranges: []AddrRange{ - {0x10, 0x11}, - {0x20, 0x22}, - }, - }, - { - desc: "Multiple AddrRanges including empty AddrRanges", - ranges: []AddrRange{ - {0x10, 0x10}, - {0x20, 0x20}, - {0x30, 0x33}, - {0x40, 0x44}, - {0x50, 0x50}, - {0x60, 0x60}, - {0x70, 0x77}, - {0x80, 0x88}, - {0x90, 0x90}, - {0xa0, 0xa0}, - }, - }, -} - -func testAddrRangeSeqEqualityWithTailIteration(t *testing.T, ars AddrRangeSeq, wantRanges []AddrRange) { - var wantLen int64 - for _, ar := range wantRanges { - wantLen += int64(ar.Length()) - } - - var i int - for !ars.IsEmpty() { - if gotLen := ars.NumBytes(); gotLen != wantLen { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen) - } - if gotN, wantN := ars.NumRanges(), len(wantRanges)-i; gotN != wantN { - t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted %d", i, ars, gotN, wantN) - } - got := ars.Head() - if i >= len(wantRanges) { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted <end of sequence>", i, ars, got) - } else if want := wantRanges[i]; got != want { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want) - } - ars = ars.Tail() - wantLen -= int64(got.Length()) - i++ - } - if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen) - } - if gotN := ars.NumRanges(); gotN != 0 { - t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted 0", i, ars, gotN) - } -} - -func TestAddrRangeSeqTailIteration(t *testing.T) { - for _, test := range addrRangeSeqTests { - t.Run(test.desc, func(t *testing.T) { - testAddrRangeSeqEqualityWithTailIteration(t, AddrRangeSeqFromSlice(test.ranges), test.ranges) - }) - } -} - -func TestAddrRangeSeqDropFirstEmpty(t *testing.T) { - var ars AddrRangeSeq - if got, want := ars.DropFirst(1), ars; got != want { - t.Errorf("%v.DropFirst(1): got %v, wanted %v", ars, got, want) - } -} - -func TestAddrRangeSeqDropSingleByteIteration(t *testing.T) { - // Tests AddrRangeSeq iteration using Head/DropFirst, simulating - // I/O-per-AddrRange. - for _, test := range addrRangeSeqTests { - t.Run(test.desc, func(t *testing.T) { - // Figure out what AddrRanges we expect to see. - var wantLen int64 - var wantRanges []AddrRange - for _, ar := range test.ranges { - wantLen += int64(ar.Length()) - wantRanges = append(wantRanges, ar) - if ar.Length() == 0 { - // We "do" 0 bytes of I/O and then call DropFirst(0), - // advancing to the next AddrRange. - continue - } - // Otherwise we "do" 1 byte of I/O and then call DropFirst(1), - // advancing the AddrRange by 1 byte, or to the next AddrRange - // if this one is exhausted. - for ar.Start++; ar.Length() != 0; ar.Start++ { - wantRanges = append(wantRanges, ar) - } - } - t.Logf("Expected AddrRanges: %s (%d bytes)", wantRanges, wantLen) - - ars := AddrRangeSeqFromSlice(test.ranges) - var i int - for !ars.IsEmpty() { - if gotLen := ars.NumBytes(); gotLen != wantLen { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen) - } - got := ars.Head() - if i >= len(wantRanges) { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted <end of sequence>", i, ars, got) - } else if want := wantRanges[i]; got != want { - t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want) - } - if got.Length() == 0 { - ars = ars.DropFirst(0) - } else { - ars = ars.DropFirst(1) - wantLen-- - } - i++ - } - if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 { - t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen) - } - }) - } -} - -func TestAddrRangeSeqTakeFirstEmpty(t *testing.T) { - var ars AddrRangeSeq - if got, want := ars.TakeFirst(1), ars; got != want { - t.Errorf("%v.TakeFirst(1): got %v, wanted %v", ars, got, want) - } -} - -func TestAddrRangeSeqTakeFirst(t *testing.T) { - ranges := []AddrRange{ - {0x10, 0x11}, - {0x20, 0x22}, - {0x30, 0x30}, - {0x40, 0x44}, - {0x50, 0x55}, - {0x60, 0x60}, - {0x70, 0x77}, - } - ars := AddrRangeSeqFromSlice(ranges).TakeFirst(5) - want := []AddrRange{ - {0x10, 0x11}, // +1 byte (total 1 byte), not truncated - {0x20, 0x22}, // +2 bytes (total 3 bytes), not truncated - {0x30, 0x30}, // +0 bytes (total 3 bytes), no change - {0x40, 0x42}, // +2 bytes (total 5 bytes), partially truncated - {0x50, 0x50}, // +0 bytes (total 5 bytes), fully truncated - {0x60, 0x60}, // +0 bytes (total 5 bytes), "fully truncated" (no change) - {0x70, 0x70}, // +0 bytes (total 5 bytes), fully truncated - } - testAddrRangeSeqEqualityWithTailIteration(t, ars, want) -} diff --git a/pkg/sentry/usermem/usermem_state_autogen.go b/pkg/sentry/usermem/usermem_state_autogen.go new file mode 100755 index 000000000..39b92d108 --- /dev/null +++ b/pkg/sentry/usermem/usermem_state_autogen.go @@ -0,0 +1,49 @@ +// automatically generated by stateify. + +package usermem + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *AccessType) beforeSave() {} +func (x *AccessType) save(m state.Map) { + x.beforeSave() + m.Save("Read", &x.Read) + m.Save("Write", &x.Write) + m.Save("Execute", &x.Execute) +} + +func (x *AccessType) afterLoad() {} +func (x *AccessType) load(m state.Map) { + m.Load("Read", &x.Read) + m.Load("Write", &x.Write) + m.Load("Execute", &x.Execute) +} + +func (x *Addr) save(m state.Map) { + m.SaveValue("", (uintptr)(*x)) +} + +func (x *Addr) load(m state.Map) { + m.LoadValue("", new(uintptr), func(y interface{}) { *x = (Addr)(y.(uintptr)) }) +} + +func (x *AddrRange) beforeSave() {} +func (x *AddrRange) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *AddrRange) afterLoad() {} +func (x *AddrRange) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func init() { + state.Register("usermem.AccessType", (*AccessType)(nil), state.Fns{Save: (*AccessType).save, Load: (*AccessType).load}) + state.Register("usermem.Addr", (*Addr)(nil), state.Fns{Save: (*Addr).save, Load: (*Addr).load}) + state.Register("usermem.AddrRange", (*AddrRange)(nil), state.Fns{Save: (*AddrRange).save, Load: (*AddrRange).load}) +} diff --git a/pkg/sentry/usermem/usermem_test.go b/pkg/sentry/usermem/usermem_test.go deleted file mode 100644 index 299f64754..000000000 --- a/pkg/sentry/usermem/usermem_test.go +++ /dev/null @@ -1,424 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package usermem - -import ( - "bytes" - "encoding/binary" - "fmt" - "reflect" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/safemem" - "gvisor.dev/gvisor/pkg/syserror" -) - -// newContext returns a context.Context that we can use in these tests (we -// can't use contexttest because it depends on usermem). -func newContext() context.Context { - return context.Background() -} - -func newBytesIOString(s string) *BytesIO { - return &BytesIO{[]byte(s)} -} - -func TestBytesIOCopyOutSuccess(t *testing.T) { - b := newBytesIOString("ABCDE") - n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("AfooE"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFailure(t *testing.T) { - b := newBytesIOString("ABC") - n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) - if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("Afo"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInSuccess(t *testing.T) { - b := newBytesIOString("AfooE") - var dst [3]byte - n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInFailure(t *testing.T) { - b := newBytesIOString("Afo") - var dst [3]byte - n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) - if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := dst[:], []byte("fo\x00"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } -} - -func TestBytesIOZeroOutSuccess(t *testing.T) { - b := newBytesIOString("ABCD") - n, err := b.ZeroOut(newContext(), 1, 2, IOOpts{}) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("A\x00\x00D"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOZeroOutFailure(t *testing.T) { - b := newBytesIOString("ABC") - n, err := b.ZeroOut(newContext(), 1, 3, IOOpts{}) - if wantN, wantErr := int64(2), syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("A\x00\x00"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFromSuccess(t *testing.T) { - b := newBytesIOString("ABCDEFGH") - n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 4, End: 7}, - {Start: 1, End: 4}, - }), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{}) - if wantN := int64(6); n != wantN || err != nil { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := b.Bytes, []byte("AfoobarH"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyOutFromFailure(t *testing.T) { - b := newBytesIOString("ABCDE") - n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 1, End: 4}, - {Start: 4, End: 7}, - }), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{}) - if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := b.Bytes, []byte("Afoob"); !bytes.Equal(got, want) { - t.Errorf("Bytes: got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInToSuccess(t *testing.T) { - b := newBytesIOString("AfoobarH") - var dst bytes.Buffer - n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 4, End: 7}, - {Start: 1, End: 4}, - }), safemem.FromIOWriter{&dst}, IOOpts{}) - if wantN := int64(6); n != wantN || err != nil { - t.Errorf("CopyInTo: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst.Bytes(), []byte("barfoo"); !bytes.Equal(got, want) { - t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) - } -} - -func TestBytesIOCopyInToFailure(t *testing.T) { - b := newBytesIOString("Afoob") - var dst bytes.Buffer - n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{ - {Start: 1, End: 4}, - {Start: 4, End: 7}, - }), safemem.FromIOWriter{&dst}, IOOpts{}) - if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr { - t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) - } - if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { - t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) - } -} - -type testStruct struct { - Int8 int8 - Uint8 uint8 - Int16 int16 - Uint16 uint16 - Int32 int32 - Uint32 uint32 - Int64 int64 - Uint64 uint64 -} - -func TestCopyObject(t *testing.T) { - wantObj := testStruct{1, 2, 3, 4, 5, 6, 7, 8} - wantN := binary.Size(wantObj) - b := &BytesIO{make([]byte, wantN)} - ctx := newContext() - if n, err := CopyObjectOut(ctx, b, 0, &wantObj, IOOpts{}); n != wantN || err != nil { - t.Fatalf("CopyObjectOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - var gotObj testStruct - if n, err := CopyObjectIn(ctx, b, 0, &gotObj, IOOpts{}); n != wantN || err != nil { - t.Errorf("CopyObjectIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if gotObj != wantObj { - t.Errorf("CopyObject round trip: got %+v, wanted %+v", gotObj, wantObj) - } -} - -func TestCopyStringInShort(t *testing.T) { - // Tests for string length <= copyStringIncrement. - want := strings.Repeat("A", copyStringIncrement-2) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInLong(t *testing.T) { - // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen - // (requiring multiple calls to IO.CopyIn()). - want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInVeryLong(t *testing.T) { - // Tests for string length > copyStringMaxInitBufLen (requiring buffer - // reallocation). - want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4) - mem := want + "\x00" - if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) - } -} - -func TestCopyStringInNoTerminatingZeroByte(t *testing.T) { - want := strings.Repeat("A", copyStringIncrement-1) - got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{}) - if wantErr := syserror.EFAULT; got != want || err != wantErr { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) - } -} - -func TestCopyStringInTruncatedByMaxlen(t *testing.T) { - got, err := CopyStringIn(newContext(), newBytesIOString(strings.Repeat("A", 10)), 0, 5, IOOpts{}) - if want, wantErr := strings.Repeat("A", 5), syserror.ENAMETOOLONG; got != want || err != wantErr { - t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) - } -} - -func TestCopyInt32StringsInVec(t *testing.T) { - for _, test := range []struct { - str string - n int - initial []int32 - final []int32 - }{ - { - str: "100 200", - n: len("100 200"), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - { - // Fewer values ok - str: "100", - n: len("100"), - initial: []int32{1, 2}, - final: []int32{100, 2}, - }, - { - // Extra values ok - str: "100 200 300", - n: len("100 200 "), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - { - // Leading and trailing whitespace ok - str: " 100\t200\n", - n: len(" 100\t200\n"), - initial: []int32{1, 2}, - final: []int32{100, 200}, - }, - } { - t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) { - src := BytesIOSequence([]byte(test.str)) - dsts := append([]int32(nil), test.initial...) - if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); n != int64(test.n) || err != nil { - t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (%d, nil)", n, err, test.n) - } - if !reflect.DeepEqual(dsts, test.final) { - t.Errorf("dsts: got %v, wanted %v", dsts, test.final) - } - }) - } -} - -func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) { - for _, s := range []string{"", "\n", "a123"} { - t.Run(fmt.Sprintf("%q", s), func(t *testing.T) { - src := BytesIOSequence([]byte(s)) - initial := []int32{1, 2} - dsts := append([]int32(nil), initial...) - if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); err != syserror.EINVAL { - t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, syserror.EINVAL) - } - if !reflect.DeepEqual(dsts, initial) { - t.Errorf("dsts: got %v, wanted %v", dsts, initial) - } - }) - } -} - -func TestIOSequenceCopyOut(t *testing.T) { - buf := []byte("ABCD") - s := BytesIOSequence(buf) - - // CopyOut limited by len(src). - n, err := s.CopyOut(newContext(), []byte("fo")) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foCD"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // CopyOut limited by s.NumBytes(). - n, err = s.CopyOut(newContext(), []byte("obar")) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foob"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceCopyIn(t *testing.T) { - s := BytesIOSequence([]byte("foob")) - dst := []byte("ABCDEF") - - // CopyIn limited by len(dst). - n, err := s.CopyIn(newContext(), dst[:2]) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foCDEF"); !bytes.Equal(dst, want) { - t.Errorf("dst: got %q, wanted %q", dst, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // CopyIn limited by s.Remaining(). - n, err = s.CopyIn(newContext(), dst[2:]) - if wantN := 2; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("foobEF"); !bytes.Equal(dst, want) { - t.Errorf("dst: got %q, wanted %q", dst, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceZeroOut(t *testing.T) { - buf := []byte("ABCD") - s := BytesIOSequence(buf) - - // ZeroOut limited by toZero. - n, err := s.ZeroOut(newContext(), 2) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("\x00\x00CD"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(2); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // ZeroOut limited by s.NumBytes(). - n, err = s.ZeroOut(newContext(), 4) - if wantN := int64(2); n != wantN || err != nil { - t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if want := []byte("\x00\x00\x00\x00"); !bytes.Equal(buf, want) { - t.Errorf("buf: got %q, wanted %q", buf, want) - } - s = s.DropFirst(2) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} - -func TestIOSequenceTakeFirst(t *testing.T) { - s := BytesIOSequence([]byte("foobar")) - if got, want := s.NumBytes(), int64(6); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - s = s.TakeFirst(3) - if got, want := s.NumBytes(), int64(3); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - // TakeFirst(n) where n > s.NumBytes() is a no-op. - s = s.TakeFirst(9) - if got, want := s.NumBytes(), int64(3); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } - - var dst [3]byte - n, err := s.CopyIn(newContext(), dst[:]) - if wantN := 3; n != wantN || err != nil { - t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { - t.Errorf("dst: got %q, wanted %q", got, want) - } - s = s.DropFirst(3) - if got, want := s.NumBytes(), int64(0); got != want { - t.Errorf("NumBytes: got %v, wanted %v", got, want) - } -} diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD deleted file mode 100644 index 4d8435265..000000000 --- a/pkg/sentry/watchdog/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "watchdog", - srcs = ["watchdog.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/watchdog", - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/log", - "//pkg/metric", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - ], -) diff --git a/pkg/sentry/watchdog/watchdog_state_autogen.go b/pkg/sentry/watchdog/watchdog_state_autogen.go new file mode 100755 index 000000000..530ac6a07 --- /dev/null +++ b/pkg/sentry/watchdog/watchdog_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package watchdog + diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD deleted file mode 100644 index 00665c939..000000000 --- a/pkg/sleep/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "sleep", - srcs = [ - "commit_amd64.s", - "commit_asm.go", - "commit_noasm.go", - "sleep_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sleep", - visibility = ["//:sandbox"], -) - -go_test( - name = "sleep_test", - size = "medium", - srcs = [ - "sleep_test.go", - ], - embed = [":sleep"], -) diff --git a/pkg/sleep/empty.s b/pkg/sleep/empty.s deleted file mode 100644 index fb37360ac..000000000 --- a/pkg/sleep/empty.s +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Empty assembly file so empty func definitions work. diff --git a/pkg/sleep/sleep_state_autogen.go b/pkg/sleep/sleep_state_autogen.go new file mode 100755 index 000000000..e444aa91a --- /dev/null +++ b/pkg/sleep/sleep_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package sleep + diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go deleted file mode 100644 index 130806c86..000000000 --- a/pkg/sleep/sleep_test.go +++ /dev/null @@ -1,542 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sleep - -import ( - "math/rand" - "runtime" - "testing" - "time" -) - -// ZeroWakerNotAsserted tests that a zero-value waker is in non-asserted state. -func ZeroWakerNotAsserted(t *testing.T) { - var w Waker - if w.IsAsserted() { - t.Fatalf("Zero waker is asserted") - } - - if w.Clear() { - t.Fatalf("Zero waker is asserted") - } -} - -// AssertedWakerAfterAssert tests that a waker properly reports its state as -// asserted once its Assert() method is called. -func AssertedWakerAfterAssert(t *testing.T) { - var w Waker - w.Assert() - if !w.IsAsserted() { - t.Fatalf("Asserted waker is not reported as such") - } - - if !w.Clear() { - t.Fatalf("Asserted waker is not reported as such") - } -} - -// AssertedWakerAfterTwoAsserts tests that a waker properly reports its state as -// asserted once its Assert() method is called twice. -func AssertedWakerAfterTwoAsserts(t *testing.T) { - var w Waker - w.Assert() - w.Assert() - if !w.IsAsserted() { - t.Fatalf("Asserted waker is not reported as such") - } - - if !w.Clear() { - t.Fatalf("Asserted waker is not reported as such") - } -} - -// NotAssertedWakerWithSleeper tests that a waker properly reports its state as -// not asserted after a sleeper is associated with it. -func NotAssertedWakerWithSleeper(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - if w.IsAsserted() { - t.Fatalf("Non-asserted waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Non-asserted waker is reported as asserted") - } -} - -// NotAssertedWakerAfterWake tests that a waker properly reports its state as -// not asserted after a previous assert is consumed by a sleeper. That is, tests -// the "edge-triggered" behavior. -func NotAssertedWakerAfterWake(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - w.Assert() - s.Fetch(true) - if w.IsAsserted() { - t.Fatalf("Consumed waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Consumed waker is reported as asserted") - } -} - -// AssertedWakerBeforeAdd tests that a waker causes a sleeper to not sleep if -// it's already asserted before being added. -func AssertedWakerBeforeAdd(t *testing.T) { - var w Waker - var s Sleeper - w.Assert() - s.AddWaker(&w, 0) - - if _, ok := s.Fetch(false); !ok { - t.Fatalf("Fetch failed even though asserted waker was added") - } -} - -// ClearedWaker tests that a waker properly reports its state as not asserted -// after it is cleared. -func ClearedWaker(t *testing.T) { - var w Waker - w.Assert() - w.Clear() - if w.IsAsserted() { - t.Fatalf("Cleared waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Cleared waker is reported as asserted") - } -} - -// ClearedWakerWithSleeper tests that a waker properly reports its state as -// not asserted when it is cleared while it has a sleeper associated with it. -func ClearedWakerWithSleeper(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - w.Clear() - if w.IsAsserted() { - t.Fatalf("Cleared waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Cleared waker is reported as asserted") - } -} - -// ClearedWakerAssertedWithSleeper tests that a waker properly reports its state -// as not asserted when it is cleared while it has a sleeper associated with it -// and has been asserted. -func ClearedWakerAssertedWithSleeper(t *testing.T) { - var w Waker - var s Sleeper - s.AddWaker(&w, 0) - w.Assert() - w.Clear() - if w.IsAsserted() { - t.Fatalf("Cleared waker is reported as asserted") - } - - if w.Clear() { - t.Fatalf("Cleared waker is reported as asserted") - } -} - -// TestBlock tests that a sleeper actually blocks waiting for the waker to -// assert its state. -func TestBlock(t *testing.T) { - var w Waker - var s Sleeper - - s.AddWaker(&w, 0) - - // Assert waker after one second. - before := time.Now() - go func() { - time.Sleep(1 * time.Second) - w.Assert() - }() - - // Fetch the result and make sure it took at least 500ms. - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") - } - if d := time.Now().Sub(before); d < 500*time.Millisecond { - t.Fatalf("Duration was too short: %v", d) - } - - // Check that already-asserted waker completes inline. - w.Assert() - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") - } - - // Check that fetch sleeps if waker had been asserted but was reset - // before Fetch is called. - w.Assert() - w.Clear() - before = time.Now() - go func() { - time.Sleep(1 * time.Second) - w.Assert() - }() - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") - } - if d := time.Now().Sub(before); d < 500*time.Millisecond { - t.Fatalf("Duration was too short: %v", d) - } -} - -// TestNonBlock checks that a sleeper won't block if waker isn't asserted. -func TestNonBlock(t *testing.T) { - var w Waker - var s Sleeper - - // Don't block when there's no waker. - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when there is no waker") - } - - // Don't block when waker isn't asserted. - s.AddWaker(&w, 0) - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when waker was not asserted") - } - - // Don't block when waker was asserted, but isn't anymore. - w.Assert() - w.Clear() - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when waker was not asserted anymore") - } - - // Don't block when waker was consumed by previous Fetch(). - w.Assert() - if _, ok := s.Fetch(false); !ok { - t.Fatalf("Fetch failed even though waker was asserted") - } - - if _, ok := s.Fetch(false); ok { - t.Fatalf("Fetch succeeded when waker had been consumed") - } -} - -// TestMultiple checks that a sleeper can wait for and receives notifications -// from multiple wakers. -func TestMultiple(t *testing.T) { - s := Sleeper{} - w1 := Waker{} - w2 := Waker{} - - s.AddWaker(&w1, 0) - s.AddWaker(&w2, 1) - - w1.Assert() - w2.Assert() - - v, ok := s.Fetch(false) - if !ok { - t.Fatalf("Fetch failed when there are asserted wakers") - } - - if v != 0 && v != 1 { - t.Fatalf("Unexpected waker id: %v", v) - } - - want := 1 - v - v, ok = s.Fetch(false) - if !ok { - t.Fatalf("Fetch failed when there is an asserted waker") - } - - if v != want { - t.Fatalf("Unexpected waker id, got %v, want %v", v, want) - } -} - -// TestDoneFunction tests if calling Done() on a sleeper works properly. -func TestDoneFunction(t *testing.T) { - // Trivial case of no waker. - s := Sleeper{} - s.Done() - - // Cases when the sleeper has n wakers, but none are asserted. - for n := 1; n < 20; n++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - s.Done() - } - - // Cases when the sleeper has n wakers, and only the i-th one is - // asserted. - for n := 1; n < 20; n++ { - for i := 0; i < n; i++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - w[i].Assert() - s.Done() - } - } - - // Cases when the sleeper has n wakers, and the i-th one is asserted - // and cleared. - for n := 1; n < 20; n++ { - for i := 0; i < n; i++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - w[i].Assert() - w[i].Clear() - s.Done() - } - } - - // Cases when the sleeper has n wakers, with a random number of them - // asserted. - for n := 1; n < 20; n++ { - for iters := 0; iters < 1000; iters++ { - s := Sleeper{} - w := make([]Waker, n) - for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) - } - - // Pick the number of asserted elements, then assert - // random wakers. - asserted := rand.Int() % (n + 1) - for j := 0; j < asserted; j++ { - w[rand.Int()%n].Assert() - } - s.Done() - } - } -} - -// TestRace tests that multiple wakers can continuously send wake requests to -// the sleeper. -func TestRace(t *testing.T) { - const wakers = 100 - const wakeRequests = 10000 - - counts := make([]int, wakers) - w := make([]Waker, wakers) - s := Sleeper{} - - // Associate each waker and start goroutines that will assert them. - for i := range w { - s.AddWaker(&w[i], i) - go func(w *Waker) { - n := 0 - for n < wakeRequests { - if !w.IsAsserted() { - w.Assert() - n++ - } else { - runtime.Gosched() - } - } - }(&w[i]) - } - - // Wait for all wake up notifications from all wakers. - for i := 0; i < wakers*wakeRequests; i++ { - v, _ := s.Fetch(true) - counts[v]++ - } - - // Check that we got the right number for each. - for i, v := range counts { - if v != wakeRequests { - t.Errorf("Waker %v only got %v wakes", i, v) - } - } -} - -// BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up -// from 4 wakers when at least one is already asserted. -func BenchmarkSleeperMultiSelect(b *testing.B) { - const count = 4 - s := Sleeper{} - w := make([]Waker, count) - for i := range w { - s.AddWaker(&w[i], i) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w[count-1].Assert() - s.Fetch(true) - } -} - -// BenchmarkGoMultiSelect measures how long it takes to fetch a zero-length -// struct from one of 4 channels when at least one is ready. -func BenchmarkGoMultiSelect(b *testing.B) { - const count = 4 - ch := make([]chan struct{}, count) - for i := range ch { - ch[i] = make(chan struct{}, 1) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - ch[count-1] <- struct{}{} - select { - case <-ch[0]: - case <-ch[1]: - case <-ch[2]: - case <-ch[3]: - } - } -} - -// BenchmarkSleeperSingleSelect measures how long it takes to fetch a wake up -// from one waker that is already asserted. -func BenchmarkSleeperSingleSelect(b *testing.B) { - s := Sleeper{} - w := Waker{} - s.AddWaker(&w, 0) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - w.Assert() - s.Fetch(true) - } -} - -// BenchmarkGoSingleSelect measures how long it takes to fetch a zero-length -// struct from a channel that already has it buffered. -func BenchmarkGoSingleSelect(b *testing.B) { - ch := make(chan struct{}, 1) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - ch <- struct{}{} - <-ch - } -} - -// BenchmarkSleeperAssertNonWaiting measures how long it takes to assert a -// channel that is already asserted. -func BenchmarkSleeperAssertNonWaiting(b *testing.B) { - w := Waker{} - w.Assert() - for i := 0; i < b.N; i++ { - w.Assert() - } - -} - -// BenchmarkGoAssertNonWaiting measures how long it takes to write to a channel -// that has already something written to it. -func BenchmarkGoAssertNonWaiting(b *testing.B) { - ch := make(chan struct{}, 1) - ch <- struct{}{} - for i := 0; i < b.N; i++ { - select { - case ch <- struct{}{}: - default: - } - } -} - -// BenchmarkSleeperWaitOnSingleSelect measures how long it takes to wait on one -// waker channel while another goroutine wakes up the sleeper. This assumes that -// a new goroutine doesn't run immediately (i.e., the creator of a new goroutine -// is allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) { - s := Sleeper{} - w := Waker{} - s.AddWaker(&w, 0) - for i := 0; i < b.N; i++ { - go func() { - w.Assert() - }() - s.Fetch(true) - } - -} - -// BenchmarkGoWaitOnSingleSelect measures how long it takes to wait on one -// channel while another goroutine wakes up the sleeper. This assumes that a new -// goroutine doesn't run immediately (i.e., the creator of a new goroutine is -// allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkGoWaitOnSingleSelect(b *testing.B) { - ch := make(chan struct{}, 1) - for i := 0; i < b.N; i++ { - go func() { - ch <- struct{}{} - }() - <-ch - } -} - -// BenchmarkSleeperWaitOnMultiSelect measures how long it takes to wait on 4 -// wakers while another goroutine wakes up the sleeper. This assumes that a new -// goroutine doesn't run immediately (i.e., the creator of a new goroutine is -// allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkSleeperWaitOnMultiSelect(b *testing.B) { - const count = 4 - s := Sleeper{} - w := make([]Waker, count) - for i := range w { - s.AddWaker(&w[i], i) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - w[count-1].Assert() - }() - s.Fetch(true) - } -} - -// BenchmarkGoWaitOnMultiSelect measures how long it takes to wait on 4 channels -// while another goroutine wakes up the sleeper. This assumes that a new -// goroutine doesn't run immediately (i.e., the creator of a new goroutine is -// allowed to go to sleep before the new goroutine has a chance to run). -func BenchmarkGoWaitOnMultiSelect(b *testing.B) { - const count = 4 - ch := make([]chan struct{}, count) - for i := range ch { - ch[i] = make(chan struct{}, 1) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { - ch[count-1] <- struct{}{} - }() - select { - case <-ch[0]: - case <-ch[1]: - case <-ch[2]: - case <-ch[3]: - } - } -} diff --git a/pkg/state/BUILD b/pkg/state/BUILD deleted file mode 100644 index c0f3c658d..000000000 --- a/pkg/state/BUILD +++ /dev/null @@ -1,78 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") - -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") - -go_template_instance( - name = "addr_range", - out = "addr_range.go", - package = "state", - prefix = "addr", - template = "//pkg/segment:generic_range", - types = { - "T": "uintptr", - }, -) - -go_template_instance( - name = "addr_set", - out = "addr_set.go", - consts = { - "minDegree": "10", - }, - imports = { - "reflect": "reflect", - }, - package = "state", - prefix = "addr", - template = "//pkg/segment:generic_set", - types = { - "Key": "uintptr", - "Range": "addrRange", - "Value": "reflect.Value", - "Functions": "addrSetFunctions", - }, -) - -go_library( - name = "state", - srcs = [ - "addr_range.go", - "addr_set.go", - "decode.go", - "encode.go", - "encode_unsafe.go", - "map.go", - "printer.go", - "state.go", - "stats.go", - ], - importpath = "gvisor.dev/gvisor/pkg/state", - visibility = ["//:sandbox"], - deps = [ - ":object_go_proto", - "@com_github_golang_protobuf//proto:go_default_library", - ], -) - -proto_library( - name = "object_proto", - srcs = ["object.proto"], - visibility = ["//:sandbox"], -) - -go_proto_library( - name = "object_go_proto", - importpath = "gvisor.dev/gvisor/pkg/state/object_go_proto", - proto = ":object_proto", - visibility = ["//:sandbox"], -) - -go_test( - name = "state_test", - timeout = "long", - srcs = ["state_test.go"], - embed = [":state"], -) diff --git a/pkg/state/addr_range.go b/pkg/state/addr_range.go new file mode 100755 index 000000000..45720c643 --- /dev/null +++ b/pkg/state/addr_range.go @@ -0,0 +1,62 @@ +package state + +// A Range represents a contiguous range of T. +// +// +stateify savable +type addrRange struct { + // Start is the inclusive start of the range. + Start uintptr + + // End is the exclusive end of the range. + End uintptr +} + +// WellFormed returns true if r.Start <= r.End. All other methods on a Range +// require that the Range is well-formed. +func (r addrRange) WellFormed() bool { + return r.Start <= r.End +} + +// Length returns the length of the range. +func (r addrRange) Length() uintptr { + return r.End - r.Start +} + +// Contains returns true if r contains x. +func (r addrRange) Contains(x uintptr) bool { + return r.Start <= x && x < r.End +} + +// Overlaps returns true if r and r2 overlap. +func (r addrRange) Overlaps(r2 addrRange) bool { + return r.Start < r2.End && r2.Start < r.End +} + +// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is +// contained within r. +func (r addrRange) IsSupersetOf(r2 addrRange) bool { + return r.Start <= r2.Start && r.End >= r2.End +} + +// Intersect returns a range consisting of the intersection between r and r2. +// If r and r2 do not overlap, Intersect returns a range with unspecified +// bounds, but for which Length() == 0. +func (r addrRange) Intersect(r2 addrRange) addrRange { + if r.Start < r2.Start { + r.Start = r2.Start + } + if r.End > r2.End { + r.End = r2.End + } + if r.End < r.Start { + r.End = r.Start + } + return r +} + +// CanSplitAt returns true if it is legal to split a segment spanning the range +// r at x; that is, splitting at x would produce two ranges, both of which have +// non-zero length. +func (r addrRange) CanSplitAt(x uintptr) bool { + return r.Contains(x) && r.Start < x +} diff --git a/pkg/state/addr_set.go b/pkg/state/addr_set.go new file mode 100755 index 000000000..bce7da87d --- /dev/null +++ b/pkg/state/addr_set.go @@ -0,0 +1,1274 @@ +package state + +import ( + __generics_imported0 "reflect" +) + +import ( + "bytes" + "fmt" +) + +const ( + // minDegree is the minimum degree of an internal node in a Set B-tree. + // + // - Any non-root node has at least minDegree-1 segments. + // + // - Any non-root internal (non-leaf) node has at least minDegree children. + // + // - The root node may have fewer than minDegree-1 segments, but it may + // only have 0 segments if the tree is empty. + // + // Our implementation requires minDegree >= 3. Higher values of minDegree + // usually improve performance, but increase memory usage for small sets. + addrminDegree = 10 + + addrmaxDegree = 2 * addrminDegree +) + +// A Set is a mapping of segments with non-overlapping Range keys. The zero +// value for a Set is an empty set. Set values are not safely movable nor +// copyable. Set is thread-compatible. +// +// +stateify savable +type addrSet struct { + root addrnode `state:".(*addrSegmentDataSlices)"` +} + +// IsEmpty returns true if the set contains no segments. +func (s *addrSet) IsEmpty() bool { + return s.root.nrSegments == 0 +} + +// IsEmptyRange returns true iff no segments in the set overlap the given +// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be +// more efficient. +func (s *addrSet) IsEmptyRange(r addrRange) bool { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return true + } + _, gap := s.Find(r.Start) + if !gap.Ok() { + return false + } + return r.End <= gap.End() +} + +// Span returns the total size of all segments in the set. +func (s *addrSet) Span() uintptr { + var sz uintptr + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sz += seg.Range().Length() + } + return sz +} + +// SpanRange returns the total size of the intersection of segments in the set +// with the given range. +func (s *addrSet) SpanRange(r addrRange) uintptr { + switch { + case r.Length() < 0: + panic(fmt.Sprintf("invalid range %v", r)) + case r.Length() == 0: + return 0 + } + var sz uintptr + for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() { + sz += seg.Range().Intersect(r).Length() + } + return sz +} + +// FirstSegment returns the first segment in the set. If the set is empty, +// FirstSegment returns a terminal iterator. +func (s *addrSet) FirstSegment() addrIterator { + if s.root.nrSegments == 0 { + return addrIterator{} + } + return s.root.firstSegment() +} + +// LastSegment returns the last segment in the set. If the set is empty, +// LastSegment returns a terminal iterator. +func (s *addrSet) LastSegment() addrIterator { + if s.root.nrSegments == 0 { + return addrIterator{} + } + return s.root.lastSegment() +} + +// FirstGap returns the first gap in the set. +func (s *addrSet) FirstGap() addrGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[0] + } + return addrGapIterator{n, 0} +} + +// LastGap returns the last gap in the set. +func (s *addrSet) LastGap() addrGapIterator { + n := &s.root + for n.hasChildren { + n = n.children[n.nrSegments] + } + return addrGapIterator{n, n.nrSegments} +} + +// Find returns the segment or gap whose range contains the given key. If a +// segment is found, the returned Iterator is non-terminal and the +// returned GapIterator is terminal. Otherwise, the returned Iterator is +// terminal and the returned GapIterator is non-terminal. +func (s *addrSet) Find(key uintptr) (addrIterator, addrGapIterator) { + n := &s.root + for { + + lower := 0 + upper := n.nrSegments + for lower < upper { + i := lower + (upper-lower)/2 + if r := n.keys[i]; key < r.End { + if key >= r.Start { + return addrIterator{n, i}, addrGapIterator{} + } + upper = i + } else { + lower = i + 1 + } + } + i := lower + if !n.hasChildren { + return addrIterator{}, addrGapIterator{n, i} + } + n = n.children[i] + } +} + +// FindSegment returns the segment whose range contains the given key. If no +// such segment exists, FindSegment returns a terminal iterator. +func (s *addrSet) FindSegment(key uintptr) addrIterator { + seg, _ := s.Find(key) + return seg +} + +// LowerBoundSegment returns the segment with the lowest range that contains a +// key greater than or equal to min. If no such segment exists, +// LowerBoundSegment returns a terminal iterator. +func (s *addrSet) LowerBoundSegment(min uintptr) addrIterator { + seg, gap := s.Find(min) + if seg.Ok() { + return seg + } + return gap.NextSegment() +} + +// UpperBoundSegment returns the segment with the highest range that contains a +// key less than or equal to max. If no such segment exists, UpperBoundSegment +// returns a terminal iterator. +func (s *addrSet) UpperBoundSegment(max uintptr) addrIterator { + seg, gap := s.Find(max) + if seg.Ok() { + return seg + } + return gap.PrevSegment() +} + +// FindGap returns the gap containing the given key. If no such gap exists +// (i.e. the set contains a segment containing that key), FindGap returns a +// terminal iterator. +func (s *addrSet) FindGap(key uintptr) addrGapIterator { + _, gap := s.Find(key) + return gap +} + +// LowerBoundGap returns the gap with the lowest range that is greater than or +// equal to min. +func (s *addrSet) LowerBoundGap(min uintptr) addrGapIterator { + seg, gap := s.Find(min) + if gap.Ok() { + return gap + } + return seg.NextGap() +} + +// UpperBoundGap returns the gap with the highest range that is less than or +// equal to max. +func (s *addrSet) UpperBoundGap(max uintptr) addrGapIterator { + seg, gap := s.Find(max) + if gap.Ok() { + return gap + } + return seg.PrevGap() +} + +// Add inserts the given segment into the set and returns true. If the new +// segment can be merged with adjacent segments, Add will do so. If the new +// segment would overlap an existing segment, Add returns false. If Add +// succeeds, all existing iterators are invalidated. +func (s *addrSet) Add(r addrRange, val __generics_imported0.Value) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.Insert(gap, r, val) + return true +} + +// AddWithoutMerging inserts the given segment into the set and returns true. +// If it would overlap an existing segment, AddWithoutMerging does nothing and +// returns false. If AddWithoutMerging succeeds, all existing iterators are +// invalidated. +func (s *addrSet) AddWithoutMerging(r addrRange, val __generics_imported0.Value) bool { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + gap := s.FindGap(r.Start) + if !gap.Ok() { + return false + } + if r.End > gap.End() { + return false + } + s.InsertWithoutMergingUnchecked(gap, r, val) + return true +} + +// Insert inserts the given segment into the given gap. If the new segment can +// be merged with adjacent segments, Insert will do so. Insert returns an +// iterator to the segment containing the inserted value (which may have been +// merged with other values). All existing iterators (including gap, but not +// including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, Insert panics. +// +// Insert is semantically equivalent to a InsertWithoutMerging followed by a +// Merge, but may be more efficient. Note that there is no unchecked variant of +// Insert since Insert must retrieve and inspect gap's predecessor and +// successor segments regardless. +func (s *addrSet) Insert(gap addrGapIterator, r addrRange, val __generics_imported0.Value) addrIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + prev, next := gap.PrevSegment(), gap.NextSegment() + if prev.Ok() && prev.End() > r.Start { + panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range())) + } + if next.Ok() && next.Start() < r.End { + panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range())) + } + if prev.Ok() && prev.End() == r.Start { + if mval, ok := (addrSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + prev.SetEndUnchecked(r.End) + prev.SetValue(mval) + if next.Ok() && next.Start() == r.End { + val = mval + if mval, ok := (addrSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { + prev.SetEndUnchecked(next.End()) + prev.SetValue(mval) + return s.Remove(next).PrevSegment() + } + } + return prev + } + } + if next.Ok() && next.Start() == r.End { + if mval, ok := (addrSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok { + next.SetStartUnchecked(r.Start) + next.SetValue(mval) + return next + } + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMerging inserts the given segment into the given gap and +// returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// If the gap cannot accommodate the segment, or if r is invalid, +// InsertWithoutMerging panics. +func (s *addrSet) InsertWithoutMerging(gap addrGapIterator, r addrRange, val __generics_imported0.Value) addrIterator { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if gr := gap.Range(); !gr.IsSupersetOf(r) { + panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr)) + } + return s.InsertWithoutMergingUnchecked(gap, r, val) +} + +// InsertWithoutMergingUnchecked inserts the given segment into the given gap +// and returns an iterator to the inserted segment. All existing iterators +// (including gap, but not including the returned iterator) are invalidated. +// +// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +func (s *addrSet) InsertWithoutMergingUnchecked(gap addrGapIterator, r addrRange, val __generics_imported0.Value) addrIterator { + gap = gap.node.rebalanceBeforeInsert(gap) + copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) + copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) + gap.node.keys[gap.index] = r + gap.node.values[gap.index] = val + gap.node.nrSegments++ + return addrIterator{gap.node, gap.index} +} + +// Remove removes the given segment and returns an iterator to the vacated gap. +// All existing iterators (including seg, but not including the returned +// iterator) are invalidated. +func (s *addrSet) Remove(seg addrIterator) addrGapIterator { + + if seg.node.hasChildren { + + victim := seg.PrevSegment() + + seg.SetRangeUnchecked(victim.Range()) + seg.SetValue(victim.Value()) + return s.Remove(victim).NextGap() + } + copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) + copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) + addrSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) + seg.node.nrSegments-- + return seg.node.rebalanceAfterRemove(addrGapIterator{seg.node, seg.index}) +} + +// RemoveAll removes all segments from the set. All existing iterators are +// invalidated. +func (s *addrSet) RemoveAll() { + s.root = addrnode{} +} + +// RemoveRange removes all segments in the given range. An iterator to the +// newly formed gap is returned, and all existing iterators are invalidated. +func (s *addrSet) RemoveRange(r addrRange) addrGapIterator { + seg, gap := s.Find(r.Start) + if seg.Ok() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() { + seg = s.Isolate(seg, r) + gap = s.Remove(seg) + } + return gap +} + +// Merge attempts to merge two neighboring segments. If successful, Merge +// returns an iterator to the merged segment, and all existing iterators are +// invalidated. Otherwise, Merge returns a terminal iterator. +// +// If first is not the predecessor of second, Merge panics. +func (s *addrSet) Merge(first, second addrIterator) addrIterator { + if first.NextSegment() != second { + panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range())) + } + return s.MergeUnchecked(first, second) +} + +// MergeUnchecked attempts to merge two neighboring segments. If successful, +// MergeUnchecked returns an iterator to the merged segment, and all existing +// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal +// iterator. +// +// Precondition: first is the predecessor of second: first.NextSegment() == +// second, first == second.PrevSegment(). +func (s *addrSet) MergeUnchecked(first, second addrIterator) addrIterator { + if first.End() == second.Start() { + if mval, ok := (addrSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok { + + first.SetEndUnchecked(second.End()) + first.SetValue(mval) + return s.Remove(second).PrevSegment() + } + } + return addrIterator{} +} + +// MergeAll attempts to merge all adjacent segments in the set. All existing +// iterators are invalidated. +func (s *addrSet) MergeAll() { + seg := s.FirstSegment() + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeRange attempts to merge all adjacent segments that contain a key in the +// specific range. All existing iterators are invalidated. +func (s *addrSet) MergeRange(r addrRange) { + seg := s.LowerBoundSegment(r.Start) + if !seg.Ok() { + return + } + next := seg.NextSegment() + for next.Ok() && next.Range().Start < r.End { + if mseg := s.MergeUnchecked(seg, next); mseg.Ok() { + seg, next = mseg, mseg.NextSegment() + } else { + seg, next = next, next.NextSegment() + } + } +} + +// MergeAdjacent attempts to merge the segment containing r.Start with its +// predecessor, and the segment containing r.End-1 with its successor. +func (s *addrSet) MergeAdjacent(r addrRange) { + first := s.FindSegment(r.Start) + if first.Ok() { + if prev := first.PrevSegment(); prev.Ok() { + s.Merge(prev, first) + } + } + last := s.FindSegment(r.End - 1) + if last.Ok() { + if next := last.NextSegment(); next.Ok() { + s.Merge(last, next) + } + } +} + +// Split splits the given segment at the given key and returns iterators to the +// two resulting segments. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +// +// If the segment cannot be split at split (because split is at the start or +// end of the segment's range, so splitting would produce a segment with zero +// length, or because split falls outside the segment's range altogether), +// Split panics. +func (s *addrSet) Split(seg addrIterator, split uintptr) (addrIterator, addrIterator) { + if !seg.Range().CanSplitAt(split) { + panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split)) + } + return s.SplitUnchecked(seg, split) +} + +// SplitUnchecked splits the given segment at the given key and returns +// iterators to the two resulting segments. All existing iterators (including +// seg, but not including the returned iterators) are invalidated. +// +// Preconditions: seg.Start() < key < seg.End(). +func (s *addrSet) SplitUnchecked(seg addrIterator, split uintptr) (addrIterator, addrIterator) { + val1, val2 := (addrSetFunctions{}).Split(seg.Range(), seg.Value(), split) + end2 := seg.End() + seg.SetEndUnchecked(split) + seg.SetValue(val1) + seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), addrRange{split, end2}, val2) + + return seg2.PrevSegment(), seg2 +} + +// SplitAt splits the segment straddling split, if one exists. SplitAt returns +// true if a segment was split and false otherwise. If SplitAt splits a +// segment, all existing iterators are invalidated. +func (s *addrSet) SplitAt(split uintptr) bool { + if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) { + s.SplitUnchecked(seg, split) + return true + } + return false +} + +// Isolate ensures that the given segment's range does not escape r by +// splitting at r.Start and r.End if necessary, and returns an updated iterator +// to the bounded segment. All existing iterators (including seg, but not +// including the returned iterators) are invalidated. +func (s *addrSet) Isolate(seg addrIterator, r addrRange) addrIterator { + if seg.Range().CanSplitAt(r.Start) { + _, seg = s.SplitUnchecked(seg, r.Start) + } + if seg.Range().CanSplitAt(r.End) { + seg, _ = s.SplitUnchecked(seg, r.End) + } + return seg +} + +// ApplyContiguous applies a function to a contiguous range of segments, +// splitting if necessary. The function is applied until the first gap is +// encountered, at which point the gap is returned. If the function is applied +// across the entire range, a terminal gap is returned. All existing iterators +// are invalidated. +// +// N.B. The Iterator must not be invalidated by the function. +func (s *addrSet) ApplyContiguous(r addrRange, fn func(seg addrIterator)) addrGapIterator { + seg, gap := s.Find(r.Start) + if !seg.Ok() { + return gap + } + for { + seg = s.Isolate(seg, r) + fn(seg) + if seg.End() >= r.End { + return addrGapIterator{} + } + gap = seg.NextGap() + if !gap.IsEmpty() { + return gap + } + seg = gap.NextSegment() + if !seg.Ok() { + + return addrGapIterator{} + } + } +} + +// +stateify savable +type addrnode struct { + // An internal binary tree node looks like: + // + // K + // / \ + // Cl Cr + // + // where all keys in the subtree rooted by Cl (the left subtree) are less + // than K (the key of the parent node), and all keys in the subtree rooted + // by Cr (the right subtree) are greater than K. + // + // An internal B-tree node's indexes work out to look like: + // + // K0 K1 K2 ... Kn-1 + // / \/ \/ \ ... / \ + // C0 C1 C2 C3 ... Cn-1 Cn + // + // where n is nrSegments. + nrSegments int + + // parent is a pointer to this node's parent. If this node is root, parent + // is nil. + parent *addrnode + + // parentIndex is the index of this node in parent.children. + parentIndex int + + // Flag for internal nodes that is technically redundant with "children[0] + // != nil", but is stored in the first cache line. "hasChildren" rather + // than "isLeaf" because false must be the correct value for an empty root. + hasChildren bool + + // Nodes store keys and values in separate arrays to maximize locality in + // the common case (scanning keys for lookup). + keys [addrmaxDegree - 1]addrRange + values [addrmaxDegree - 1]__generics_imported0.Value + children [addrmaxDegree]*addrnode +} + +// firstSegment returns the first segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *addrnode) firstSegment() addrIterator { + for n.hasChildren { + n = n.children[0] + } + return addrIterator{n, 0} +} + +// lastSegment returns the last segment in the subtree rooted by n. +// +// Preconditions: n.nrSegments != 0. +func (n *addrnode) lastSegment() addrIterator { + for n.hasChildren { + n = n.children[n.nrSegments] + } + return addrIterator{n, n.nrSegments - 1} +} + +func (n *addrnode) prevSibling() *addrnode { + if n.parent == nil || n.parentIndex == 0 { + return nil + } + return n.parent.children[n.parentIndex-1] +} + +func (n *addrnode) nextSibling() *addrnode { + if n.parent == nil || n.parentIndex == n.parent.nrSegments { + return nil + } + return n.parent.children[n.parentIndex+1] +} + +// rebalanceBeforeInsert splits n and its ancestors if they are full, as +// required for insertion, and returns an updated iterator to the position +// represented by gap. +func (n *addrnode) rebalanceBeforeInsert(gap addrGapIterator) addrGapIterator { + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } + if n.nrSegments < addrmaxDegree-1 { + return gap + } + if n.parent == nil { + + left := &addrnode{ + nrSegments: addrminDegree - 1, + parent: n, + parentIndex: 0, + hasChildren: n.hasChildren, + } + right := &addrnode{ + nrSegments: addrminDegree - 1, + parent: n, + parentIndex: 1, + hasChildren: n.hasChildren, + } + copy(left.keys[:addrminDegree-1], n.keys[:addrminDegree-1]) + copy(left.values[:addrminDegree-1], n.values[:addrminDegree-1]) + copy(right.keys[:addrminDegree-1], n.keys[addrminDegree:]) + copy(right.values[:addrminDegree-1], n.values[addrminDegree:]) + n.keys[0], n.values[0] = n.keys[addrminDegree-1], n.values[addrminDegree-1] + addrzeroValueSlice(n.values[1:]) + if n.hasChildren { + copy(left.children[:addrminDegree], n.children[:addrminDegree]) + copy(right.children[:addrminDegree], n.children[addrminDegree:]) + addrzeroNodeSlice(n.children[2:]) + for i := 0; i < addrminDegree; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + right.children[i].parent = right + right.children[i].parentIndex = i + } + } + n.nrSegments = 1 + n.hasChildren = true + n.children[0] = left + n.children[1] = right + if gap.node != n { + return gap + } + if gap.index < addrminDegree { + return addrGapIterator{left, gap.index} + } + return addrGapIterator{right, gap.index - addrminDegree} + } + + copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments]) + copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments]) + n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[addrminDegree-1], n.values[addrminDegree-1] + copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1]) + for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ { + n.parent.children[i].parentIndex = i + } + sibling := &addrnode{ + nrSegments: addrminDegree - 1, + parent: n.parent, + parentIndex: n.parentIndex + 1, + hasChildren: n.hasChildren, + } + n.parent.children[n.parentIndex+1] = sibling + n.parent.nrSegments++ + copy(sibling.keys[:addrminDegree-1], n.keys[addrminDegree:]) + copy(sibling.values[:addrminDegree-1], n.values[addrminDegree:]) + addrzeroValueSlice(n.values[addrminDegree-1:]) + if n.hasChildren { + copy(sibling.children[:addrminDegree], n.children[addrminDegree:]) + addrzeroNodeSlice(n.children[addrminDegree:]) + for i := 0; i < addrminDegree; i++ { + sibling.children[i].parent = sibling + sibling.children[i].parentIndex = i + } + } + n.nrSegments = addrminDegree - 1 + + if gap.node != n { + return gap + } + if gap.index < addrminDegree { + return gap + } + return addrGapIterator{sibling, gap.index - addrminDegree} +} + +// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient +// (contain fewer segments than required by B-tree invariants), as required for +// removal, and returns an updated iterator to the position represented by gap. +// +// Precondition: n is the only node in the tree that may currently violate a +// B-tree invariant. +func (n *addrnode) rebalanceAfterRemove(gap addrGapIterator) addrGapIterator { + for { + if n.nrSegments >= addrminDegree-1 { + return gap + } + if n.parent == nil { + + return gap + } + + if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= addrminDegree { + copy(n.keys[1:], n.keys[:n.nrSegments]) + copy(n.values[1:], n.values[:n.nrSegments]) + n.keys[0] = n.parent.keys[n.parentIndex-1] + n.values[0] = n.parent.values[n.parentIndex-1] + n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1] + n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1] + addrSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + copy(n.children[1:], n.children[:n.nrSegments+1]) + n.children[0] = sibling.children[sibling.nrSegments] + sibling.children[sibling.nrSegments] = nil + n.children[0].parent = n + n.children[0].parentIndex = 0 + for i := 1; i < n.nrSegments+2; i++ { + n.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling && gap.index == sibling.nrSegments { + return addrGapIterator{n, 0} + } + if gap.node == n { + return addrGapIterator{n, gap.index + 1} + } + return gap + } + if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= addrminDegree { + n.keys[n.nrSegments] = n.parent.keys[n.parentIndex] + n.values[n.nrSegments] = n.parent.values[n.parentIndex] + n.parent.keys[n.parentIndex] = sibling.keys[0] + n.parent.values[n.parentIndex] = sibling.values[0] + copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:]) + copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:]) + addrSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1]) + if n.hasChildren { + n.children[n.nrSegments+1] = sibling.children[0] + copy(sibling.children[:sibling.nrSegments], sibling.children[1:]) + sibling.children[sibling.nrSegments] = nil + n.children[n.nrSegments+1].parent = n + n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1 + for i := 0; i < sibling.nrSegments; i++ { + sibling.children[i].parentIndex = i + } + } + n.nrSegments++ + sibling.nrSegments-- + if gap.node == sibling { + if gap.index == 0 { + return addrGapIterator{n, n.nrSegments} + } + return addrGapIterator{sibling, gap.index - 1} + } + return gap + } + + p := n.parent + if p.nrSegments == 1 { + + left, right := p.children[0], p.children[1] + p.nrSegments = left.nrSegments + right.nrSegments + 1 + p.hasChildren = left.hasChildren + p.keys[left.nrSegments] = p.keys[0] + p.values[left.nrSegments] = p.values[0] + copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments]) + copy(p.values[:left.nrSegments], left.values[:left.nrSegments]) + copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1]) + copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := 0; i < p.nrSegments+1; i++ { + p.children[i].parent = p + p.children[i].parentIndex = i + } + } else { + p.children[0] = nil + p.children[1] = nil + } + if gap.node == left { + return addrGapIterator{p, gap.index} + } + if gap.node == right { + return addrGapIterator{p, gap.index + left.nrSegments + 1} + } + return gap + } + // Merge n and either sibling, along with the segment separating the + // two, into whichever of the two nodes comes first. This is the + // reverse of the non-root splitting case in + // node.rebalanceBeforeInsert. + var left, right *addrnode + if n.parentIndex > 0 { + left = n.prevSibling() + right = n + } else { + left = n + right = n.nextSibling() + } + + if gap.node == right { + gap = addrGapIterator{left, gap.index + left.nrSegments + 1} + } + left.keys[left.nrSegments] = p.keys[left.parentIndex] + left.values[left.nrSegments] = p.values[left.parentIndex] + copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments]) + copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments]) + if left.hasChildren { + copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1]) + for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ { + left.children[i].parent = left + left.children[i].parentIndex = i + } + } + left.nrSegments += right.nrSegments + 1 + copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments]) + copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments]) + addrSetFunctions{}.ClearValue(&p.values[p.nrSegments-1]) + copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1]) + for i := 0; i < p.nrSegments; i++ { + p.children[i].parentIndex = i + } + p.children[p.nrSegments] = nil + p.nrSegments-- + + n = p + } +} + +// A Iterator is conceptually one of: +// +// - A pointer to a segment in a set; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Iterators are copyable values and are meaningfully equality-comparable. The +// zero value of Iterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type addrIterator struct { + // node is the node containing the iterated segment. If the iterator is + // terminal, node is nil. + node *addrnode + + // index is the index of the segment in node.keys/values. + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (seg addrIterator) Ok() bool { + return seg.node != nil +} + +// Range returns the iterated segment's range key. +func (seg addrIterator) Range() addrRange { + return seg.node.keys[seg.index] +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (seg addrIterator) Start() uintptr { + return seg.node.keys[seg.index].Start +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (seg addrIterator) End() uintptr { + return seg.node.keys[seg.index].End +} + +// SetRangeUnchecked mutates the iterated segment's range key. This operation +// does not invalidate any iterators. +// +// Preconditions: +// +// - r.Length() > 0. +// +// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), +// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then +// r.start >= seg.PrevSegment().End(). +func (seg addrIterator) SetRangeUnchecked(r addrRange) { + seg.node.keys[seg.index] = r +} + +// SetRange mutates the iterated segment's range key. If the new range would +// cause the iterated segment to overlap another segment, or if the new range +// is invalid, SetRange panics. This operation does not invalidate any +// iterators. +func (seg addrIterator) SetRange(r addrRange) { + if r.Length() <= 0 { + panic(fmt.Sprintf("invalid segment range %v", r)) + } + if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range())) + } + if next := seg.NextSegment(); next.Ok() && r.End > next.Start() { + panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range())) + } + seg.SetRangeUnchecked(r) +} + +// SetStartUnchecked mutates the iterated segment's start. This operation does +// not invalidate any iterators. +// +// Preconditions: The new start must be valid: start < seg.End(); if +// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +func (seg addrIterator) SetStartUnchecked(start uintptr) { + seg.node.keys[seg.index].Start = start +} + +// SetStart mutates the iterated segment's start. If the new start value would +// cause the iterated segment to overlap another segment, or would result in an +// invalid range, SetStart panics. This operation does not invalidate any +// iterators. +func (seg addrIterator) SetStart(start uintptr) { + if start >= seg.End() { + panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range())) + } + if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() { + panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range())) + } + seg.SetStartUnchecked(start) +} + +// SetEndUnchecked mutates the iterated segment's end. This operation does not +// invalidate any iterators. +// +// Preconditions: The new end must be valid: end > seg.Start(); if +// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +func (seg addrIterator) SetEndUnchecked(end uintptr) { + seg.node.keys[seg.index].End = end +} + +// SetEnd mutates the iterated segment's end. If the new end value would cause +// the iterated segment to overlap another segment, or would result in an +// invalid range, SetEnd panics. This operation does not invalidate any +// iterators. +func (seg addrIterator) SetEnd(end uintptr) { + if end <= seg.Start() { + panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range())) + } + if next := seg.NextSegment(); next.Ok() && end > next.Start() { + panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range())) + } + seg.SetEndUnchecked(end) +} + +// Value returns a copy of the iterated segment's value. +func (seg addrIterator) Value() __generics_imported0.Value { + return seg.node.values[seg.index] +} + +// ValuePtr returns a pointer to the iterated segment's value. The pointer is +// invalidated if the iterator is invalidated. This operation does not +// invalidate any iterators. +func (seg addrIterator) ValuePtr() *__generics_imported0.Value { + return &seg.node.values[seg.index] +} + +// SetValue mutates the iterated segment's value. This operation does not +// invalidate any iterators. +func (seg addrIterator) SetValue(val __generics_imported0.Value) { + seg.node.values[seg.index] = val +} + +// PrevSegment returns the iterated segment's predecessor. If there is no +// preceding segment, PrevSegment returns a terminal iterator. +func (seg addrIterator) PrevSegment() addrIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index].lastSegment() + } + if seg.index > 0 { + return addrIterator{seg.node, seg.index - 1} + } + if seg.node.parent == nil { + return addrIterator{} + } + return addrsegmentBeforePosition(seg.node.parent, seg.node.parentIndex) +} + +// NextSegment returns the iterated segment's successor. If there is no +// succeeding segment, NextSegment returns a terminal iterator. +func (seg addrIterator) NextSegment() addrIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment() + } + if seg.index < seg.node.nrSegments-1 { + return addrIterator{seg.node, seg.index + 1} + } + if seg.node.parent == nil { + return addrIterator{} + } + return addrsegmentAfterPosition(seg.node.parent, seg.node.parentIndex) +} + +// PrevGap returns the gap immediately before the iterated segment. +func (seg addrIterator) PrevGap() addrGapIterator { + if seg.node.hasChildren { + + return seg.node.children[seg.index].lastSegment().NextGap() + } + return addrGapIterator{seg.node, seg.index} +} + +// NextGap returns the gap immediately after the iterated segment. +func (seg addrIterator) NextGap() addrGapIterator { + if seg.node.hasChildren { + return seg.node.children[seg.index+1].firstSegment().PrevGap() + } + return addrGapIterator{seg.node, seg.index + 1} +} + +// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent, +// or the gap before the iterated segment otherwise. If seg.Start() == +// Functions.MinKey(), PrevNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be +// non-terminal. +func (seg addrIterator) PrevNonEmpty() (addrIterator, addrGapIterator) { + gap := seg.PrevGap() + if gap.Range().Length() != 0 { + return addrIterator{}, gap + } + return gap.PrevSegment(), addrGapIterator{} +} + +// NextNonEmpty returns the iterated segment's successor if it is adjacent, or +// the gap after the iterated segment otherwise. If seg.End() == +// Functions.MaxKey(), NextNonEmpty will return two terminal iterators. +// Otherwise, exactly one of the iterators returned by NextNonEmpty will be +// non-terminal. +func (seg addrIterator) NextNonEmpty() (addrIterator, addrGapIterator) { + gap := seg.NextGap() + if gap.Range().Length() != 0 { + return addrIterator{}, gap + } + return gap.NextSegment(), addrGapIterator{} +} + +// A GapIterator is conceptually one of: +// +// - A pointer to a position between two segments, before the first segment, or +// after the last segment in a set, called a *gap*; or +// +// - A terminal iterator, which is a sentinel indicating that the end of +// iteration has been reached. +// +// Note that the gap between two adjacent segments exists (iterators to it are +// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true +// for such gaps. An empty set contains a single gap, spanning the entire range +// of the set's keys. +// +// GapIterators are copyable values and are meaningfully equality-comparable. +// The zero value of GapIterator is a terminal iterator. +// +// Unless otherwise specified, any mutation of a set invalidates all existing +// iterators into the set. +type addrGapIterator struct { + // The representation of a GapIterator is identical to that of an Iterator, + // except that index corresponds to positions between segments in the same + // way as for node.children (see comment for node.nrSegments). + node *addrnode + index int +} + +// Ok returns true if the iterator is not terminal. All other methods are only +// valid for non-terminal iterators. +func (gap addrGapIterator) Ok() bool { + return gap.node != nil +} + +// Range returns the range spanned by the iterated gap. +func (gap addrGapIterator) Range() addrRange { + return addrRange{gap.Start(), gap.End()} +} + +// Start is equivalent to Range().Start, but should be preferred if only the +// start of the range is needed. +func (gap addrGapIterator) Start() uintptr { + if ps := gap.PrevSegment(); ps.Ok() { + return ps.End() + } + return addrSetFunctions{}.MinKey() +} + +// End is equivalent to Range().End, but should be preferred if only the end of +// the range is needed. +func (gap addrGapIterator) End() uintptr { + if ns := gap.NextSegment(); ns.Ok() { + return ns.Start() + } + return addrSetFunctions{}.MaxKey() +} + +// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is +// between two adjacent segments.) +func (gap addrGapIterator) IsEmpty() bool { + return gap.Range().Length() == 0 +} + +// PrevSegment returns the segment immediately before the iterated gap. If no +// such segment exists, PrevSegment returns a terminal iterator. +func (gap addrGapIterator) PrevSegment() addrIterator { + return addrsegmentBeforePosition(gap.node, gap.index) +} + +// NextSegment returns the segment immediately after the iterated gap. If no +// such segment exists, NextSegment returns a terminal iterator. +func (gap addrGapIterator) NextSegment() addrIterator { + return addrsegmentAfterPosition(gap.node, gap.index) +} + +// PrevGap returns the iterated gap's predecessor. If no such gap exists, +// PrevGap returns a terminal iterator. +func (gap addrGapIterator) PrevGap() addrGapIterator { + seg := gap.PrevSegment() + if !seg.Ok() { + return addrGapIterator{} + } + return seg.PrevGap() +} + +// NextGap returns the iterated gap's successor. If no such gap exists, NextGap +// returns a terminal iterator. +func (gap addrGapIterator) NextGap() addrGapIterator { + seg := gap.NextSegment() + if !seg.Ok() { + return addrGapIterator{} + } + return seg.NextGap() +} + +// segmentBeforePosition returns the predecessor segment of the position given +// by n.children[i], which may or may not contain a child. If no such segment +// exists, segmentBeforePosition returns a terminal iterator. +func addrsegmentBeforePosition(n *addrnode, i int) addrIterator { + for i == 0 { + if n.parent == nil { + return addrIterator{} + } + n, i = n.parent, n.parentIndex + } + return addrIterator{n, i - 1} +} + +// segmentAfterPosition returns the successor segment of the position given by +// n.children[i], which may or may not contain a child. If no such segment +// exists, segmentAfterPosition returns a terminal iterator. +func addrsegmentAfterPosition(n *addrnode, i int) addrIterator { + for i == n.nrSegments { + if n.parent == nil { + return addrIterator{} + } + n, i = n.parent, n.parentIndex + } + return addrIterator{n, i} +} + +func addrzeroValueSlice(slice []__generics_imported0.Value) { + + for i := range slice { + addrSetFunctions{}.ClearValue(&slice[i]) + } +} + +func addrzeroNodeSlice(slice []*addrnode) { + for i := range slice { + slice[i] = nil + } +} + +// String stringifies a Set for debugging. +func (s *addrSet) String() string { + return s.root.String() +} + +// String stringifes a node (and all of its children) for debugging. +func (n *addrnode) String() string { + var buf bytes.Buffer + n.writeDebugString(&buf, "") + return buf.String() +} + +func (n *addrnode) writeDebugString(buf *bytes.Buffer, prefix string) { + if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) { + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren)) + } + for i := 0; i < n.nrSegments; i++ { + if child := n.children[i]; child != nil { + cprefix := fmt.Sprintf("%s- % 3d ", prefix, i) + if child.parent != n || child.parentIndex != i { + buf.WriteString(cprefix) + buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i)) + } + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) + } + buf.WriteString(prefix) + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + if child := n.children[n.nrSegments]; child != nil { + child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) + } +} + +// SegmentDataSlices represents segments from a set as slices of start, end, and +// values. SegmentDataSlices is primarily used as an intermediate representation +// for save/restore and the layout here is optimized for that. +// +// +stateify savable +type addrSegmentDataSlices struct { + Start []uintptr + End []uintptr + Values []__generics_imported0.Value +} + +// ExportSortedSlice returns a copy of all segments in the given set, in ascending +// key order. +func (s *addrSet) ExportSortedSlices() *addrSegmentDataSlices { + var sds addrSegmentDataSlices + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + sds.Start = append(sds.Start, seg.Start()) + sds.End = append(sds.End, seg.End()) + sds.Values = append(sds.Values, seg.Value()) + } + sds.Start = sds.Start[:len(sds.Start):len(sds.Start)] + sds.End = sds.End[:len(sds.End):len(sds.End)] + sds.Values = sds.Values[:len(sds.Values):len(sds.Values)] + return &sds +} + +// ImportSortedSlice initializes the given set from the given slice. +// +// Preconditions: s must be empty. sds must represent a valid set (the segments +// in sds must have valid lengths that do not overlap). The segments in sds +// must be sorted in ascending key order. +func (s *addrSet) ImportSortedSlices(sds *addrSegmentDataSlices) error { + if !s.IsEmpty() { + return fmt.Errorf("cannot import into non-empty set %v", s) + } + gap := s.FirstGap() + for i := range sds.Start { + r := addrRange{sds.Start[i], sds.End[i]} + if !gap.Range().IsSupersetOf(r) { + return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i]) + } + gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap() + } + return nil +} +func (s *addrSet) saveRoot() *addrSegmentDataSlices { + return s.ExportSortedSlices() +} + +func (s *addrSet) loadRoot(sds *addrSegmentDataSlices) { + if err := s.ImportSortedSlices(sds); err != nil { + panic(err) + } +} diff --git a/pkg/state/object.proto b/pkg/state/object.proto deleted file mode 100644 index 952289069..000000000 --- a/pkg/state/object.proto +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor.state.statefile; - -// Slice is a slice value. -message Slice { - uint32 length = 1; - uint32 capacity = 2; - uint64 ref_value = 3; -} - -// Array is an array value. -message Array { - repeated Object contents = 1; -} - -// Map is a map value. -message Map { - repeated Object keys = 1; - repeated Object values = 2; -} - -// Interface is an interface value. -message Interface { - string type = 1; - Object value = 2; -} - -// Struct is a basic composite value. -message Struct { - repeated Field fields = 1; -} - -// Field encodes a single field. -message Field { - string name = 1; - Object value = 2; -} - -// Uint16s encodes an uint16 array. To be used inside oneof structure. -message Uint16s { - // There is no 16-bit type in protobuf so we use variable length 32-bit here. - repeated uint32 values = 1; -} - -// Uint32s encodes an uint32 array. To be used inside oneof structure. -message Uint32s { - repeated fixed32 values = 1; -} - -// Uint64s encodes an uint64 array. To be used inside oneof structure. -message Uint64s { - repeated fixed64 values = 1; -} - -// Uintptrs encodes an uintptr array. To be used inside oneof structure. -message Uintptrs { - repeated fixed64 values = 1; -} - -// Int8s encodes an int8 array. To be used inside oneof structure. -message Int8s { - bytes values = 1; -} - -// Int16s encodes an int16 array. To be used inside oneof structure. -message Int16s { - // There is no 16-bit type in protobuf so we use variable length 32-bit here. - repeated int32 values = 1; -} - -// Int32s encodes an int32 array. To be used inside oneof structure. -message Int32s { - repeated sfixed32 values = 1; -} - -// Int64s encodes an int64 array. To be used inside oneof structure. -message Int64s { - repeated sfixed64 values = 1; -} - -// Bools encodes a boolean array. To be used inside oneof structure. -message Bools { - repeated bool values = 1; -} - -// Float64s encodes a float64 array. To be used inside oneof structure. -message Float64s { - repeated double values = 1; -} - -// Float32s encodes a float32 array. To be used inside oneof structure. -message Float32s { - repeated float values = 1; -} - -// Object are primitive encodings. -// -// Note that ref_value references an Object.id, below. -message Object { - oneof value { - bool bool_value = 1; - bytes string_value = 2; - int64 int64_value = 3; - uint64 uint64_value = 4; - double double_value = 5; - uint64 ref_value = 6; - Slice slice_value = 7; - Array array_value = 8; - Interface interface_value = 9; - Struct struct_value = 10; - Map map_value = 11; - bytes byte_array_value = 12; - Uint16s uint16_array_value = 13; - Uint32s uint32_array_value = 14; - Uint64s uint64_array_value = 15; - Uintptrs uintptr_array_value = 16; - Int8s int8_array_value = 17; - Int16s int16_array_value = 18; - Int32s int32_array_value = 19; - Int64s int64_array_value = 20; - Bools bool_array_value = 21; - Float64s float64_array_value = 22; - Float32s float32_array_value = 23; - } -} diff --git a/pkg/state/object_go_proto/object.pb.go b/pkg/state/object_go_proto/object.pb.go new file mode 100755 index 000000000..dc5127149 --- /dev/null +++ b/pkg/state/object_go_proto/object.pb.go @@ -0,0 +1,1195 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: pkg/state/object.proto + +package gvisor_state_statefile + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type Slice struct { + Length uint32 `protobuf:"varint,1,opt,name=length,proto3" json:"length,omitempty"` + Capacity uint32 `protobuf:"varint,2,opt,name=capacity,proto3" json:"capacity,omitempty"` + RefValue uint64 `protobuf:"varint,3,opt,name=ref_value,json=refValue,proto3" json:"ref_value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Slice) Reset() { *m = Slice{} } +func (m *Slice) String() string { return proto.CompactTextString(m) } +func (*Slice) ProtoMessage() {} +func (*Slice) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{0} +} + +func (m *Slice) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Slice.Unmarshal(m, b) +} +func (m *Slice) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Slice.Marshal(b, m, deterministic) +} +func (m *Slice) XXX_Merge(src proto.Message) { + xxx_messageInfo_Slice.Merge(m, src) +} +func (m *Slice) XXX_Size() int { + return xxx_messageInfo_Slice.Size(m) +} +func (m *Slice) XXX_DiscardUnknown() { + xxx_messageInfo_Slice.DiscardUnknown(m) +} + +var xxx_messageInfo_Slice proto.InternalMessageInfo + +func (m *Slice) GetLength() uint32 { + if m != nil { + return m.Length + } + return 0 +} + +func (m *Slice) GetCapacity() uint32 { + if m != nil { + return m.Capacity + } + return 0 +} + +func (m *Slice) GetRefValue() uint64 { + if m != nil { + return m.RefValue + } + return 0 +} + +type Array struct { + Contents []*Object `protobuf:"bytes,1,rep,name=contents,proto3" json:"contents,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Array) Reset() { *m = Array{} } +func (m *Array) String() string { return proto.CompactTextString(m) } +func (*Array) ProtoMessage() {} +func (*Array) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{1} +} + +func (m *Array) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Array.Unmarshal(m, b) +} +func (m *Array) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Array.Marshal(b, m, deterministic) +} +func (m *Array) XXX_Merge(src proto.Message) { + xxx_messageInfo_Array.Merge(m, src) +} +func (m *Array) XXX_Size() int { + return xxx_messageInfo_Array.Size(m) +} +func (m *Array) XXX_DiscardUnknown() { + xxx_messageInfo_Array.DiscardUnknown(m) +} + +var xxx_messageInfo_Array proto.InternalMessageInfo + +func (m *Array) GetContents() []*Object { + if m != nil { + return m.Contents + } + return nil +} + +type Map struct { + Keys []*Object `protobuf:"bytes,1,rep,name=keys,proto3" json:"keys,omitempty"` + Values []*Object `protobuf:"bytes,2,rep,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Map) Reset() { *m = Map{} } +func (m *Map) String() string { return proto.CompactTextString(m) } +func (*Map) ProtoMessage() {} +func (*Map) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{2} +} + +func (m *Map) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Map.Unmarshal(m, b) +} +func (m *Map) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Map.Marshal(b, m, deterministic) +} +func (m *Map) XXX_Merge(src proto.Message) { + xxx_messageInfo_Map.Merge(m, src) +} +func (m *Map) XXX_Size() int { + return xxx_messageInfo_Map.Size(m) +} +func (m *Map) XXX_DiscardUnknown() { + xxx_messageInfo_Map.DiscardUnknown(m) +} + +var xxx_messageInfo_Map proto.InternalMessageInfo + +func (m *Map) GetKeys() []*Object { + if m != nil { + return m.Keys + } + return nil +} + +func (m *Map) GetValues() []*Object { + if m != nil { + return m.Values + } + return nil +} + +type Interface struct { + Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` + Value *Object `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Interface) Reset() { *m = Interface{} } +func (m *Interface) String() string { return proto.CompactTextString(m) } +func (*Interface) ProtoMessage() {} +func (*Interface) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{3} +} + +func (m *Interface) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Interface.Unmarshal(m, b) +} +func (m *Interface) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Interface.Marshal(b, m, deterministic) +} +func (m *Interface) XXX_Merge(src proto.Message) { + xxx_messageInfo_Interface.Merge(m, src) +} +func (m *Interface) XXX_Size() int { + return xxx_messageInfo_Interface.Size(m) +} +func (m *Interface) XXX_DiscardUnknown() { + xxx_messageInfo_Interface.DiscardUnknown(m) +} + +var xxx_messageInfo_Interface proto.InternalMessageInfo + +func (m *Interface) GetType() string { + if m != nil { + return m.Type + } + return "" +} + +func (m *Interface) GetValue() *Object { + if m != nil { + return m.Value + } + return nil +} + +type Struct struct { + Fields []*Field `protobuf:"bytes,1,rep,name=fields,proto3" json:"fields,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Struct) Reset() { *m = Struct{} } +func (m *Struct) String() string { return proto.CompactTextString(m) } +func (*Struct) ProtoMessage() {} +func (*Struct) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{4} +} + +func (m *Struct) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Struct.Unmarshal(m, b) +} +func (m *Struct) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Struct.Marshal(b, m, deterministic) +} +func (m *Struct) XXX_Merge(src proto.Message) { + xxx_messageInfo_Struct.Merge(m, src) +} +func (m *Struct) XXX_Size() int { + return xxx_messageInfo_Struct.Size(m) +} +func (m *Struct) XXX_DiscardUnknown() { + xxx_messageInfo_Struct.DiscardUnknown(m) +} + +var xxx_messageInfo_Struct proto.InternalMessageInfo + +func (m *Struct) GetFields() []*Field { + if m != nil { + return m.Fields + } + return nil +} + +type Field struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Value *Object `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Field) Reset() { *m = Field{} } +func (m *Field) String() string { return proto.CompactTextString(m) } +func (*Field) ProtoMessage() {} +func (*Field) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{5} +} + +func (m *Field) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Field.Unmarshal(m, b) +} +func (m *Field) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Field.Marshal(b, m, deterministic) +} +func (m *Field) XXX_Merge(src proto.Message) { + xxx_messageInfo_Field.Merge(m, src) +} +func (m *Field) XXX_Size() int { + return xxx_messageInfo_Field.Size(m) +} +func (m *Field) XXX_DiscardUnknown() { + xxx_messageInfo_Field.DiscardUnknown(m) +} + +var xxx_messageInfo_Field proto.InternalMessageInfo + +func (m *Field) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *Field) GetValue() *Object { + if m != nil { + return m.Value + } + return nil +} + +type Uint16S struct { + Values []uint32 `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Uint16S) Reset() { *m = Uint16S{} } +func (m *Uint16S) String() string { return proto.CompactTextString(m) } +func (*Uint16S) ProtoMessage() {} +func (*Uint16S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{6} +} + +func (m *Uint16S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Uint16S.Unmarshal(m, b) +} +func (m *Uint16S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Uint16S.Marshal(b, m, deterministic) +} +func (m *Uint16S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Uint16S.Merge(m, src) +} +func (m *Uint16S) XXX_Size() int { + return xxx_messageInfo_Uint16S.Size(m) +} +func (m *Uint16S) XXX_DiscardUnknown() { + xxx_messageInfo_Uint16S.DiscardUnknown(m) +} + +var xxx_messageInfo_Uint16S proto.InternalMessageInfo + +func (m *Uint16S) GetValues() []uint32 { + if m != nil { + return m.Values + } + return nil +} + +type Uint32S struct { + Values []uint32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Uint32S) Reset() { *m = Uint32S{} } +func (m *Uint32S) String() string { return proto.CompactTextString(m) } +func (*Uint32S) ProtoMessage() {} +func (*Uint32S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{7} +} + +func (m *Uint32S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Uint32S.Unmarshal(m, b) +} +func (m *Uint32S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Uint32S.Marshal(b, m, deterministic) +} +func (m *Uint32S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Uint32S.Merge(m, src) +} +func (m *Uint32S) XXX_Size() int { + return xxx_messageInfo_Uint32S.Size(m) +} +func (m *Uint32S) XXX_DiscardUnknown() { + xxx_messageInfo_Uint32S.DiscardUnknown(m) +} + +var xxx_messageInfo_Uint32S proto.InternalMessageInfo + +func (m *Uint32S) GetValues() []uint32 { + if m != nil { + return m.Values + } + return nil +} + +type Uint64S struct { + Values []uint64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Uint64S) Reset() { *m = Uint64S{} } +func (m *Uint64S) String() string { return proto.CompactTextString(m) } +func (*Uint64S) ProtoMessage() {} +func (*Uint64S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{8} +} + +func (m *Uint64S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Uint64S.Unmarshal(m, b) +} +func (m *Uint64S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Uint64S.Marshal(b, m, deterministic) +} +func (m *Uint64S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Uint64S.Merge(m, src) +} +func (m *Uint64S) XXX_Size() int { + return xxx_messageInfo_Uint64S.Size(m) +} +func (m *Uint64S) XXX_DiscardUnknown() { + xxx_messageInfo_Uint64S.DiscardUnknown(m) +} + +var xxx_messageInfo_Uint64S proto.InternalMessageInfo + +func (m *Uint64S) GetValues() []uint64 { + if m != nil { + return m.Values + } + return nil +} + +type Uintptrs struct { + Values []uint64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Uintptrs) Reset() { *m = Uintptrs{} } +func (m *Uintptrs) String() string { return proto.CompactTextString(m) } +func (*Uintptrs) ProtoMessage() {} +func (*Uintptrs) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{9} +} + +func (m *Uintptrs) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Uintptrs.Unmarshal(m, b) +} +func (m *Uintptrs) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Uintptrs.Marshal(b, m, deterministic) +} +func (m *Uintptrs) XXX_Merge(src proto.Message) { + xxx_messageInfo_Uintptrs.Merge(m, src) +} +func (m *Uintptrs) XXX_Size() int { + return xxx_messageInfo_Uintptrs.Size(m) +} +func (m *Uintptrs) XXX_DiscardUnknown() { + xxx_messageInfo_Uintptrs.DiscardUnknown(m) +} + +var xxx_messageInfo_Uintptrs proto.InternalMessageInfo + +func (m *Uintptrs) GetValues() []uint64 { + if m != nil { + return m.Values + } + return nil +} + +type Int8S struct { + Values []byte `protobuf:"bytes,1,opt,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Int8S) Reset() { *m = Int8S{} } +func (m *Int8S) String() string { return proto.CompactTextString(m) } +func (*Int8S) ProtoMessage() {} +func (*Int8S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{10} +} + +func (m *Int8S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Int8S.Unmarshal(m, b) +} +func (m *Int8S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Int8S.Marshal(b, m, deterministic) +} +func (m *Int8S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Int8S.Merge(m, src) +} +func (m *Int8S) XXX_Size() int { + return xxx_messageInfo_Int8S.Size(m) +} +func (m *Int8S) XXX_DiscardUnknown() { + xxx_messageInfo_Int8S.DiscardUnknown(m) +} + +var xxx_messageInfo_Int8S proto.InternalMessageInfo + +func (m *Int8S) GetValues() []byte { + if m != nil { + return m.Values + } + return nil +} + +type Int16S struct { + Values []int32 `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Int16S) Reset() { *m = Int16S{} } +func (m *Int16S) String() string { return proto.CompactTextString(m) } +func (*Int16S) ProtoMessage() {} +func (*Int16S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{11} +} + +func (m *Int16S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Int16S.Unmarshal(m, b) +} +func (m *Int16S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Int16S.Marshal(b, m, deterministic) +} +func (m *Int16S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Int16S.Merge(m, src) +} +func (m *Int16S) XXX_Size() int { + return xxx_messageInfo_Int16S.Size(m) +} +func (m *Int16S) XXX_DiscardUnknown() { + xxx_messageInfo_Int16S.DiscardUnknown(m) +} + +var xxx_messageInfo_Int16S proto.InternalMessageInfo + +func (m *Int16S) GetValues() []int32 { + if m != nil { + return m.Values + } + return nil +} + +type Int32S struct { + Values []int32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Int32S) Reset() { *m = Int32S{} } +func (m *Int32S) String() string { return proto.CompactTextString(m) } +func (*Int32S) ProtoMessage() {} +func (*Int32S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{12} +} + +func (m *Int32S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Int32S.Unmarshal(m, b) +} +func (m *Int32S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Int32S.Marshal(b, m, deterministic) +} +func (m *Int32S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Int32S.Merge(m, src) +} +func (m *Int32S) XXX_Size() int { + return xxx_messageInfo_Int32S.Size(m) +} +func (m *Int32S) XXX_DiscardUnknown() { + xxx_messageInfo_Int32S.DiscardUnknown(m) +} + +var xxx_messageInfo_Int32S proto.InternalMessageInfo + +func (m *Int32S) GetValues() []int32 { + if m != nil { + return m.Values + } + return nil +} + +type Int64S struct { + Values []int64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Int64S) Reset() { *m = Int64S{} } +func (m *Int64S) String() string { return proto.CompactTextString(m) } +func (*Int64S) ProtoMessage() {} +func (*Int64S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{13} +} + +func (m *Int64S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Int64S.Unmarshal(m, b) +} +func (m *Int64S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Int64S.Marshal(b, m, deterministic) +} +func (m *Int64S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Int64S.Merge(m, src) +} +func (m *Int64S) XXX_Size() int { + return xxx_messageInfo_Int64S.Size(m) +} +func (m *Int64S) XXX_DiscardUnknown() { + xxx_messageInfo_Int64S.DiscardUnknown(m) +} + +var xxx_messageInfo_Int64S proto.InternalMessageInfo + +func (m *Int64S) GetValues() []int64 { + if m != nil { + return m.Values + } + return nil +} + +type Bools struct { + Values []bool `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Bools) Reset() { *m = Bools{} } +func (m *Bools) String() string { return proto.CompactTextString(m) } +func (*Bools) ProtoMessage() {} +func (*Bools) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{14} +} + +func (m *Bools) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Bools.Unmarshal(m, b) +} +func (m *Bools) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Bools.Marshal(b, m, deterministic) +} +func (m *Bools) XXX_Merge(src proto.Message) { + xxx_messageInfo_Bools.Merge(m, src) +} +func (m *Bools) XXX_Size() int { + return xxx_messageInfo_Bools.Size(m) +} +func (m *Bools) XXX_DiscardUnknown() { + xxx_messageInfo_Bools.DiscardUnknown(m) +} + +var xxx_messageInfo_Bools proto.InternalMessageInfo + +func (m *Bools) GetValues() []bool { + if m != nil { + return m.Values + } + return nil +} + +type Float64S struct { + Values []float64 `protobuf:"fixed64,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Float64S) Reset() { *m = Float64S{} } +func (m *Float64S) String() string { return proto.CompactTextString(m) } +func (*Float64S) ProtoMessage() {} +func (*Float64S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{15} +} + +func (m *Float64S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Float64S.Unmarshal(m, b) +} +func (m *Float64S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Float64S.Marshal(b, m, deterministic) +} +func (m *Float64S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Float64S.Merge(m, src) +} +func (m *Float64S) XXX_Size() int { + return xxx_messageInfo_Float64S.Size(m) +} +func (m *Float64S) XXX_DiscardUnknown() { + xxx_messageInfo_Float64S.DiscardUnknown(m) +} + +var xxx_messageInfo_Float64S proto.InternalMessageInfo + +func (m *Float64S) GetValues() []float64 { + if m != nil { + return m.Values + } + return nil +} + +type Float32S struct { + Values []float32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Float32S) Reset() { *m = Float32S{} } +func (m *Float32S) String() string { return proto.CompactTextString(m) } +func (*Float32S) ProtoMessage() {} +func (*Float32S) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{16} +} + +func (m *Float32S) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Float32S.Unmarshal(m, b) +} +func (m *Float32S) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Float32S.Marshal(b, m, deterministic) +} +func (m *Float32S) XXX_Merge(src proto.Message) { + xxx_messageInfo_Float32S.Merge(m, src) +} +func (m *Float32S) XXX_Size() int { + return xxx_messageInfo_Float32S.Size(m) +} +func (m *Float32S) XXX_DiscardUnknown() { + xxx_messageInfo_Float32S.DiscardUnknown(m) +} + +var xxx_messageInfo_Float32S proto.InternalMessageInfo + +func (m *Float32S) GetValues() []float32 { + if m != nil { + return m.Values + } + return nil +} + +type Object struct { + // Types that are valid to be assigned to Value: + // *Object_BoolValue + // *Object_StringValue + // *Object_Int64Value + // *Object_Uint64Value + // *Object_DoubleValue + // *Object_RefValue + // *Object_SliceValue + // *Object_ArrayValue + // *Object_InterfaceValue + // *Object_StructValue + // *Object_MapValue + // *Object_ByteArrayValue + // *Object_Uint16ArrayValue + // *Object_Uint32ArrayValue + // *Object_Uint64ArrayValue + // *Object_UintptrArrayValue + // *Object_Int8ArrayValue + // *Object_Int16ArrayValue + // *Object_Int32ArrayValue + // *Object_Int64ArrayValue + // *Object_BoolArrayValue + // *Object_Float64ArrayValue + // *Object_Float32ArrayValue + Value isObject_Value `protobuf_oneof:"value"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Object) Reset() { *m = Object{} } +func (m *Object) String() string { return proto.CompactTextString(m) } +func (*Object) ProtoMessage() {} +func (*Object) Descriptor() ([]byte, []int) { + return fileDescriptor_3dee2c1912d4d62d, []int{17} +} + +func (m *Object) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Object.Unmarshal(m, b) +} +func (m *Object) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Object.Marshal(b, m, deterministic) +} +func (m *Object) XXX_Merge(src proto.Message) { + xxx_messageInfo_Object.Merge(m, src) +} +func (m *Object) XXX_Size() int { + return xxx_messageInfo_Object.Size(m) +} +func (m *Object) XXX_DiscardUnknown() { + xxx_messageInfo_Object.DiscardUnknown(m) +} + +var xxx_messageInfo_Object proto.InternalMessageInfo + +type isObject_Value interface { + isObject_Value() +} + +type Object_BoolValue struct { + BoolValue bool `protobuf:"varint,1,opt,name=bool_value,json=boolValue,proto3,oneof"` +} + +type Object_StringValue struct { + StringValue []byte `protobuf:"bytes,2,opt,name=string_value,json=stringValue,proto3,oneof"` +} + +type Object_Int64Value struct { + Int64Value int64 `protobuf:"varint,3,opt,name=int64_value,json=int64Value,proto3,oneof"` +} + +type Object_Uint64Value struct { + Uint64Value uint64 `protobuf:"varint,4,opt,name=uint64_value,json=uint64Value,proto3,oneof"` +} + +type Object_DoubleValue struct { + DoubleValue float64 `protobuf:"fixed64,5,opt,name=double_value,json=doubleValue,proto3,oneof"` +} + +type Object_RefValue struct { + RefValue uint64 `protobuf:"varint,6,opt,name=ref_value,json=refValue,proto3,oneof"` +} + +type Object_SliceValue struct { + SliceValue *Slice `protobuf:"bytes,7,opt,name=slice_value,json=sliceValue,proto3,oneof"` +} + +type Object_ArrayValue struct { + ArrayValue *Array `protobuf:"bytes,8,opt,name=array_value,json=arrayValue,proto3,oneof"` +} + +type Object_InterfaceValue struct { + InterfaceValue *Interface `protobuf:"bytes,9,opt,name=interface_value,json=interfaceValue,proto3,oneof"` +} + +type Object_StructValue struct { + StructValue *Struct `protobuf:"bytes,10,opt,name=struct_value,json=structValue,proto3,oneof"` +} + +type Object_MapValue struct { + MapValue *Map `protobuf:"bytes,11,opt,name=map_value,json=mapValue,proto3,oneof"` +} + +type Object_ByteArrayValue struct { + ByteArrayValue []byte `protobuf:"bytes,12,opt,name=byte_array_value,json=byteArrayValue,proto3,oneof"` +} + +type Object_Uint16ArrayValue struct { + Uint16ArrayValue *Uint16S `protobuf:"bytes,13,opt,name=uint16_array_value,json=uint16ArrayValue,proto3,oneof"` +} + +type Object_Uint32ArrayValue struct { + Uint32ArrayValue *Uint32S `protobuf:"bytes,14,opt,name=uint32_array_value,json=uint32ArrayValue,proto3,oneof"` +} + +type Object_Uint64ArrayValue struct { + Uint64ArrayValue *Uint64S `protobuf:"bytes,15,opt,name=uint64_array_value,json=uint64ArrayValue,proto3,oneof"` +} + +type Object_UintptrArrayValue struct { + UintptrArrayValue *Uintptrs `protobuf:"bytes,16,opt,name=uintptr_array_value,json=uintptrArrayValue,proto3,oneof"` +} + +type Object_Int8ArrayValue struct { + Int8ArrayValue *Int8S `protobuf:"bytes,17,opt,name=int8_array_value,json=int8ArrayValue,proto3,oneof"` +} + +type Object_Int16ArrayValue struct { + Int16ArrayValue *Int16S `protobuf:"bytes,18,opt,name=int16_array_value,json=int16ArrayValue,proto3,oneof"` +} + +type Object_Int32ArrayValue struct { + Int32ArrayValue *Int32S `protobuf:"bytes,19,opt,name=int32_array_value,json=int32ArrayValue,proto3,oneof"` +} + +type Object_Int64ArrayValue struct { + Int64ArrayValue *Int64S `protobuf:"bytes,20,opt,name=int64_array_value,json=int64ArrayValue,proto3,oneof"` +} + +type Object_BoolArrayValue struct { + BoolArrayValue *Bools `protobuf:"bytes,21,opt,name=bool_array_value,json=boolArrayValue,proto3,oneof"` +} + +type Object_Float64ArrayValue struct { + Float64ArrayValue *Float64S `protobuf:"bytes,22,opt,name=float64_array_value,json=float64ArrayValue,proto3,oneof"` +} + +type Object_Float32ArrayValue struct { + Float32ArrayValue *Float32S `protobuf:"bytes,23,opt,name=float32_array_value,json=float32ArrayValue,proto3,oneof"` +} + +func (*Object_BoolValue) isObject_Value() {} + +func (*Object_StringValue) isObject_Value() {} + +func (*Object_Int64Value) isObject_Value() {} + +func (*Object_Uint64Value) isObject_Value() {} + +func (*Object_DoubleValue) isObject_Value() {} + +func (*Object_RefValue) isObject_Value() {} + +func (*Object_SliceValue) isObject_Value() {} + +func (*Object_ArrayValue) isObject_Value() {} + +func (*Object_InterfaceValue) isObject_Value() {} + +func (*Object_StructValue) isObject_Value() {} + +func (*Object_MapValue) isObject_Value() {} + +func (*Object_ByteArrayValue) isObject_Value() {} + +func (*Object_Uint16ArrayValue) isObject_Value() {} + +func (*Object_Uint32ArrayValue) isObject_Value() {} + +func (*Object_Uint64ArrayValue) isObject_Value() {} + +func (*Object_UintptrArrayValue) isObject_Value() {} + +func (*Object_Int8ArrayValue) isObject_Value() {} + +func (*Object_Int16ArrayValue) isObject_Value() {} + +func (*Object_Int32ArrayValue) isObject_Value() {} + +func (*Object_Int64ArrayValue) isObject_Value() {} + +func (*Object_BoolArrayValue) isObject_Value() {} + +func (*Object_Float64ArrayValue) isObject_Value() {} + +func (*Object_Float32ArrayValue) isObject_Value() {} + +func (m *Object) GetValue() isObject_Value { + if m != nil { + return m.Value + } + return nil +} + +func (m *Object) GetBoolValue() bool { + if x, ok := m.GetValue().(*Object_BoolValue); ok { + return x.BoolValue + } + return false +} + +func (m *Object) GetStringValue() []byte { + if x, ok := m.GetValue().(*Object_StringValue); ok { + return x.StringValue + } + return nil +} + +func (m *Object) GetInt64Value() int64 { + if x, ok := m.GetValue().(*Object_Int64Value); ok { + return x.Int64Value + } + return 0 +} + +func (m *Object) GetUint64Value() uint64 { + if x, ok := m.GetValue().(*Object_Uint64Value); ok { + return x.Uint64Value + } + return 0 +} + +func (m *Object) GetDoubleValue() float64 { + if x, ok := m.GetValue().(*Object_DoubleValue); ok { + return x.DoubleValue + } + return 0 +} + +func (m *Object) GetRefValue() uint64 { + if x, ok := m.GetValue().(*Object_RefValue); ok { + return x.RefValue + } + return 0 +} + +func (m *Object) GetSliceValue() *Slice { + if x, ok := m.GetValue().(*Object_SliceValue); ok { + return x.SliceValue + } + return nil +} + +func (m *Object) GetArrayValue() *Array { + if x, ok := m.GetValue().(*Object_ArrayValue); ok { + return x.ArrayValue + } + return nil +} + +func (m *Object) GetInterfaceValue() *Interface { + if x, ok := m.GetValue().(*Object_InterfaceValue); ok { + return x.InterfaceValue + } + return nil +} + +func (m *Object) GetStructValue() *Struct { + if x, ok := m.GetValue().(*Object_StructValue); ok { + return x.StructValue + } + return nil +} + +func (m *Object) GetMapValue() *Map { + if x, ok := m.GetValue().(*Object_MapValue); ok { + return x.MapValue + } + return nil +} + +func (m *Object) GetByteArrayValue() []byte { + if x, ok := m.GetValue().(*Object_ByteArrayValue); ok { + return x.ByteArrayValue + } + return nil +} + +func (m *Object) GetUint16ArrayValue() *Uint16S { + if x, ok := m.GetValue().(*Object_Uint16ArrayValue); ok { + return x.Uint16ArrayValue + } + return nil +} + +func (m *Object) GetUint32ArrayValue() *Uint32S { + if x, ok := m.GetValue().(*Object_Uint32ArrayValue); ok { + return x.Uint32ArrayValue + } + return nil +} + +func (m *Object) GetUint64ArrayValue() *Uint64S { + if x, ok := m.GetValue().(*Object_Uint64ArrayValue); ok { + return x.Uint64ArrayValue + } + return nil +} + +func (m *Object) GetUintptrArrayValue() *Uintptrs { + if x, ok := m.GetValue().(*Object_UintptrArrayValue); ok { + return x.UintptrArrayValue + } + return nil +} + +func (m *Object) GetInt8ArrayValue() *Int8S { + if x, ok := m.GetValue().(*Object_Int8ArrayValue); ok { + return x.Int8ArrayValue + } + return nil +} + +func (m *Object) GetInt16ArrayValue() *Int16S { + if x, ok := m.GetValue().(*Object_Int16ArrayValue); ok { + return x.Int16ArrayValue + } + return nil +} + +func (m *Object) GetInt32ArrayValue() *Int32S { + if x, ok := m.GetValue().(*Object_Int32ArrayValue); ok { + return x.Int32ArrayValue + } + return nil +} + +func (m *Object) GetInt64ArrayValue() *Int64S { + if x, ok := m.GetValue().(*Object_Int64ArrayValue); ok { + return x.Int64ArrayValue + } + return nil +} + +func (m *Object) GetBoolArrayValue() *Bools { + if x, ok := m.GetValue().(*Object_BoolArrayValue); ok { + return x.BoolArrayValue + } + return nil +} + +func (m *Object) GetFloat64ArrayValue() *Float64S { + if x, ok := m.GetValue().(*Object_Float64ArrayValue); ok { + return x.Float64ArrayValue + } + return nil +} + +func (m *Object) GetFloat32ArrayValue() *Float32S { + if x, ok := m.GetValue().(*Object_Float32ArrayValue); ok { + return x.Float32ArrayValue + } + return nil +} + +// XXX_OneofWrappers is for the internal use of the proto package. +func (*Object) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*Object_BoolValue)(nil), + (*Object_StringValue)(nil), + (*Object_Int64Value)(nil), + (*Object_Uint64Value)(nil), + (*Object_DoubleValue)(nil), + (*Object_RefValue)(nil), + (*Object_SliceValue)(nil), + (*Object_ArrayValue)(nil), + (*Object_InterfaceValue)(nil), + (*Object_StructValue)(nil), + (*Object_MapValue)(nil), + (*Object_ByteArrayValue)(nil), + (*Object_Uint16ArrayValue)(nil), + (*Object_Uint32ArrayValue)(nil), + (*Object_Uint64ArrayValue)(nil), + (*Object_UintptrArrayValue)(nil), + (*Object_Int8ArrayValue)(nil), + (*Object_Int16ArrayValue)(nil), + (*Object_Int32ArrayValue)(nil), + (*Object_Int64ArrayValue)(nil), + (*Object_BoolArrayValue)(nil), + (*Object_Float64ArrayValue)(nil), + (*Object_Float32ArrayValue)(nil), + } +} + +func init() { + proto.RegisterType((*Slice)(nil), "gvisor.state.statefile.Slice") + proto.RegisterType((*Array)(nil), "gvisor.state.statefile.Array") + proto.RegisterType((*Map)(nil), "gvisor.state.statefile.Map") + proto.RegisterType((*Interface)(nil), "gvisor.state.statefile.Interface") + proto.RegisterType((*Struct)(nil), "gvisor.state.statefile.Struct") + proto.RegisterType((*Field)(nil), "gvisor.state.statefile.Field") + proto.RegisterType((*Uint16S)(nil), "gvisor.state.statefile.Uint16s") + proto.RegisterType((*Uint32S)(nil), "gvisor.state.statefile.Uint32s") + proto.RegisterType((*Uint64S)(nil), "gvisor.state.statefile.Uint64s") + proto.RegisterType((*Uintptrs)(nil), "gvisor.state.statefile.Uintptrs") + proto.RegisterType((*Int8S)(nil), "gvisor.state.statefile.Int8s") + proto.RegisterType((*Int16S)(nil), "gvisor.state.statefile.Int16s") + proto.RegisterType((*Int32S)(nil), "gvisor.state.statefile.Int32s") + proto.RegisterType((*Int64S)(nil), "gvisor.state.statefile.Int64s") + proto.RegisterType((*Bools)(nil), "gvisor.state.statefile.Bools") + proto.RegisterType((*Float64S)(nil), "gvisor.state.statefile.Float64s") + proto.RegisterType((*Float32S)(nil), "gvisor.state.statefile.Float32s") + proto.RegisterType((*Object)(nil), "gvisor.state.statefile.Object") +} + +func init() { proto.RegisterFile("pkg/state/object.proto", fileDescriptor_3dee2c1912d4d62d) } + +var fileDescriptor_3dee2c1912d4d62d = []byte{ + // 781 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x96, 0x6f, 0x4f, 0xda, 0x5e, + 0x14, 0xc7, 0xa9, 0x40, 0x29, 0x07, 0x14, 0xb8, 0xfe, 0x7e, 0x8c, 0xcc, 0x38, 0xb1, 0x7b, 0x42, + 0xf6, 0x00, 0x33, 0x60, 0xc4, 0xf8, 0x64, 0x53, 0x13, 0x03, 0xc9, 0x8c, 0x59, 0x8d, 0xcb, 0x9e, + 0x99, 0x52, 0x2f, 0xac, 0xb3, 0xb6, 0x5d, 0x7b, 0x6b, 0xc2, 0xcb, 0xdc, 0x3b, 0x5a, 0xee, 0x1f, + 0xae, 0xfd, 0x03, 0xc5, 0xec, 0x89, 0xa1, 0xb7, 0xdf, 0xf3, 0xe1, 0xdc, 0xf3, 0x3d, 0xe7, 0x08, + 0xb4, 0xfd, 0xc7, 0xc5, 0x49, 0x48, 0x4c, 0x82, 0x4f, 0xbc, 0xd9, 0x2f, 0x6c, 0x91, 0xbe, 0x1f, + 0x78, 0xc4, 0x43, 0xed, 0xc5, 0xb3, 0x1d, 0x7a, 0x41, 0x9f, 0xbd, 0xe2, 0x7f, 0xe7, 0xb6, 0x83, + 0xf5, 0x1f, 0x50, 0xbe, 0x75, 0x6c, 0x0b, 0xa3, 0x36, 0xa8, 0x0e, 0x76, 0x17, 0xe4, 0x67, 0x47, + 0xe9, 0x2a, 0xbd, 0x5d, 0x43, 0x3c, 0xa1, 0xb7, 0xa0, 0x59, 0xa6, 0x6f, 0x5a, 0x36, 0x59, 0x76, + 0x76, 0xd8, 0x1b, 0xf9, 0x8c, 0x0e, 0xa0, 0x1a, 0xe0, 0xf9, 0xfd, 0xb3, 0xe9, 0x44, 0xb8, 0x53, + 0xec, 0x2a, 0xbd, 0x92, 0xa1, 0x05, 0x78, 0xfe, 0x9d, 0x3e, 0xeb, 0x97, 0x50, 0x3e, 0x0f, 0x02, + 0x73, 0x89, 0xce, 0x40, 0xb3, 0x3c, 0x97, 0x60, 0x97, 0x84, 0x1d, 0xa5, 0x5b, 0xec, 0xd5, 0x06, + 0xef, 0xfa, 0xeb, 0xb3, 0xe9, 0xdf, 0xb0, 0x94, 0x0d, 0xa9, 0xd7, 0x7f, 0x43, 0xf1, 0xda, 0xf4, + 0xd1, 0x00, 0x4a, 0x8f, 0x78, 0xf9, 0xda, 0x70, 0xa6, 0x45, 0x63, 0x50, 0x59, 0x62, 0x61, 0x67, + 0xe7, 0x55, 0x51, 0x42, 0xad, 0xdf, 0x41, 0x75, 0xea, 0x12, 0x1c, 0xcc, 0x4d, 0x0b, 0x23, 0x04, + 0x25, 0xb2, 0xf4, 0x31, 0xab, 0x49, 0xd5, 0x60, 0x9f, 0xd1, 0x08, 0xca, 0xfc, 0xc6, 0xb4, 0x1c, + 0xdb, 0xb9, 0x5c, 0xac, 0x7f, 0x06, 0xf5, 0x96, 0x04, 0x91, 0x45, 0xd0, 0x27, 0x50, 0xe7, 0x36, + 0x76, 0x1e, 0x56, 0xd7, 0x39, 0xdc, 0x04, 0xb8, 0xa2, 0x2a, 0x43, 0x88, 0xf5, 0x6f, 0x50, 0x66, + 0x07, 0x34, 0x27, 0xd7, 0x7c, 0x92, 0x39, 0xd1, 0xcf, 0xff, 0x98, 0xd3, 0x31, 0x54, 0xee, 0x6c, + 0x97, 0x7c, 0x1c, 0x87, 0xd4, 0x7e, 0x51, 0x2d, 0x9a, 0xd4, 0xae, 0xac, 0x86, 0x90, 0x0c, 0x07, + 0x69, 0x49, 0x25, 0x2d, 0x19, 0x8f, 0xd2, 0x12, 0x55, 0x4a, 0x74, 0xd0, 0xa8, 0xc4, 0x27, 0xc1, + 0x66, 0xcd, 0x11, 0x94, 0xa7, 0x2e, 0x39, 0x4d, 0x0a, 0x94, 0x5e, 0x5d, 0x0a, 0xba, 0xa0, 0x4e, + 0xd7, 0x25, 0x5b, 0x4e, 0x29, 0xb2, 0xb9, 0x36, 0x52, 0x8a, 0x6c, 0xaa, 0xcd, 0x78, 0x1a, 0x17, + 0x9e, 0xe7, 0xa4, 0x05, 0x5a, 0xfc, 0x2e, 0x57, 0x8e, 0x67, 0xae, 0x81, 0x28, 0x19, 0x4d, 0x36, + 0x95, 0x1d, 0xa9, 0xf9, 0x53, 0x03, 0x95, 0xdb, 0x81, 0x8e, 0x00, 0x66, 0x9e, 0xe7, 0x88, 0x41, + 0xa2, 0xb7, 0xd6, 0x26, 0x05, 0xa3, 0x4a, 0xcf, 0xd8, 0x2c, 0xa1, 0xf7, 0x50, 0x0f, 0x49, 0x60, + 0xbb, 0x8b, 0xfb, 0x17, 0x97, 0xeb, 0x93, 0x82, 0x51, 0xe3, 0xa7, 0x5c, 0x74, 0x0c, 0x35, 0x66, + 0x43, 0x6c, 0x1e, 0x8b, 0x93, 0x82, 0x01, 0xec, 0x50, 0x72, 0xa2, 0xb8, 0xa6, 0x44, 0x67, 0x96, + 0x72, 0xa2, 0xa4, 0xe8, 0xc1, 0x8b, 0x66, 0x0e, 0x16, 0xa2, 0x72, 0x57, 0xe9, 0x29, 0x54, 0xc4, + 0x4f, 0xb9, 0xe8, 0x30, 0x3e, 0xfa, 0xaa, 0xc0, 0xc8, 0xe1, 0x47, 0x5f, 0xa0, 0x16, 0xd2, 0xb5, + 0x22, 0x04, 0x15, 0xd6, 0x95, 0x1b, 0x1b, 0x9d, 0x6d, 0x20, 0x9a, 0x2a, 0x8b, 0x91, 0x04, 0x93, + 0xae, 0x0f, 0x41, 0xd0, 0xf2, 0x09, 0x6c, 0xd3, 0x50, 0x02, 0x8b, 0xe1, 0x84, 0xaf, 0xd0, 0xb0, + 0x57, 0x83, 0x2c, 0x28, 0x55, 0x46, 0x39, 0xde, 0x44, 0x91, 0x73, 0x3f, 0x29, 0x18, 0x7b, 0x32, + 0x96, 0xd3, 0x2e, 0x99, 0x05, 0x91, 0x45, 0x04, 0x0a, 0xf2, 0x07, 0x8d, 0xcf, 0xba, 0xb0, 0x28, + 0xb2, 0x08, 0x87, 0x9c, 0x41, 0xf5, 0xc9, 0xf4, 0x05, 0xa1, 0xc6, 0x08, 0x07, 0x9b, 0x08, 0xd7, + 0xa6, 0x4f, 0x4b, 0xfa, 0x64, 0xfa, 0x3c, 0xf6, 0x03, 0x34, 0x67, 0x4b, 0x82, 0xef, 0xe3, 0x55, + 0xa9, 0x8b, 0x3e, 0xd8, 0xa3, 0x6f, 0xce, 0x5f, 0xae, 0x7e, 0x03, 0x28, 0x62, 0x83, 0x9d, 0x50, + 0xef, 0xb2, 0x2f, 0x3c, 0xda, 0xf4, 0x85, 0x62, 0x15, 0x4c, 0x0a, 0x46, 0x93, 0x07, 0x67, 0x81, + 0xc3, 0x41, 0x02, 0xb8, 0xb7, 0x1d, 0x38, 0x1c, 0x48, 0xe0, 0x70, 0x90, 0x05, 0x8e, 0x47, 0x09, + 0x60, 0x63, 0x3b, 0x70, 0x3c, 0x92, 0xc0, 0xf1, 0x28, 0x06, 0x34, 0x60, 0x3f, 0xe2, 0x2b, 0x26, + 0x41, 0x6c, 0x32, 0x62, 0x37, 0x8f, 0x48, 0xb7, 0xd2, 0xa4, 0x60, 0xb4, 0x44, 0x78, 0x8c, 0x39, + 0x85, 0xa6, 0xed, 0x92, 0xd3, 0x04, 0xb0, 0x95, 0xdf, 0x88, 0x6c, 0x85, 0x89, 0xf6, 0x39, 0x3d, + 0x8f, 0x37, 0x63, 0x2b, 0x6b, 0x08, 0xca, 0xef, 0xa1, 0xe9, 0xca, 0x8f, 0x46, 0xda, 0x0e, 0x4e, + 0x4b, 0xb9, 0xb1, 0xbf, 0x95, 0xc6, 0xcd, 0x68, 0xa4, 0xbd, 0xe0, 0xb4, 0x94, 0x15, 0xff, 0x6d, + 0xa5, 0x71, 0x27, 0x1a, 0x69, 0x23, 0xa6, 0xd0, 0x64, 0xcb, 0x2c, 0x0e, 0xfb, 0x3f, 0xbf, 0x68, + 0x6c, 0xe1, 0xb2, 0x36, 0xf6, 0x3c, 0x27, 0xe9, 0xe9, 0x9c, 0xaf, 0xda, 0x04, 0xad, 0x9d, 0xef, + 0xe9, 0x6a, 0x3b, 0x53, 0x4f, 0x45, 0xf8, 0x1a, 0x66, 0xaa, 0x78, 0x6f, 0x5e, 0xc1, 0xe4, 0xe5, + 0x6b, 0x89, 0xf0, 0x17, 0xe6, 0x45, 0x45, 0xfc, 0xf7, 0x9d, 0xa9, 0xec, 0xc7, 0xd6, 0xf0, 0x6f, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x84, 0x69, 0xc9, 0x45, 0x86, 0x09, 0x00, 0x00, +} diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go deleted file mode 100644 index 7c24bbcda..000000000 --- a/pkg/state/state_test.go +++ /dev/null @@ -1,720 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package state - -import ( - "bytes" - "io/ioutil" - "math" - "reflect" - "testing" -) - -// TestCase is used to define a single success/failure testcase of -// serialization of a set of objects. -type TestCase struct { - // Name is the name of the test case. - Name string - - // Objects is the list of values to serialize. - Objects []interface{} - - // Fail is whether the test case is supposed to fail or not. - Fail bool -} - -// runTest runs all testcases. -func runTest(t *testing.T, tests []TestCase) { - for _, test := range tests { - t.Logf("TEST %s:", test.Name) - for i, root := range test.Objects { - t.Logf(" case#%d: %#v", i, root) - - // Save the passed object. - saveBuffer := &bytes.Buffer{} - saveObjectPtr := reflect.New(reflect.TypeOf(root)) - saveObjectPtr.Elem().Set(reflect.ValueOf(root)) - if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { - t.Errorf(" FAIL: Save failed unexpectedly: %v", err) - continue - } else if err != nil { - t.Logf(" PASS: Save failed as expected: %v", err) - continue - } - - // Load a new copy of the object. - loadObjectPtr := reflect.New(reflect.TypeOf(root)) - if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { - t.Errorf(" FAIL: Load failed unexpectedly: %v", err) - continue - } else if err != nil { - t.Logf(" PASS: Load failed as expected: %v", err) - continue - } - - // Compare the values. - loadedValue := loadObjectPtr.Elem().Interface() - if eq := reflect.DeepEqual(root, loadedValue); !eq && !test.Fail { - t.Errorf(" FAIL: Objects differs; got %#v", loadedValue) - continue - } else if !eq { - t.Logf(" PASS: Object different as expected.") - continue - } - - // Everything went okay. Is that good? - if test.Fail { - t.Errorf(" FAIL: Unexpected success.") - } else { - t.Logf(" PASS: Success.") - } - } - } -} - -// dumbStruct is a struct which does not implement the loader/saver interface. -// We expect that serialization of this struct will fail. -type dumbStruct struct { - A int - B int -} - -// smartStruct is a struct which does implement the loader/saver interface. -// We expect that serialization of this struct will succeed. -type smartStruct struct { - A int - B int -} - -func (s *smartStruct) save(m Map) { - m.Save("A", &s.A) - m.Save("B", &s.B) -} - -func (s *smartStruct) load(m Map) { - m.Load("A", &s.A) - m.Load("B", &s.B) -} - -// valueLoadStruct uses a value load. -type valueLoadStruct struct { - v int -} - -func (v *valueLoadStruct) save(m Map) { - m.SaveValue("v", v.v) -} - -func (v *valueLoadStruct) load(m Map) { - m.LoadValue("v", new(int), func(value interface{}) { - v.v = value.(int) - }) -} - -// afterLoadStruct has an AfterLoad function. -type afterLoadStruct struct { - v int -} - -func (a *afterLoadStruct) save(m Map) { -} - -func (a *afterLoadStruct) load(m Map) { - m.AfterLoad(func() { - a.v++ - }) -} - -// genericContainer is a generic dispatcher. -type genericContainer struct { - v interface{} -} - -func (g *genericContainer) save(m Map) { - m.Save("v", &g.v) -} - -func (g *genericContainer) load(m Map) { - m.Load("v", &g.v) -} - -// sliceContainer is a generic slice. -type sliceContainer struct { - v []interface{} -} - -func (s *sliceContainer) save(m Map) { - m.Save("v", &s.v) -} - -func (s *sliceContainer) load(m Map) { - m.Load("v", &s.v) -} - -// mapContainer is a generic map. -type mapContainer struct { - v map[int]interface{} -} - -func (mc *mapContainer) save(m Map) { - m.Save("v", &mc.v) -} - -func (mc *mapContainer) load(m Map) { - // Some of the test cases below assume legacy behavior wherein maps - // will automatically inherit dependencies. - m.LoadWait("v", &mc.v) -} - -// dumbMap is a map which does not implement the loader/saver interface. -// Serialization of this map will default to the standard encode/decode logic. -type dumbMap map[string]int - -// pointerStruct contains various pointers, shared and non-shared, and pointers -// to pointers. We expect that serialization will respect the structure. -type pointerStruct struct { - A *int - B *int - C *int - D *int - - AA **int - BB **int -} - -func (p *pointerStruct) save(m Map) { - m.Save("A", &p.A) - m.Save("B", &p.B) - m.Save("C", &p.C) - m.Save("D", &p.D) - m.Save("AA", &p.AA) - m.Save("BB", &p.BB) -} - -func (p *pointerStruct) load(m Map) { - m.Load("A", &p.A) - m.Load("B", &p.B) - m.Load("C", &p.C) - m.Load("D", &p.D) - m.Load("AA", &p.AA) - m.Load("BB", &p.BB) -} - -// testInterface is a trivial interface example. -type testInterface interface { - Foo() -} - -// testImpl is a trivial implementation of testInterface. -type testImpl struct { -} - -// Foo satisfies testInterface. -func (t *testImpl) Foo() { -} - -// testImpl is trivially serializable. -func (t *testImpl) save(m Map) { -} - -// testImpl is trivially serializable. -func (t *testImpl) load(m Map) { -} - -// testI demonstrates interface dispatching. -type testI struct { - I testInterface -} - -func (t *testI) save(m Map) { - m.Save("I", &t.I) -} - -func (t *testI) load(m Map) { - m.Load("I", &t.I) -} - -// cycleStruct is used to implement basic cycles. -type cycleStruct struct { - c *cycleStruct -} - -func (c *cycleStruct) save(m Map) { - m.Save("c", &c.c) -} - -func (c *cycleStruct) load(m Map) { - m.Load("c", &c.c) -} - -// badCycleStruct actually has deadlocking dependencies. -// -// This should pass if b.b = {nil|b} and fail otherwise. -type badCycleStruct struct { - b *badCycleStruct -} - -func (b *badCycleStruct) save(m Map) { - m.Save("b", &b.b) -} - -func (b *badCycleStruct) load(m Map) { - m.LoadWait("b", &b.b) - m.AfterLoad(func() { - // This is not executable, since AfterLoad requires that the - // object and all dependencies are complete. This should cause - // a deadlock error during load. - }) -} - -// emptyStructPointer points to an empty struct. -type emptyStructPointer struct { - nothing *struct{} -} - -func (e *emptyStructPointer) save(m Map) { - m.Save("nothing", &e.nothing) -} - -func (e *emptyStructPointer) load(m Map) { - m.Load("nothing", &e.nothing) -} - -// truncateInteger truncates an integer. -type truncateInteger struct { - v int64 - v2 int32 -} - -func (t *truncateInteger) save(m Map) { - t.v2 = int32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateInteger) load(m Map) { - m.Load("v", &t.v2) - t.v = int64(t.v2) -} - -// truncateUnsignedInteger truncates an unsigned integer. -type truncateUnsignedInteger struct { - v uint64 - v2 uint32 -} - -func (t *truncateUnsignedInteger) save(m Map) { - t.v2 = uint32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateUnsignedInteger) load(m Map) { - m.Load("v", &t.v2) - t.v = uint64(t.v2) -} - -// truncateFloat truncates a floating point number. -type truncateFloat struct { - v float64 - v2 float32 -} - -func (t *truncateFloat) save(m Map) { - t.v2 = float32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateFloat) load(m Map) { - m.Load("v", &t.v2) - t.v = float64(t.v2) -} - -func TestTypes(t *testing.T) { - // x and y are basic integers, while xp points to x. - x := 1 - y := 2 - xp := &x - - // cs is a single object cycle. - cs := cycleStruct{nil} - cs.c = &cs - - // cs1 and cs2 are in a two object cycle. - cs1 := cycleStruct{nil} - cs2 := cycleStruct{nil} - cs1.c = &cs2 - cs2.c = &cs1 - - // bs is a single object cycle. - bs := badCycleStruct{nil} - bs.b = &bs - - // bs2 and bs2 are in a deadlocking cycle. - bs1 := badCycleStruct{nil} - bs2 := badCycleStruct{nil} - bs1.b = &bs2 - bs2.b = &bs1 - - // regular nils. - var ( - nilmap dumbMap - nilslice []byte - ) - - // embed points to embedded fields. - embed1 := pointerStruct{} - embed1.AA = &embed1.A - embed2 := pointerStruct{} - embed2.BB = &embed2.B - - // es1 contains two structs pointing to the same empty struct. - es := emptyStructPointer{new(struct{})} - es1 := []emptyStructPointer{es, es} - - tests := []TestCase{ - { - Name: "bool", - Objects: []interface{}{ - true, - false, - }, - }, - { - Name: "integers", - Objects: []interface{}{ - int(0), - int(1), - int(-1), - int8(0), - int8(1), - int8(-1), - int16(0), - int16(1), - int16(-1), - int32(0), - int32(1), - int32(-1), - int64(0), - int64(1), - int64(-1), - }, - }, - { - Name: "unsigned integers", - Objects: []interface{}{ - uint(0), - uint(1), - uint8(0), - uint8(1), - uint16(0), - uint16(1), - uint32(1), - uint64(0), - uint64(1), - }, - }, - { - Name: "strings", - Objects: []interface{}{ - "", - "foo", - "bar", - "\xa0", - }, - }, - { - Name: "slices", - Objects: []interface{}{ - []int{-1, 0, 1}, - []*int{&x, &x, &x}, - []int{1, 2, 3}[0:1], - []int{1, 2, 3}[1:2], - make([]byte, 32), - make([]byte, 32)[:16], - make([]byte, 32)[:16:20], - nilslice, - }, - }, - { - Name: "arrays", - Objects: []interface{}{ - &[1048576]bool{false, true, false, true}, - &[1048576]uint8{0, 1, 2, 3}, - &[1048576]byte{0, 1, 2, 3}, - &[1048576]uint16{0, 1, 2, 3}, - &[1048576]uint{0, 1, 2, 3}, - &[1048576]uint32{0, 1, 2, 3}, - &[1048576]uint64{0, 1, 2, 3}, - &[1048576]uintptr{0, 1, 2, 3}, - &[1048576]int8{0, -1, -2, -3}, - &[1048576]int16{0, -1, -2, -3}, - &[1048576]int32{0, -1, -2, -3}, - &[1048576]int64{0, -1, -2, -3}, - &[1048576]float32{0, 1.1, 2.2, 3.3}, - &[1048576]float64{0, 1.1, 2.2, 3.3}, - }, - }, - { - Name: "pointers", - Objects: []interface{}{ - &pointerStruct{A: &x, B: &x, C: &y, D: &y, AA: &xp, BB: &xp}, - &pointerStruct{}, - }, - }, - { - Name: "empty struct", - Objects: []interface{}{ - struct{}{}, - }, - }, - { - Name: "unenlightened structs", - Objects: []interface{}{ - &dumbStruct{A: 1, B: 2}, - }, - Fail: true, - }, - { - Name: "enlightened structs", - Objects: []interface{}{ - &smartStruct{A: 1, B: 2}, - }, - }, - { - Name: "load-hooks", - Objects: []interface{}{ - &afterLoadStruct{v: 1}, - &valueLoadStruct{v: 1}, - &genericContainer{v: &afterLoadStruct{v: 1}}, - &genericContainer{v: &valueLoadStruct{v: 1}}, - &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, - &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}}, - }, - }, - { - Name: "maps", - Objects: []interface{}{ - dumbMap{"a": -1, "b": 0, "c": 1}, - map[smartStruct]int{{}: 0, {A: 1}: 1}, - nilmap, - &mapContainer{v: map[int]interface{}{0: &smartStruct{A: 1}}}, - }, - }, - { - Name: "interfaces", - Objects: []interface{}{ - &testI{&testImpl{}}, - &testI{nil}, - &testI{(*testImpl)(nil)}, - }, - }, - { - Name: "unregistered-interfaces", - Objects: []interface{}{ - &genericContainer{v: afterLoadStruct{v: 1}}, - &genericContainer{v: valueLoadStruct{v: 1}}, - &sliceContainer{v: []interface{}{afterLoadStruct{v: 1}}}, - &sliceContainer{v: []interface{}{valueLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: afterLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: valueLoadStruct{v: 1}}}, - }, - Fail: true, - }, - { - Name: "cycles", - Objects: []interface{}{ - &cs, - &cs1, - &cycleStruct{&cs1}, - &cycleStruct{&cs}, - &badCycleStruct{nil}, - &bs, - }, - }, - { - Name: "deadlock", - Objects: []interface{}{ - &bs1, - }, - Fail: true, - }, - { - Name: "embed", - Objects: []interface{}{ - &embed1, - &embed2, - }, - Fail: true, - }, - { - Name: "empty structs", - Objects: []interface{}{ - new(struct{}), - es, - es1, - }, - }, - { - Name: "truncated okay", - Objects: []interface{}{ - &truncateInteger{v: 1}, - &truncateUnsignedInteger{v: 1}, - &truncateFloat{v: 1.0}, - }, - }, - { - Name: "truncated bad", - Objects: []interface{}{ - &truncateInteger{v: math.MaxInt32 + 1}, - &truncateUnsignedInteger{v: math.MaxUint32 + 1}, - &truncateFloat{v: math.MaxFloat32 * 2}, - }, - Fail: true, - }, - } - - runTest(t, tests) -} - -// benchStruct is used for benchmarking. -type benchStruct struct { - b *benchStruct - - // Dummy data is included to ensure that these objects are large. - // This is to detect possible regression when registering objects. - _ [4096]byte -} - -func (b *benchStruct) save(m Map) { - m.Save("b", &b.b) -} - -func (b *benchStruct) load(m Map) { - m.LoadWait("b", &b.b) - m.AfterLoad(b.afterLoad) -} - -func (b *benchStruct) afterLoad() { - // Do nothing, just force scheduling. -} - -// buildObject builds a benchmark object. -func buildObject(n int) (b *benchStruct) { - for i := 0; i < n; i++ { - b = &benchStruct{b: b} - } - return -} - -func BenchmarkEncoding(b *testing.B) { - b.StopTimer() - bs := buildObject(b.N) - var stats Stats - b.StartTimer() - if err := Save(ioutil.Discard, bs, &stats); err != nil { - b.Errorf("save failed: %v", err) - } - b.StopTimer() - if b.N > 1000 { - b.Logf("breakdown (n=%d): %s", b.N, &stats) - } -} - -func BenchmarkDecoding(b *testing.B) { - b.StopTimer() - bs := buildObject(b.N) - var newBS benchStruct - buf := &bytes.Buffer{} - if err := Save(buf, bs, nil); err != nil { - b.Errorf("save failed: %v", err) - } - var stats Stats - b.StartTimer() - if err := Load(buf, &newBS, &stats); err != nil { - b.Errorf("load failed: %v", err) - } - b.StopTimer() - if b.N > 1000 { - b.Logf("breakdown (n=%d): %s", b.N, &stats) - } -} - -func init() { - Register("stateTest.smartStruct", (*smartStruct)(nil), Fns{ - Save: (*smartStruct).save, - Load: (*smartStruct).load, - }) - Register("stateTest.afterLoadStruct", (*afterLoadStruct)(nil), Fns{ - Save: (*afterLoadStruct).save, - Load: (*afterLoadStruct).load, - }) - Register("stateTest.valueLoadStruct", (*valueLoadStruct)(nil), Fns{ - Save: (*valueLoadStruct).save, - Load: (*valueLoadStruct).load, - }) - Register("stateTest.genericContainer", (*genericContainer)(nil), Fns{ - Save: (*genericContainer).save, - Load: (*genericContainer).load, - }) - Register("stateTest.sliceContainer", (*sliceContainer)(nil), Fns{ - Save: (*sliceContainer).save, - Load: (*sliceContainer).load, - }) - Register("stateTest.mapContainer", (*mapContainer)(nil), Fns{ - Save: (*mapContainer).save, - Load: (*mapContainer).load, - }) - Register("stateTest.pointerStruct", (*pointerStruct)(nil), Fns{ - Save: (*pointerStruct).save, - Load: (*pointerStruct).load, - }) - Register("stateTest.testImpl", (*testImpl)(nil), Fns{ - Save: (*testImpl).save, - Load: (*testImpl).load, - }) - Register("stateTest.testI", (*testI)(nil), Fns{ - Save: (*testI).save, - Load: (*testI).load, - }) - Register("stateTest.cycleStruct", (*cycleStruct)(nil), Fns{ - Save: (*cycleStruct).save, - Load: (*cycleStruct).load, - }) - Register("stateTest.badCycleStruct", (*badCycleStruct)(nil), Fns{ - Save: (*badCycleStruct).save, - Load: (*badCycleStruct).load, - }) - Register("stateTest.emptyStructPointer", (*emptyStructPointer)(nil), Fns{ - Save: (*emptyStructPointer).save, - Load: (*emptyStructPointer).load, - }) - Register("stateTest.truncateInteger", (*truncateInteger)(nil), Fns{ - Save: (*truncateInteger).save, - Load: (*truncateInteger).load, - }) - Register("stateTest.truncateUnsignedInteger", (*truncateUnsignedInteger)(nil), Fns{ - Save: (*truncateUnsignedInteger).save, - Load: (*truncateUnsignedInteger).load, - }) - Register("stateTest.truncateFloat", (*truncateFloat)(nil), Fns{ - Save: (*truncateFloat).save, - Load: (*truncateFloat).load, - }) - Register("stateTest.benchStruct", (*benchStruct)(nil), Fns{ - Save: (*benchStruct).save, - Load: (*benchStruct).load, - }) -} diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD deleted file mode 100644 index e70f4a79f..000000000 --- a/pkg/state/statefile/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "statefile", - srcs = ["statefile.go"], - importpath = "gvisor.dev/gvisor/pkg/state/statefile", - visibility = ["//:sandbox"], - deps = [ - "//pkg/binary", - "//pkg/compressio", - ], -) - -go_test( - name = "statefile_test", - size = "small", - srcs = ["statefile_test.go"], - embed = [":statefile"], - deps = ["//pkg/compressio"], -) diff --git a/pkg/state/statefile/statefile_state_autogen.go b/pkg/state/statefile/statefile_state_autogen.go new file mode 100755 index 000000000..438c485ca --- /dev/null +++ b/pkg/state/statefile/statefile_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package statefile + diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go deleted file mode 100644 index 0b470fdec..000000000 --- a/pkg/state/statefile/statefile_test.go +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package statefile - -import ( - "bytes" - crand "crypto/rand" - "encoding/base64" - "io" - "math/rand" - "runtime" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/compressio" -) - -func randomKey() ([]byte, error) { - r := make([]byte, base64.RawStdEncoding.DecodedLen(keySize)) - if _, err := io.ReadFull(crand.Reader, r); err != nil { - return nil, err - } - key := make([]byte, keySize) - base64.RawStdEncoding.Encode(key, r) - return key, nil -} - -type testCase struct { - name string - data []byte - metadata map[string]string -} - -func TestStatefile(t *testing.T) { - rand.Seed(time.Now().Unix()) - - cases := []testCase{ - // Various data sizes. - {"nil", nil, nil}, - {"empty", []byte(""), nil}, - {"some", []byte("_"), nil}, - {"one", []byte("0"), nil}, - {"two", []byte("01"), nil}, - {"three", []byte("012"), nil}, - {"four", []byte("0123"), nil}, - {"five", []byte("01234"), nil}, - {"six", []byte("012356"), nil}, - {"seven", []byte("0123567"), nil}, - {"eight", []byte("01235678"), nil}, - - // Make sure we have one longer than the hash length. - {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil}, - - // Make sure we have one longer than the chunk size. - {"chunks", make([]byte, 3*compressionChunkSize), nil}, - {"large", make([]byte, 30*compressionChunkSize), nil}, - - // Different metadata. - {"one metadata", []byte("data"), map[string]string{"foo": "bar"}}, - {"two metadata", []byte("data"), map[string]string{"foo": "bar", "one": "two"}}, - } - - for _, c := range cases { - // Generate a key. - integrityKey, err := randomKey() - if err != nil { - t.Errorf("can't generate key: got %v, excepted nil", err) - continue - } - - t.Run(c.name, func(t *testing.T) { - for _, key := range [][]byte{nil, integrityKey} { - t.Run("key="+string(key), func(t *testing.T) { - // Encoding happens via a buffer. - var bufEncoded bytes.Buffer - var bufDecoded bytes.Buffer - - // Do all the writing. - w, err := NewWriter(&bufEncoded, key, c.metadata) - if err != nil { - t.Fatalf("error creating writer: got %v, expected nil", err) - } - if _, err := io.Copy(w, bytes.NewBuffer(c.data)); err != nil { - t.Fatalf("error during write: got %v, expected nil", err) - } - - // Finish the sum. - if err := w.Close(); err != nil { - t.Fatalf("error during close: got %v, expected nil", err) - } - - t.Logf("original data: %d bytes, encoded: %d bytes.", - len(c.data), len(bufEncoded.Bytes())) - - // Do all the reading. - r, metadata, err := NewReader(bytes.NewReader(bufEncoded.Bytes()), key) - if err != nil { - t.Fatalf("error creating reader: got %v, expected nil", err) - } - if _, err := io.Copy(&bufDecoded, r); err != nil { - t.Fatalf("error during read: got %v, expected nil", err) - } - - // Check that the data matches. - if !bytes.Equal(c.data, bufDecoded.Bytes()) { - t.Fatalf("data didn't match (%d vs %d bytes)", len(bufDecoded.Bytes()), len(c.data)) - } - - // Check that the metadata matches. - for k, v := range c.metadata { - nv, ok := metadata[k] - if !ok { - t.Fatalf("missing metadata: %s", k) - } - if v != nv { - t.Fatalf("mismatched metdata for %s: got %s, expected %s", k, nv, v) - } - } - - // Change the data and verify that it fails. - if key != nil { - b := append([]byte(nil), bufEncoded.Bytes()...) - b[rand.Intn(len(b))]++ - bufDecoded.Reset() - r, _, err = NewReader(bytes.NewReader(b), key) - if err == nil { - _, err = io.Copy(&bufDecoded, r) - } - if err == nil { - t.Error("got no error: expected error on data corruption") - } - } - - // Change the key and verify that it fails. - newKey := integrityKey - if len(key) > 0 { - newKey = append([]byte{}, key...) - newKey[rand.Intn(len(newKey))]++ - } - bufDecoded.Reset() - r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), newKey) - if err == nil { - _, err = io.Copy(&bufDecoded, r) - } - if err != compressio.ErrHashMismatch { - t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err) - } - }) - } - }) - } -} - -const benchmarkDataSize = 100 * 1024 * 1024 - -func benchmark(b *testing.B, size int, write bool, compressible bool) { - b.StopTimer() - b.SetBytes(benchmarkDataSize) - - // Generate source data. - var source []byte - if compressible { - // For compressible data, we use essentially all zeros. - source = make([]byte, benchmarkDataSize) - } else { - // For non-compressible data, we use random base64 data (to - // make it marginally compressible, a ratio of 75%). - var sourceBuf bytes.Buffer - bufW := base64.NewEncoder(base64.RawStdEncoding, &sourceBuf) - bufR := rand.New(rand.NewSource(0)) - if _, err := io.CopyN(bufW, bufR, benchmarkDataSize); err != nil { - b.Fatalf("unable to seed random data: %v", err) - } - source = sourceBuf.Bytes() - } - - // Generate a random key for integrity check. - key, err := randomKey() - if err != nil { - b.Fatalf("error generating key: %v", err) - } - - // Define our benchmark functions. Prior to running the readState - // function here, you must execute the writeState function at least - // once (done below). - var stateBuf bytes.Buffer - writeState := func() { - stateBuf.Reset() - w, err := NewWriter(&stateBuf, key, nil) - if err != nil { - b.Fatalf("error creating writer: %v", err) - } - for done := 0; done < len(source); { - chunk := size // limit size. - if done+chunk > len(source) { - chunk = len(source) - done - } - n, err := w.Write(source[done : done+chunk]) - done += n - if n == 0 && err != nil { - b.Fatalf("error during write: %v", err) - } - } - if err := w.Close(); err != nil { - b.Fatalf("error closing writer: %v", err) - } - } - readState := func() { - tmpBuf := bytes.NewBuffer(stateBuf.Bytes()) - r, _, err := NewReader(tmpBuf, key) - if err != nil { - b.Fatalf("error creating reader: %v", err) - } - for done := 0; done < len(source); { - chunk := size // limit size. - if done+chunk > len(source) { - chunk = len(source) - done - } - n, err := r.Read(source[done : done+chunk]) - done += n - if n == 0 && err != nil { - b.Fatalf("error during read: %v", err) - } - } - } - // Generate the state once without timing to ensure that buffers have - // been appropriately allocated. - writeState() - if write { - b.StartTimer() - for i := 0; i < b.N; i++ { - writeState() - } - b.StopTimer() - } else { - b.StartTimer() - for i := 0; i < b.N; i++ { - readState() - } - b.StopTimer() - } -} - -func BenchmarkWrite4KCompressible(b *testing.B) { - benchmark(b, 4096, true, true) -} - -func BenchmarkWrite4KNoncompressible(b *testing.B) { - benchmark(b, 4096, true, false) -} - -func BenchmarkWrite1MCompressible(b *testing.B) { - benchmark(b, 1024*1024, true, true) -} - -func BenchmarkWrite1MNoncompressible(b *testing.B) { - benchmark(b, 1024*1024, true, false) -} - -func BenchmarkRead4KCompressible(b *testing.B) { - benchmark(b, 4096, false, true) -} - -func BenchmarkRead4KNoncompressible(b *testing.B) { - benchmark(b, 4096, false, false) -} - -func BenchmarkRead1MCompressible(b *testing.B) { - benchmark(b, 1024*1024, false, true) -} - -func BenchmarkRead1MNoncompressible(b *testing.B) { - benchmark(b, 1024*1024, false, false) -} - -func init() { - runtime.GOMAXPROCS(runtime.NumCPU()) -} diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD deleted file mode 100644 index 5665ad4ee..000000000 --- a/pkg/syserr/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "syserr", - srcs = [ - "host_linux.go", - "netstack.go", - "syserr.go", - ], - importpath = "gvisor.dev/gvisor/pkg/syserr", - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi/linux", - "//pkg/syserror", - "//pkg/tcpip", - ], -) diff --git a/pkg/syserr/syserr_state_autogen.go b/pkg/syserr/syserr_state_autogen.go new file mode 100755 index 000000000..f34cb096b --- /dev/null +++ b/pkg/syserr/syserr_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package syserr + diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD deleted file mode 100644 index b149f9e02..000000000 --- a/pkg/syserror/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "syserror", - srcs = ["syserror.go"], - importpath = "gvisor.dev/gvisor/pkg/syserror", - visibility = ["//visibility:public"], -) - -go_test( - name = "syserror_test", - srcs = ["syserror_test.go"], - deps = [ - ":syserror", - ], -) diff --git a/pkg/syserror/syserror_state_autogen.go b/pkg/syserror/syserror_state_autogen.go new file mode 100755 index 000000000..5691e4f5e --- /dev/null +++ b/pkg/syserror/syserror_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package syserror + diff --git a/pkg/syserror/syserror_test.go b/pkg/syserror/syserror_test.go deleted file mode 100644 index 29719752e..000000000 --- a/pkg/syserror/syserror_test.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syserror_test - -import ( - "errors" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/syserror" -) - -var globalError error - -func returnErrnoAsError() error { - return syscall.EINVAL -} - -func returnError() error { - return syserror.EINVAL -} - -func BenchmarkReturnErrnoAsError(b *testing.B) { - for i := b.N; i > 0; i-- { - returnErrnoAsError() - } -} - -func BenchmarkReturnError(b *testing.B) { - for i := b.N; i > 0; i-- { - returnError() - } -} - -func BenchmarkCompareErrno(b *testing.B) { - j := 0 - for i := b.N; i > 0; i-- { - if globalError == syscall.EINVAL { - j++ - } - } -} - -func BenchmarkCompareError(b *testing.B) { - j := 0 - for i := b.N; i > 0; i-- { - if globalError == syserror.EINVAL { - j++ - } - } -} - -func BenchmarkSwitchErrno(b *testing.B) { - j := 0 - for i := b.N; i > 0; i-- { - switch globalError { - case syscall.EINVAL: - j += 1 - case syscall.EINTR: - j += 2 - case syscall.EAGAIN: - j += 3 - } - } -} - -func BenchmarkSwitchError(b *testing.B) { - j := 0 - for i := b.N; i > 0; i-- { - switch globalError { - case syserror.EINVAL: - j += 1 - case syserror.EINTR: - j += 2 - case syserror.EAGAIN: - j += 3 - } - } -} - -type translationTestTable struct { - fn string - errIn error - syscallErrorIn syscall.Errno - expectedBool bool - expectedTranslation syscall.Errno -} - -func TestErrorTranslation(t *testing.T) { - myError := errors.New("My test error") - myError2 := errors.New("Another test error") - testTable := []translationTestTable{ - {"TranslateError", myError, 0, false, 0}, - {"TranslateError", myError2, 0, false, 0}, - {"AddErrorTranslation", myError, syscall.EAGAIN, true, 0}, - {"AddErrorTranslation", myError, syscall.EAGAIN, false, 0}, - {"AddErrorTranslation", myError, syscall.EPERM, false, 0}, - {"TranslateError", myError, 0, true, syscall.EAGAIN}, - {"TranslateError", myError2, 0, false, 0}, - {"AddErrorTranslation", myError2, syscall.EPERM, true, 0}, - {"AddErrorTranslation", myError2, syscall.EPERM, false, 0}, - {"AddErrorTranslation", myError2, syscall.EAGAIN, false, 0}, - {"TranslateError", myError, 0, true, syscall.EAGAIN}, - {"TranslateError", myError2, 0, true, syscall.EPERM}, - } - for _, tt := range testTable { - switch tt.fn { - case "TranslateError": - err, ok := syserror.TranslateError(tt.errIn) - if ok != tt.expectedBool { - t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool) - } else if err != tt.expectedTranslation { - t.Fatalf("%v(%v) (error) => %v expected %v", tt.fn, tt.errIn, err, tt.expectedTranslation) - } - case "AddErrorTranslation": - ok := syserror.AddErrorTranslation(tt.errIn, tt.syscallErrorIn) - if ok != tt.expectedBool { - t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool) - } - default: - t.Fatalf("Unknown function %v", tt.fn) - } - } -} diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD deleted file mode 100644 index 047f8329a..000000000 --- a/pkg/tcpip/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "tcpip", - srcs = [ - "tcpip.go", - "time_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip/buffer", - "//pkg/waiter", - ], -) - -go_test( - name = "tcpip_test", - size = "small", - srcs = ["tcpip_test.go"], - embed = [":tcpip"], -) diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD deleted file mode 100644 index c40924852..000000000 --- a/pkg/tcpip/adapters/gonet/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "gonet", - srcs = ["gonet.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - ], -) - -go_test( - name = "gonet_test", - size = "small", - srcs = ["gonet_test.go"], - embed = [":gonet"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@org_golang_x_net//nettest:go_default_library", - ], -) diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go deleted file mode 100644 index 308f620e5..000000000 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ /dev/null @@ -1,704 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package gonet provides a Go net package compatible wrapper for a tcpip stack. -package gonet - -import ( - "context" - "errors" - "io" - "net" - "sync" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var ( - errCanceled = errors.New("operation canceled") - errWouldBlock = errors.New("operation would block") -) - -// timeoutError is how the net package reports timeouts. -type timeoutError struct{} - -func (e *timeoutError) Error() string { return "i/o timeout" } -func (e *timeoutError) Timeout() bool { return true } -func (e *timeoutError) Temporary() bool { return true } - -// A Listener is a wrapper around a tcpip endpoint that implements -// net.Listener. -type Listener struct { - stack *stack.Stack - ep tcpip.Endpoint - wq *waiter.Queue - cancel chan struct{} -} - -// NewListener creates a new Listener. -func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) { - // Create TCP endpoint, bind it, then start listening. - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) - if err != nil { - return nil, errors.New(err.String()) - } - - if err := ep.Bind(addr); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "bind", - Net: "tcp", - Addr: fullToTCPAddr(addr), - Err: errors.New(err.String()), - } - } - - if err := ep.Listen(10); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "listen", - Net: "tcp", - Addr: fullToTCPAddr(addr), - Err: errors.New(err.String()), - } - } - - return &Listener{ - stack: s, - ep: ep, - wq: &wq, - cancel: make(chan struct{}), - }, nil -} - -// Close implements net.Listener.Close. -func (l *Listener) Close() error { - l.ep.Close() - return nil -} - -// Shutdown stops the HTTP server. -func (l *Listener) Shutdown() { - l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) - close(l.cancel) // broadcast cancellation -} - -// Addr implements net.Listener.Addr. -func (l *Listener) Addr() net.Addr { - a, err := l.ep.GetLocalAddress() - if err != nil { - return nil - } - return fullToTCPAddr(a) -} - -type deadlineTimer struct { - // mu protects the fields below. - mu sync.Mutex - - readTimer *time.Timer - readCancelCh chan struct{} - writeTimer *time.Timer - writeCancelCh chan struct{} -} - -func (d *deadlineTimer) init() { - d.readCancelCh = make(chan struct{}) - d.writeCancelCh = make(chan struct{}) -} - -func (d *deadlineTimer) readCancel() <-chan struct{} { - d.mu.Lock() - c := d.readCancelCh - d.mu.Unlock() - return c -} -func (d *deadlineTimer) writeCancel() <-chan struct{} { - d.mu.Lock() - c := d.writeCancelCh - d.mu.Unlock() - return c -} - -// setDeadline contains the shared logic for setting a deadline. -// -// cancelCh and timer must be pointers to deadlineTimer.readCancelCh and -// deadlineTimer.readTimer or deadlineTimer.writeCancelCh and -// deadlineTimer.writeTimer. -// -// setDeadline must only be called while holding d.mu. -func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) { - if *timer != nil && !(*timer).Stop() { - *cancelCh = make(chan struct{}) - } - - // Create a new channel if we already closed it due to setting an already - // expired time. We won't race with the timer because we already handled - // that above. - select { - case <-*cancelCh: - *cancelCh = make(chan struct{}) - default: - } - - // "A zero value for t means I/O operations will not time out." - // - net.Conn.SetDeadline - if t.IsZero() { - return - } - - timeout := t.Sub(time.Now()) - if timeout <= 0 { - close(*cancelCh) - return - } - - // Timer.Stop returns whether or not the AfterFunc has started, but - // does not indicate whether or not it has completed. Make a copy of - // the cancel channel to prevent this code from racing with the next - // call of setDeadline replacing *cancelCh. - ch := *cancelCh - *timer = time.AfterFunc(timeout, func() { - close(ch) - }) -} - -// SetReadDeadline implements net.Conn.SetReadDeadline and -// net.PacketConn.SetReadDeadline. -func (d *deadlineTimer) SetReadDeadline(t time.Time) error { - d.mu.Lock() - d.setDeadline(&d.readCancelCh, &d.readTimer, t) - d.mu.Unlock() - return nil -} - -// SetWriteDeadline implements net.Conn.SetWriteDeadline and -// net.PacketConn.SetWriteDeadline. -func (d *deadlineTimer) SetWriteDeadline(t time.Time) error { - d.mu.Lock() - d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) - d.mu.Unlock() - return nil -} - -// SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline. -func (d *deadlineTimer) SetDeadline(t time.Time) error { - d.mu.Lock() - d.setDeadline(&d.readCancelCh, &d.readTimer, t) - d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) - d.mu.Unlock() - return nil -} - -// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn -// interface. -type Conn struct { - deadlineTimer - - wq *waiter.Queue - ep tcpip.Endpoint - - // readMu serializes reads and implicitly protects read. - // - // Lock ordering: - // If both readMu and deadlineTimer.mu are to be used in a single - // request, readMu must be acquired before deadlineTimer.mu. - readMu sync.Mutex - - // read contains bytes that have been read from the endpoint, - // but haven't yet been returned. - read buffer.View -} - -// NewConn creates a new Conn. -func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn { - c := &Conn{ - wq: wq, - ep: ep, - } - c.deadlineTimer.init() - return c -} - -// Accept implements net.Conn.Accept. -func (l *Listener) Accept() (net.Conn, error) { - n, wq, err := l.ep.Accept() - - if err == tcpip.ErrWouldBlock { - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - l.wq.EventRegister(&waitEntry, waiter.EventIn) - defer l.wq.EventUnregister(&waitEntry) - - for { - n, wq, err = l.ep.Accept() - - if err != tcpip.ErrWouldBlock { - break - } - - select { - case <-l.cancel: - return nil, errCanceled - case <-notifyCh: - } - } - } - - if err != nil { - return nil, &net.OpError{ - Op: "accept", - Net: "tcp", - Addr: l.Addr(), - Err: errors.New(err.String()), - } - } - - return NewConn(wq, n), nil -} - -type opErrorer interface { - newOpError(op string, err error) *net.OpError -} - -// commonRead implements the common logic between net.Conn.Read and -// net.PacketConn.ReadFrom. -func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) { - select { - case <-deadline: - return nil, errorer.newOpError("read", &timeoutError{}) - default: - } - - read, _, err := ep.Read(addr) - - if err == tcpip.ErrWouldBlock { - if dontWait { - return nil, errWouldBlock - } - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventIn) - defer wq.EventUnregister(&waitEntry) - for { - read, _, err = ep.Read(addr) - if err != tcpip.ErrWouldBlock { - break - } - select { - case <-deadline: - return nil, errorer.newOpError("read", &timeoutError{}) - case <-notifyCh: - } - } - } - - if err == tcpip.ErrClosedForReceive { - return nil, io.EOF - } - - if err != nil { - return nil, errorer.newOpError("read", errors.New(err.String())) - } - - return read, nil -} - -// Read implements net.Conn.Read. -func (c *Conn) Read(b []byte) (int, error) { - c.readMu.Lock() - defer c.readMu.Unlock() - - deadline := c.readCancel() - - numRead := 0 - for numRead != len(b) { - if len(c.read) == 0 { - var err error - c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0) - if err != nil { - if numRead != 0 { - return numRead, nil - } - return numRead, err - } - } - n := copy(b[numRead:], c.read) - c.read.TrimFront(n) - numRead += n - if len(c.read) == 0 { - c.read = nil - } - } - return numRead, nil -} - -// Write implements net.Conn.Write. -func (c *Conn) Write(b []byte) (int, error) { - deadline := c.writeCancel() - - // Check if deadlineTimer has already expired. - select { - case <-deadline: - return 0, c.newOpError("write", &timeoutError{}) - default: - } - - v := buffer.NewViewFromBytes(b) - - // We must handle two soft failure conditions simultaneously: - // 1. Write may write nothing and return tcpip.ErrWouldBlock. - // If this happens, we need to register for notifications if we have - // not already and wait to try again. - // 2. Write may write fewer than the full number of bytes and return - // without error. In this case we need to try writing the remaining - // bytes again. I do not need to register for notifications. - // - // What is more, these two soft failure conditions can be interspersed. - // There is no guarantee that all of the condition #1s will occur before - // all of the condition #2s or visa-versa. - var ( - err *tcpip.Error - nbytes int - reg bool - notifyCh chan struct{} - ) - for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) { - if err == tcpip.ErrWouldBlock { - if !reg { - // Only register once. - reg = true - - // Create wait queue entry that notifies a channel. - var waitEntry waiter.Entry - waitEntry, notifyCh = waiter.NewChannelEntry(nil) - c.wq.EventRegister(&waitEntry, waiter.EventOut) - defer c.wq.EventUnregister(&waitEntry) - } else { - // Don't wait immediately after registration in case more data - // became available between when we last checked and when we setup - // the notification. - select { - case <-deadline: - return nbytes, c.newOpError("write", &timeoutError{}) - case <-notifyCh: - } - } - } - - var n uintptr - var resCh <-chan struct{} - n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - - if resCh != nil { - select { - case <-deadline: - return nbytes, c.newOpError("write", &timeoutError{}) - case <-resCh: - } - - n, _, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - } - } - - if err == nil { - return nbytes, nil - } - - return nbytes, c.newOpError("write", errors.New(err.String())) -} - -// Close implements net.Conn.Close. -func (c *Conn) Close() error { - c.ep.Close() - return nil -} - -// CloseRead shuts down the reading side of the TCP connection. Most callers -// should just use Close. -// -// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn. -func (c *Conn) CloseRead() error { - if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil { - return c.newOpError("close", errors.New(terr.String())) - } - return nil -} - -// CloseWrite shuts down the writing side of the TCP connection. Most callers -// should just use Close. -// -// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn. -func (c *Conn) CloseWrite() error { - if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil { - return c.newOpError("close", errors.New(terr.String())) - } - return nil -} - -// LocalAddr implements net.Conn.LocalAddr. -func (c *Conn) LocalAddr() net.Addr { - a, err := c.ep.GetLocalAddress() - if err != nil { - return nil - } - return fullToTCPAddr(a) -} - -// RemoteAddr implements net.Conn.RemoteAddr. -func (c *Conn) RemoteAddr() net.Addr { - a, err := c.ep.GetRemoteAddress() - if err != nil { - return nil - } - return fullToTCPAddr(a) -} - -func (c *Conn) newOpError(op string, err error) *net.OpError { - return &net.OpError{ - Op: op, - Net: "tcp", - Source: c.LocalAddr(), - Addr: c.RemoteAddr(), - Err: err, - } -} - -func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr { - return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} -} - -func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { - return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} -} - -// DialTCP creates a new TCP Conn connected to the specified address. -func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { - return DialContextTCP(context.Background(), s, addr, network) -} - -// DialContextTCP creates a new TCP Conn connected to the specified address -// with the option of adding cancellation and timeouts. -func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { - // Create TCP endpoint, then connect. - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) - if err != nil { - return nil, errors.New(err.String()) - } - - // Create wait queue entry that notifies a channel. - // - // We do this unconditionally as Connect will always return an error. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - defer wq.EventUnregister(&waitEntry) - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { - select { - case <-ctx.Done(): - ep.Close() - return nil, ctx.Err() - case <-notifyCh: - } - - err = ep.GetSockOpt(tcpip.ErrorOption{}) - } - if err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "connect", - Net: "tcp", - Addr: fullToTCPAddr(addr), - Err: errors.New(err.String()), - } - } - - return NewConn(&wq, ep), nil -} - -// A PacketConn is a wrapper around a tcpip endpoint that implements -// net.PacketConn. -type PacketConn struct { - deadlineTimer - - stack *stack.Stack - ep tcpip.Endpoint - wq *waiter.Queue -} - -// NewPacketConn creates a new PacketConn. -func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { - // Create UDP endpoint and bind it. - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) - if err != nil { - return nil, errors.New(err.String()) - } - - if err := ep.Bind(addr); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "bind", - Net: "udp", - Addr: fullToUDPAddr(addr), - Err: errors.New(err.String()), - } - } - - c := &PacketConn{ - stack: s, - ep: ep, - wq: &wq, - } - c.deadlineTimer.init() - return c, nil -} - -func (c *PacketConn) newOpError(op string, err error) *net.OpError { - return c.newRemoteOpError(op, nil, err) -} - -func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { - return &net.OpError{ - Op: op, - Net: "udp", - Source: c.LocalAddr(), - Addr: remote, - Err: err, - } -} - -// RemoteAddr implements net.Conn.RemoteAddr. -func (c *PacketConn) RemoteAddr() net.Addr { - a, err := c.ep.GetRemoteAddress() - if err != nil { - return nil - } - return fullToTCPAddr(a) -} - -// Read implements net.Conn.Read -func (c *PacketConn) Read(b []byte) (int, error) { - bytesRead, _, err := c.ReadFrom(b) - return bytesRead, err -} - -// ReadFrom implements net.PacketConn.ReadFrom. -func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - deadline := c.readCancel() - - var addr tcpip.FullAddress - read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false) - if err != nil { - return 0, nil, err - } - - return copy(b, read), fullToUDPAddr(addr), nil -} - -func (c *PacketConn) Write(b []byte) (int, error) { - return c.WriteTo(b, nil) -} - -// WriteTo implements net.PacketConn.WriteTo. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { - deadline := c.writeCancel() - - // Check if deadline has already expired. - select { - case <-deadline: - return 0, c.newRemoteOpError("write", addr, &timeoutError{}) - default: - } - - // If we're being called by Write, there is no addr - wopts := tcpip.WriteOptions{} - if addr != nil { - ua := addr.(*net.UDPAddr) - wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)} - } - - v := buffer.NewView(len(b)) - copy(v, b) - - n, resCh, err := c.ep.Write(tcpip.SlicePayload(v), wopts) - if resCh != nil { - select { - case <-deadline: - return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) - case <-resCh: - } - - n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) - } - - if err == tcpip.ErrWouldBlock { - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&waitEntry, waiter.EventOut) - defer c.wq.EventUnregister(&waitEntry) - for { - select { - case <-deadline: - return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) - case <-notifyCh: - } - - n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) - if err != tcpip.ErrWouldBlock { - break - } - } - } - - if err == nil { - return int(n), nil - } - - return int(n), c.newRemoteOpError("write", addr, errors.New(err.String())) -} - -// Close implements net.PacketConn.Close. -func (c *PacketConn) Close() error { - c.ep.Close() - return nil -} - -// LocalAddr implements net.PacketConn.LocalAddr. -func (c *PacketConn) LocalAddr() net.Addr { - a, err := c.ep.GetLocalAddress() - if err != nil { - return nil - } - return fullToUDPAddr(a) -} diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go deleted file mode 100644 index 39efe44c7..000000000 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ /dev/null @@ -1,643 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gonet - -import ( - "context" - "fmt" - "io" - "net" - "reflect" - "strings" - "testing" - "time" - - "golang.org/x/net/nettest" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - NICID = 1 -) - -func TestTimeouts(t *testing.T) { - nc := NewConn(nil, nil) - dlfs := []struct { - name string - f func(time.Time) error - }{ - {"SetDeadline", nc.SetDeadline}, - {"SetReadDeadline", nc.SetReadDeadline}, - {"SetWriteDeadline", nc.SetWriteDeadline}, - } - - for _, dlf := range dlfs { - if err := dlf.f(time.Time{}); err != nil { - t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil) - } - } -} - -func newLoopbackStack() (*stack.Stack, *tcpip.Error) { - // Create the stack and add a NIC. - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{}) - - if err := s.CreateNIC(NICID, loopback.New()); err != nil { - return nil, err - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - // IPv4 - { - Destination: tcpip.Address(strings.Repeat("\x00", 4)), - Mask: tcpip.AddressMask(strings.Repeat("\x00", 4)), - Gateway: "", - NIC: NICID, - }, - - // IPv6 - { - Destination: tcpip.Address(strings.Repeat("\x00", 16)), - Mask: tcpip.AddressMask(strings.Repeat("\x00", 16)), - Gateway: "", - NIC: NICID, - }, - }) - - return s, nil -} - -type testConnection struct { - wq *waiter.Queue - e *waiter.Entry - ch chan struct{} - ep tcpip.Endpoint -} - -func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { - wq := &waiter.Queue{} - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - - entry, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&entry, waiter.EventOut) - - err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { - <-ch - err = ep.GetSockOpt(tcpip.ErrorOption{}) - } - if err != nil { - return nil, err - } - - wq.EventUnregister(&entry) - wq.EventRegister(&entry, waiter.EventIn) - - return &testConnection{wq, &entry, ch, ep}, nil -} - -func (c *testConnection) close() { - c.wq.EventUnregister(c.e) - c.ep.Close() -} - -// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock. -func TestCloseReader(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - l, e := NewListener(s, addr, ipv4.ProtocolNumber) - if e != nil { - t.Fatalf("NewListener() = %v", e) - } - done := make(chan struct{}) - go func() { - defer close(done) - c, err := l.Accept() - if err != nil { - t.Fatalf("l.Accept() = %v", err) - } - - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.Close() - }) - - buf := make([]byte, 256) - n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want) - } - }() - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when -// using tcp.Forwarder. -func TestCloseReaderWithForwarder(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - done := make(chan struct{}) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - defer close(done) - - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewConn(&wq, ep) - - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.Close() - }) - - buf := make([]byte, 256) - n, e := c.Read(buf) - got, ok := e.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want) - } - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(5 * time.Second): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -func TestCloseRead(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e) - } - - if n, e = c.Write([]byte("abc123")); e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } - }) - - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - tc, terr := connect(s, addr) - if terr != nil { - t.Fatalf("connect() = %v", terr) - } - c := NewConn(tc.wq, tc.ep) - - if err := c.CloseRead(); err != nil { - t.Errorf("c.CloseRead() = %v", err) - } - - buf := make([]byte, 256) - if n, err := c.Read(buf); err != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, err) - } - - if n, err := c.Write([]byte("abc123")); n != 6 || err != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, err) - } -} - -func TestCloseWrite(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - r.Complete(false) - - c := NewConn(&wq, ep) - - n, e := c.Read(make([]byte, 256)) - if n != 0 || e != io.EOF { - t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, e) - } - - if n, e = c.Write([]byte("abc123")); n != 6 || e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } - }) - - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - tc, terr := connect(s, addr) - if terr != nil { - t.Fatalf("connect() = %v", terr) - } - c := NewConn(tc.wq, tc.ep) - - if err := c.CloseWrite(); err != nil { - t.Errorf("c.CloseWrite() = %v", err) - } - - buf := make([]byte, 256) - n, err := c.Read(buf) - if err != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, err) - } - - n, err = c.Write([]byte("abc123")) - got, ok := err.(*net.OpError) - want := "endpoint is closed for send" - if n != 0 || !ok || got.Op != "write" || got.Err == nil || !strings.HasSuffix(got.Err.Error(), want) { - t.Errorf("c.Write() = (%d, %v), want (0, OpError(Op: write, Err: %s))", n, err, want) - } -} - -func TestUDPForwarder(t *testing.T) { - s, terr := newLoopbackStack() - if terr != nil { - t.Fatalf("newLoopbackStack() = %v", terr) - } - - ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) - ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) - addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - - done := make(chan struct{}) - fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) { - defer close(done) - - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - t.Fatalf("r.CreateEndpoint() = %v", err) - } - defer ep.Close() - - c := NewConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil { - t.Errorf("c.Read() = %v", e) - } - - if _, e := c.Write(buf[:n]); e != nil { - t.Errorf("c.Write() = %v", e) - } - }) - s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket) - - c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("NewPacketConn(port 5):", err) - } - - sent := "abc123" - sendAddr := fullToUDPAddr(addr1) - if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { - t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) - } - - buf := make([]byte, 256) - n, recvAddr, err := c2.ReadFrom(buf) - if err != nil || recvAddr.String() != sendAddr.String() { - t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil) - } -} - -// TestDeadlineChange tests that changing the deadline affects currently blocked reads. -func TestDeadlineChange(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - l, e := NewListener(s, addr, ipv4.ProtocolNumber) - if e != nil { - t.Fatalf("NewListener() = %v", e) - } - done := make(chan struct{}) - go func() { - defer close(done) - c, err := l.Accept() - if err != nil { - t.Fatalf("l.Accept() = %v", err) - } - - c.SetDeadline(time.Now().Add(time.Minute)) - // Give c.Read() a chance to block before closing the connection. - time.AfterFunc(time.Millisecond*50, func() { - c.SetDeadline(time.Now().Add(time.Millisecond * 10)) - }) - - buf := make([]byte, 256) - n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := "i/o timeout" - if n != 0 || !ok || got.Err == nil || got.Err.Error() != want { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want) - } - }() - sender, err := connect(s, addr) - if err != nil { - t.Fatalf("connect() = %v", err) - } - - select { - case <-done: - case <-time.After(time.Millisecond * 500): - t.Errorf("c.Read() didn't unblock") - } - sender.close() -} - -func TestPacketConnTransfer(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - - ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr1 := tcpip.FullAddress{NICID, ip1, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip1) - ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4()) - addr2 := tcpip.FullAddress{NICID, ip2, 11311} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - - c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("NewPacketConn(port 4):", err) - } - c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) - if err != nil { - t.Fatal("NewPacketConn(port 5):", err) - } - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - sent := "abc123" - sendAddr := fullToUDPAddr(addr2) - if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) { - t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil) - } - recv := make([]byte, len(sent)) - n, recvAddr, err := c2.ReadFrom(recv) - if err != nil || n != len(recv) { - t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil) - } - - if recv := string(recv); recv != sent { - t.Errorf("got recv = %q, want = %q", recv, sent) - } - - if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) { - t.Errorf("got recvAddr = %v, want = %v", recvAddr, want) - } - - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } -} - -func makePipe() (c1, c2 net.Conn, stop func(), err error) { - s, e := newLoopbackStack() - if e != nil { - return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e) - } - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - - l, err := NewListener(s, addr, ipv4.ProtocolNumber) - if err != nil { - return nil, nil, nil, fmt.Errorf("NewListener: %v", err) - } - - c1, err = DialTCP(s, addr, ipv4.ProtocolNumber) - if err != nil { - l.Close() - return nil, nil, nil, fmt.Errorf("DialTCP: %v", err) - } - - c2, err = l.Accept() - if err != nil { - l.Close() - c1.Close() - return nil, nil, nil, fmt.Errorf("l.Accept: %v", err) - } - - stop = func() { - c1.Close() - c2.Close() - } - - if err := l.Close(); err != nil { - stop() - return nil, nil, nil, fmt.Errorf("l.Close(): %v", err) - } - - return c1, c2, stop, nil -} - -func TestTCPConnTransfer(t *testing.T) { - c1, c2, _, err := makePipe() - if err != nil { - t.Fatal(err) - } - defer func() { - if err := c1.Close(); err != nil { - t.Error("c1.Close():", err) - } - if err := c2.Close(); err != nil { - t.Error("c2.Close():", err) - } - }() - - c1.SetDeadline(time.Now().Add(time.Second)) - c2.SetDeadline(time.Now().Add(time.Second)) - - const sent = "abc123" - - tests := []struct { - name string - c1 net.Conn - c2 net.Conn - }{ - {"connected to accepted", c1, c2}, - {"accepted to connected", c2, c1}, - } - - for _, test := range tests { - if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) { - t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil) - continue - } - - recv := make([]byte, len(sent)) - n, err := test.c2.Read(recv) - if err != nil || n != len(recv) { - t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil) - continue - } - - if recv := string(recv); recv != sent { - t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent) - } - } -} - -func TestTCPDialError(t *testing.T) { - s, e := newLoopbackStack() - if e != nil { - t.Fatalf("newLoopbackStack() = %v", e) - } - - ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) - addr := tcpip.FullAddress{NICID, ip, 11211} - - _, err := DialTCP(s, addr, ipv4.ProtocolNumber) - got, ok := err.(*net.OpError) - want := tcpip.ErrNoRoute - if !ok || got.Err.Error() != want.String() { - t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute) - } -} - -func TestDialContextTCPCanceled(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - cancel() - - if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.Canceled { - t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.Canceled) - } -} - -func TestDialContextTCPTimeout(t *testing.T) { - s, err := newLoopbackStack() - if err != nil { - t.Fatalf("newLoopbackStack() = %v", err) - } - - addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} - s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - - fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { - time.Sleep(time.Second) - r.Complete(true) - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) - - ctx := context.Background() - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond)) - defer cancel() - - if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.DeadlineExceeded { - t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.DeadlineExceeded) - } -} - -func TestNetTest(t *testing.T) { - nettest.TestConn(t, makePipe) -} diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD deleted file mode 100644 index 3301967fb..000000000 --- a/pkg/tcpip/buffer/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "buffer", - srcs = [ - "prependable.go", - "view.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/buffer", - visibility = ["//visibility:public"], -) - -go_test( - name = "buffer_test", - size = "small", - srcs = ["view_test.go"], - embed = [":buffer"], -) diff --git a/pkg/tcpip/buffer/buffer_state_autogen.go b/pkg/tcpip/buffer/buffer_state_autogen.go new file mode 100755 index 000000000..41a74ea97 --- /dev/null +++ b/pkg/tcpip/buffer/buffer_state_autogen.go @@ -0,0 +1,24 @@ +// automatically generated by stateify. + +package buffer + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *VectorisedView) beforeSave() {} +func (x *VectorisedView) save(m state.Map) { + x.beforeSave() + m.Save("views", &x.views) + m.Save("size", &x.size) +} + +func (x *VectorisedView) afterLoad() {} +func (x *VectorisedView) load(m state.Map) { + m.Load("views", &x.views) + m.Load("size", &x.size) +} + +func init() { + state.Register("buffer.VectorisedView", (*VectorisedView)(nil), state.Fns{Save: (*VectorisedView).save, Load: (*VectorisedView).load}) +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go deleted file mode 100644 index ebc3a17b7..000000000 --- a/pkg/tcpip/buffer/view_test.go +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package buffer_test contains tests for the VectorisedView type. -package buffer - -import ( - "reflect" - "testing" -) - -// copy returns a deep-copy of the vectorised view. -func (vv VectorisedView) copy() VectorisedView { - uu := VectorisedView{ - views: make([]View, 0, len(vv.views)), - size: vv.size, - } - for _, v := range vv.views { - uu.views = append(uu.views, append(View(nil), v...)) - } - return uu -} - -// vv is an helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) VectorisedView { - views := make([]View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return NewVectorisedView(size, views) -} - -var capLengthTestCases = []struct { - comment string - in VectorisedView - length int - want VectorisedView -}{ - { - comment: "Simple case", - in: vv(2, "12"), - length: 1, - want: vv(1, "1"), - }, - { - comment: "Case spanning across two Views", - in: vv(4, "123", "4"), - length: 2, - want: vv(2, "12"), - }, - { - comment: "Corner case with negative length", - in: vv(1, "1"), - length: -1, - want: vv(0), - }, - { - comment: "Corner case with length = 0", - in: vv(3, "12", "3"), - length: 0, - want: vv(0), - }, - { - comment: "Corner case with length = size", - in: vv(1, "1"), - length: 1, - want: vv(1, "1"), - }, - { - comment: "Corner case with length > size", - in: vv(1, "1"), - length: 2, - want: vv(1, "1"), - }, -} - -func TestCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - orig := c.in.copy() - c.in.CapLength(c.length) - if !reflect.DeepEqual(c.in, c.want) { - t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v", - c.comment, c.length, orig, c.in, c.want) - } - } -} - -var trimFrontTestCases = []struct { - comment string - in VectorisedView - count int - want VectorisedView -}{ - { - comment: "Simple case", - in: vv(2, "12"), - count: 1, - want: vv(1, "2"), - }, - { - comment: "Case where we trim an entire View", - in: vv(2, "1", "2"), - count: 1, - want: vv(1, "2"), - }, - { - comment: "Case spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: vv(1, "3"), - }, - { - comment: "Corner case with negative count", - in: vv(1, "1"), - count: -1, - want: vv(1, "1"), - }, - { - comment: " Corner case with count = 0", - in: vv(1, "1"), - count: 0, - want: vv(1, "1"), - }, - { - comment: "Corner case with count = size", - in: vv(1, "1"), - count: 1, - want: vv(0), - }, - { - comment: "Corner case with count > size", - in: vv(1, "1"), - count: 2, - want: vv(0), - }, -} - -func TestTrimFront(t *testing.T) { - for _, c := range trimFrontTestCases { - orig := c.in.copy() - c.in.TrimFront(c.count) - if !reflect.DeepEqual(c.in, c.want) { - t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v", - c.comment, c.count, orig, c.in, c.want) - } - } -} - -var toViewCases = []struct { - comment string - in VectorisedView - want View -}{ - { - comment: "Simple case", - in: vv(2, "12"), - want: []byte("12"), - }, - { - comment: "Case with multiple views", - in: vv(2, "1", "2"), - want: []byte("12"), - }, - { - comment: "Empty case", - in: vv(0), - want: []byte(""), - }, -} - -func TestToView(t *testing.T) { - for _, c := range toViewCases { - got := c.in.ToView() - if !reflect.DeepEqual(got, c.want) { - t.Errorf("Test \"%s\" failed when calling ToView() on %v. Got %v. Want %v", - c.comment, c.in, got, c.want) - } - } -} - -var toCloneCases = []struct { - comment string - inView VectorisedView - inBuffer []View -}{ - { - comment: "Simple case", - inView: vv(1, "1"), - inBuffer: make([]View, 1), - }, - { - comment: "Case with multiple views", - inView: vv(2, "1", "2"), - inBuffer: make([]View, 2), - }, - { - comment: "Case with buffer too small", - inView: vv(2, "1", "2"), - inBuffer: make([]View, 1), - }, - { - comment: "Case with buffer larger than needed", - inView: vv(1, "1"), - inBuffer: make([]View, 2), - }, - { - comment: "Case with nil buffer", - inView: vv(1, "1"), - inBuffer: nil, - }, -} - -func TestToClone(t *testing.T) { - for _, c := range toCloneCases { - t.Run(c.comment, func(t *testing.T) { - got := c.inView.Clone(c.inBuffer) - if !reflect.DeepEqual(got, c.inView) { - t.Fatalf("got (%+v).Clone(%+v) = %+v, want = %+v", - c.inView, c.inBuffer, got, c.inView) - } - }) - } -} diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD deleted file mode 100644 index 4cecfb989..000000000 --- a/pkg/tcpip/checker/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "checker", - testonly = 1, - srcs = ["checker.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/checker", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go deleted file mode 100644 index afcabd51d..000000000 --- a/pkg/tcpip/checker/checker.go +++ /dev/null @@ -1,588 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package checker provides helper functions to check networking packets for -// validity. -package checker - -import ( - "encoding/binary" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -// NetworkChecker is a function to check a property of a network packet. -type NetworkChecker func(*testing.T, []header.Network) - -// TransportChecker is a function to check a property of a transport packet. -type TransportChecker func(*testing.T, header.Transport) - -// IPv4 checks the validity and properties of the given IPv4 packet. It is -// expected to be used in conjunction with other network checkers for specific -// properties. For example, to check the source and destination address, one -// would call: -// -// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) -func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv4 := header.IPv4(b) - - if !ipv4.IsValid(len(b)) { - t.Error("Not a valid IPv4 packet") - } - - xsum := ipv4.CalculateChecksum() - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) - } - - for _, f := range checkers { - f(t, []header.Network{ipv4}) - } - if t.Failed() { - t.FailNow() - } -} - -// IPv6 checks the validity and properties of the given IPv6 packet. The usage -// is similar to IPv4. -func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv6 := header.IPv6(b) - if !ipv6.IsValid(len(b)) { - t.Error("Not a valid IPv6 packet") - } - - for _, f := range checkers { - f(t, []header.Network{ipv6}) - } - if t.Failed() { - t.FailNow() - } -} - -// SrcAddr creates a checker that checks the source address. -func SrcAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].SourceAddress(); a != addr { - t.Errorf("Bad source address, got %v, want %v", a, addr) - } - } -} - -// DstAddr creates a checker that checks the destination address. -func DstAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].DestinationAddress(); a != addr { - t.Errorf("Bad destination address, got %v, want %v", a, addr) - } - } -} - -// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). -func TTL(ttl uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - var v uint8 - switch ip := h[0].(type) { - case header.IPv4: - v = ip.TTL() - case header.IPv6: - v = ip.HopLimit() - } - if v != ttl { - t.Fatalf("Bad TTL, got %v, want %v", v, ttl) - } - } -} - -// PayloadLen creates a checker that checks the payload length. -func PayloadLen(plen int) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if l := len(h[0].Payload()); l != plen { - t.Errorf("Bad payload length, got %v, want %v", l, plen) - } - } -} - -// FragmentOffset creates a checker that checks the FragmentOffset field. -func FragmentOffset(offset uint16) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this of IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.FragmentOffset(); v != offset { - t.Errorf("Bad fragment offset, got %v, want %v", v, offset) - } - } - } -} - -// FragmentFlags creates a checker that checks the fragment flags field. -func FragmentFlags(flags uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this of IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.Flags(); v != flags { - t.Errorf("Bad fragment offset, got %v, want %v", v, flags) - } - } - } -} - -// TOS creates a checker that checks the TOS field. -func TOS(tos uint8, label uint32) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if v, l := h[0].TOS(); v != tos || l != label { - t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) - } - } -} - -// Raw creates a checker that checks the bytes of payload. -// The checker always checks the payload of the last network header. -// For instance, in case of IPv6 fragments, the payload that will be checked -// is the one containing the actual data that the packet is carrying, without -// the bytes added by the IPv6 fragmentation. -func Raw(want []byte) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} - -// IPv6Fragment creates a checker that validates an IPv6 fragment. -func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) - } - - ipv6Frag := header.IPv6Fragment(h[0].Payload()) - if !ipv6Frag.IsValid() { - t.Error("Not a valid IPv6 fragment") - } - - for _, f := range checkers { - f(t, []header.Network{h[0], ipv6Frag}) - } - if t.Failed() { - t.FailNow() - } - } -} - -// TCP creates a checker that checks that the transport protocol is TCP and -// potentially additional transport header fields. -func TCP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - first := h[0] - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) - } - - // Verify the checksum. - tcp := header.TCP(last.Payload()) - l := uint16(len(tcp)) - - xsum := header.Checksum([]byte(first.SourceAddress()), 0) - xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) - xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) - xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) - xsum = header.Checksum(tcp, xsum) - - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) - } - - // Run the transport checkers. - for _, f := range checkers { - f(t, tcp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// UDP creates a checker that checks that the transport protocol is UDP and -// potentially additional transport header fields. -func UDP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) - } - - udp := header.UDP(last.Payload()) - for _, f := range checkers { - f(t, udp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// SrcPort creates a checker that checks the source port. -func SrcPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if p := h.SourcePort(); p != port { - t.Errorf("Bad source port, got %v, want %v", p, port) - } - } -} - -// DstPort creates a checker that checks the destination port. -func DstPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - if p := h.DestinationPort(); p != port { - t.Errorf("Bad destination port, got %v, want %v", p, port) - } - } -} - -// SeqNum creates a checker that checks the sequence number. -func SeqNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if s := tcp.SequenceNumber(); s != seq { - t.Errorf("Bad sequence number, got %v, want %v", s, seq) - } - } -} - -// AckNum creates a checker that checks the ack number. -func AckNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if s := tcp.AckNumber(); s != seq { - t.Errorf("Bad ack number, got %v, want %v", s, seq) - } - } -} - -// Window creates a checker that checks the tcp window. -func Window(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if w := tcp.WindowSize(); w != window { - t.Errorf("Bad window, got 0x%x, want 0x%x", w, window) - } - } -} - -// TCPFlags creates a checker that checks the tcp flags. -func TCPFlags(flags uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if f := tcp.Flags(); f != flags { - t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) - } - } -} - -// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the -// given mask, match the supplied flags. -func TCPFlagsMatch(flags, mask uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - - if f := tcp.Flags(); (f & mask) != (flags & mask) { - t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) - } - } -} - -// TCPSynOptions creates a checker that checks the presence of TCP options in -// SYN segments. -// -// If wndscale is negative, the window scale option must not be present. -func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := tcp.Options() - limit := len(opts) - foundMSS := false - foundWS := false - foundTS := false - foundSACKPermitted := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionMSS: - v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) - if wantOpts.MSS != v { - t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) - } - foundMSS = true - i += 4 - case header.TCPOptionWS: - if wantOpts.WS < 0 { - t.Error("WS present when it shouldn't be") - } - v := int(opts[i+2]) - if v != wantOpts.WS { - t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) - } - foundWS = true - i += 3 - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = uint32(0) - if tcp.Flags()&header.TCPFlagAck != 0 { - // If the syn is an SYN-ACK then read - // the tsEcr value as well. - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - } - foundTS = true - i += 10 - case header.TCPOptionSACKPermitted: - if i+1 >= limit { - t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) - } - if opts[i+1] != 2 { - t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) - } - foundSACKPermitted = true - i += 2 - - default: - i += int(opts[i+1]) - } - } - - if !foundMSS { - t.Errorf("MSS option not found. Options: %x", opts) - } - - if !foundWS && wantOpts.WS >= 0 { - t.Errorf("WS option not found. Options: %x", opts) - } - if wantOpts.TS && !foundTS { - t.Errorf("TS option not found. Options: %x", opts) - } - if foundTS && tsVal == 0 { - t.Error("TS option specified but the timestamp value is zero") - } - if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) - } - if wantOpts.SACKPermitted && !foundSACKPermitted { - t.Errorf("SACKPermitted option not found. Options: %x", opts) - } - } -} - -// TCPTimestampChecker creates a checker that validates that a TCP segment has a -// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and -// wantTSEcr values with those in the TCP segment (if present). -// -// If wantTSVal or wantTSEcr is zero then the corresponding comparison is -// skipped. -func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := []byte(tcp.Options()) - limit := len(opts) - foundTS := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - foundTS = true - i += 10 - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - return - } - l := int(opts[i+1]) - if i < 2 || i+l > limit { - return - } - i += l - } - } - - if wantTS != foundTS { - t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) - } - if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) - } - if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) - } - } -} - -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not -// contain any SACK blocks in the TCP options. -func TCPNoSACKBlockChecker() TransportChecker { - return TCPSACKBlockChecker(nil) -} - -// TCPSACKBlockChecker creates a checker that verifies that the segment does -// contain the specified SACK blocks in the TCP options. -func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - tcp, ok := h.(header.TCP) - if !ok { - return - } - var gotSACKBlocks []header.SACKBlock - - opts := []byte(tcp.Options()) - limit := len(opts) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionSACK: - if i+2 > limit { - // Malformed SACK block. - t.Errorf("malformed SACK option in options: %v", opts) - } - sackOptionLen := int(opts[i+1]) - if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { - // Malformed SACK block. - t.Errorf("malformed SACK option length in options: %v", opts) - } - numBlocks := sackOptionLen / 8 - for j := 0; j < numBlocks; j++ { - start := binary.BigEndian.Uint32(opts[i+2+j*8:]) - end := binary.BigEndian.Uint32(opts[i+2+j*8+4:]) - gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{ - Start: seqnum.Value(start), - End: seqnum.Value(end), - }) - } - i += sackOptionLen - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - break - } - l := int(opts[i+1]) - if l < 2 || i+l > limit { - break - } - i += l - } - } - - if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) - } - } -} - -// Payload creates a checker that checks the payload. -func Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - if got := h.Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD deleted file mode 100644 index 29b30be9c..000000000 --- a/pkg/tcpip/hash/jenkins/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "jenkins", - srcs = ["jenkins.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins", - visibility = [ - "//visibility:public", - ], -) - -go_test( - name = "jenkins_test", - size = "small", - srcs = [ - "jenkins_test.go", - ], - embed = [":jenkins"], -) diff --git a/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go new file mode 100755 index 000000000..310f0ee6d --- /dev/null +++ b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package jenkins + diff --git a/pkg/tcpip/hash/jenkins/jenkins_test.go b/pkg/tcpip/hash/jenkins/jenkins_test.go deleted file mode 100644 index 4c78b5808..000000000 --- a/pkg/tcpip/hash/jenkins/jenkins_test.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package jenkins - -import ( - "bytes" - "encoding/binary" - "hash" - "hash/fnv" - "math" - "testing" -) - -func TestGolden32(t *testing.T) { - var golden32 = []struct { - out []byte - in string - }{ - {[]byte{0x00, 0x00, 0x00, 0x00}, ""}, - {[]byte{0xca, 0x2e, 0x94, 0x42}, "a"}, - {[]byte{0x45, 0xe6, 0x1e, 0x58}, "ab"}, - {[]byte{0xed, 0x13, 0x1f, 0x5b}, "abc"}, - } - - hash := New32() - - for _, g := range golden32 { - hash.Reset() - done, error := hash.Write([]byte(g.in)) - if error != nil { - t.Fatalf("write error: %s", error) - } - if done != len(g.in) { - t.Fatalf("wrote only %d out of %d bytes", done, len(g.in)) - } - if actual := hash.Sum(nil); !bytes.Equal(g.out, actual) { - t.Errorf("hash(%q) = 0x%x want 0x%x", g.in, actual, g.out) - } - } -} - -func TestIntegrity32(t *testing.T) { - data := []byte{'1', '2', 3, 4, 5} - - h := New32() - h.Write(data) - sum := h.Sum(nil) - - if size := h.Size(); size != len(sum) { - t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum)) - } - - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a) - } - - h.Reset() - h.Write(data) - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a) - } - - h.Reset() - h.Write(data[:2]) - h.Write(data[2:]) - if a := h.Sum(nil); !bytes.Equal(sum, a) { - t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a) - } - - sum32 := h.(hash.Hash32).Sum32() - if sum32 != binary.BigEndian.Uint32(sum) { - t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32) - } -} - -func BenchmarkJenkins32KB(b *testing.B) { - h := New32() - - b.SetBytes(1024) - data := make([]byte, 1024) - for i := range data { - data[i] = byte(i) - } - in := make([]byte, 0, h.Size()) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - h.Reset() - h.Write(data) - h.Sum(in) - } -} - -func BenchmarkFnv32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - - h := fnv.New32() - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - c := 0 - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - if c == 0 { - b.Logf("i %d val[i] %d val[i+1] %d b.N %b\n", i, arr[i], arr[i+1], b.N) - } - c++ - } - } - if c > 0 { - b.Logf("Unbalanced buckets: %d", c) - } - } -} - -func BenchmarkSum32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - h := Sum32(0) - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) - break - } - } - } -} - -func BenchmarkNew32(b *testing.B) { - arr := make([]int64, 1000) - for i := 0; i < b.N; i++ { - var payload [8]byte - binary.BigEndian.PutUint32(payload[:4], uint32(i)) - binary.BigEndian.PutUint32(payload[4:], uint32(i)) - h := New32() - h.Write(payload[:]) - idx := int(h.Sum32()) % len(arr) - arr[idx]++ - } - b.StopTimer() - if b.N > 1000000 { - for i := 0; i < len(arr)-1; i++ { - if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) { - b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N) - break - } - } - } -} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD deleted file mode 100644 index 76ef02f13..000000000 --- a/pkg/tcpip/header/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "header", - srcs = [ - "arp.go", - "checksum.go", - "eth.go", - "gue.go", - "icmpv4.go", - "icmpv6.go", - "interfaces.go", - "ipv4.go", - "ipv6.go", - "ipv6_fragment.go", - "tcp.go", - "udp.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/header", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/seqnum", - "@com_github_google_btree//:go_default_library", - ], -) - -go_test( - name = "header_test", - size = "small", - srcs = [ - "ipversion_test.go", - "tcp_test.go", - ], - deps = [ - ":header", - ], -) diff --git a/pkg/tcpip/header/header_state_autogen.go b/pkg/tcpip/header/header_state_autogen.go new file mode 100755 index 000000000..4fcb89452 --- /dev/null +++ b/pkg/tcpip/header/header_state_autogen.go @@ -0,0 +1,42 @@ +// automatically generated by stateify. + +package header + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *SACKBlock) beforeSave() {} +func (x *SACKBlock) save(m state.Map) { + x.beforeSave() + m.Save("Start", &x.Start) + m.Save("End", &x.End) +} + +func (x *SACKBlock) afterLoad() {} +func (x *SACKBlock) load(m state.Map) { + m.Load("Start", &x.Start) + m.Load("End", &x.End) +} + +func (x *TCPOptions) beforeSave() {} +func (x *TCPOptions) save(m state.Map) { + x.beforeSave() + m.Save("TS", &x.TS) + m.Save("TSVal", &x.TSVal) + m.Save("TSEcr", &x.TSEcr) + m.Save("SACKBlocks", &x.SACKBlocks) +} + +func (x *TCPOptions) afterLoad() {} +func (x *TCPOptions) load(m state.Map) { + m.Load("TS", &x.TS) + m.Load("TSVal", &x.TSVal) + m.Load("TSEcr", &x.TSEcr) + m.Load("SACKBlocks", &x.SACKBlocks) +} + +func init() { + state.Register("header.SACKBlock", (*SACKBlock)(nil), state.Fns{Save: (*SACKBlock).save, Load: (*SACKBlock).load}) + state.Register("header.TCPOptions", (*TCPOptions)(nil), state.Fns{Save: (*TCPOptions).save, Load: (*TCPOptions).load}) +} diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go deleted file mode 100644 index b5540bf66..000000000 --- a/pkg/tcpip/header/ipversion_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestIPv4(t *testing.T) { - b := header.IPv4(make([]byte, header.IPv4MinimumSize)) - b.Encode(&header.IPv4Fields{}) - - const want = header.IPv4Version - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestIPv6(t *testing.T) { - b := header.IPv6(make([]byte, header.IPv6MinimumSize)) - b.Encode(&header.IPv6Fields{}) - - const want = header.IPv6Version - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestOtherVersion(t *testing.T) { - const want = header.IPv4Version + header.IPv6Version - b := make([]byte, 1) - b[0] = want << 4 - - if v := header.IPVersion(b); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} - -func TestTooShort(t *testing.T) { - b := make([]byte, 1) - b[0] = (header.IPv4Version + header.IPv6Version) << 4 - - // Get the version of a zero-length slice. - const want = -1 - if v := header.IPVersion(b[:0]); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } - - // Get the version of a nil slice. - if v := header.IPVersion(nil); v != want { - t.Fatalf("Bad version, want %v, got %v", want, v) - } -} diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go deleted file mode 100644 index 72563837b..000000000 --- a/pkg/tcpip/header/tcp_test.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package header_test - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestEncodeSACKBlocks(t *testing.T) { - testCases := []struct { - sackBlocks []header.SACKBlock - want []header.SACKBlock - bufSize int - }{ - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, - 40, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, - 30, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}}, - 20, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}}, - 10, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - nil, - 8, - }, - { - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, - []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, - 60, - }, - } - for _, tc := range testCases { - b := make([]byte, tc.bufSize) - t.Logf("testing: %v", tc) - header.EncodeSACKBlocks(tc.sackBlocks, b) - opts := header.ParseTCPOptions(b) - if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want) - } - } -} - -func TestTCPParseOptions(t *testing.T) { - type tsOption struct { - tsVal uint32 - tsEcr uint32 - } - - generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte { - l := 0 - if tsOpt != nil { - l += 10 - } - if len(sackBlocks) != 0 { - l += len(sackBlocks)*8 + 2 - } - b := make([]byte, l) - offset := 0 - if tsOpt != nil { - offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b) - } - header.EncodeSACKBlocks(sackBlocks, b[offset:]) - return b - } - - testCases := []struct { - b []byte - want header.TCPOptions - }{ - // Trivial cases. - {nil, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - - // Test timestamp parsing. - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - - // Test malformed timestamp option. - {[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - - // Test SACKBlock parsing. - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}}, - {[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}}, - - // Test malformed SACK option. - {[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}}, - - // Test Timestamp + SACK block parsing. - {generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}}, - {generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}}, - {generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}}, - - // Test valid timestamp + malformed SACK block parsing. - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, - {[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - {[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}}, - {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}}, - {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, - } - for _, tc := range testCases { - if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) { - t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want) - } - } -} diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD deleted file mode 100644 index 97a794986..000000000 --- a/pkg/tcpip/link/channel/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "channel", - srcs = ["channel.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", - visibility = ["//:sandbox"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go deleted file mode 100644 index c40744b8e..000000000 --- a/pkg/tcpip/link/channel/channel.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package channel provides the implemention of channel-based data-link layer -// endpoints. Such endpoints allow injection of inbound packets and store -// outbound packets in a channel. -package channel - -import ( - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// PacketInfo holds all the information about an outbound packet. -type PacketInfo struct { - Header buffer.View - Payload buffer.View - Proto tcpip.NetworkProtocolNumber - GSO *stack.GSO -} - -// Endpoint is link layer endpoint that stores outbound packets in a channel -// and allows injection of inbound packets. -type Endpoint struct { - dispatcher stack.NetworkDispatcher - mtu uint32 - linkAddr tcpip.LinkAddress - GSO bool - - // C is where outbound packets are queued. - C chan PacketInfo -} - -// New creates a new channel endpoint. -func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ - C: make(chan PacketInfo, size), - mtu: mtu, - linkAddr: linkAddr, - } - - return stack.RegisterLinkEndpoint(e), e -} - -// Drain removes all outbound packets from the channel and counts them. -func (e *Endpoint) Drain() int { - c := 0 - for { - select { - case <-e.C: - c++ - default: - return c - } - } -} - -// Inject injects an inbound packet. -func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.InjectLinkAddr(protocol, "", vv) -} - -// InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil)) -} - -// Attach saves the stack network-layer dispatcher for use later when packets -// are injected. -func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *Endpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized -// during construction. -func (e *Endpoint) MTU() uint32 { - return e.mtu -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - caps := stack.LinkEndpointCapabilities(0) - if e.GSO { - caps |= stack.CapabilityGSO - } - return caps -} - -// GSOMaxSize returns the maximum GSO packet size. -func (*Endpoint) GSOMaxSize() uint32 { - return 1 << 15 -} - -// MaxHeaderLength returns the maximum size of the link layer header. Given it -// doesn't have a header, it just returns 0. -func (*Endpoint) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (e *Endpoint) LinkAddress() tcpip.LinkAddress { - return e.linkAddr -} - -// WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - p := PacketInfo{ - Header: hdr.View(), - Proto: protocol, - Payload: payload.ToView(), - GSO: gso, - } - - select { - case e.C <- p: - default: - } - - return nil -} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD deleted file mode 100644 index d786d8fdf..000000000 --- a/pkg/tcpip/link/fdbased/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "fdbased", - srcs = [ - "endpoint.go", - "endpoint_unsafe.go", - "mmap.go", - "mmap_amd64.go", - "mmap_amd64_unsafe.go", - "packet_dispatchers.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/stack", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -go_test( - name = "fdbased_test", - size = "small", - srcs = ["endpoint_test.go"], - embed = [":fdbased"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go deleted file mode 100644 index e305252d6..000000000 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ /dev/null @@ -1,455 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package fdbased - -import ( - "bytes" - "fmt" - "math/rand" - "reflect" - "syscall" - "testing" - "time" - "unsafe" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - mtu = 1500 - laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") - raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") - proto = 10 - csumOffset = 48 - gsoMSS = 500 -) - -type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - contents buffer.View -} - -type context struct { - t *testing.T - fds [2]int - ep stack.LinkEndpoint - ch chan packetInfo - done chan struct{} -} - -func newContext(t *testing.T, opt *Options) *context { - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) - if err != nil { - t.Fatalf("Socketpair failed: %v", err) - } - - done := make(chan struct{}, 1) - opt.ClosedFunc = func(*tcpip.Error) { - done <- struct{}{} - } - - opt.FDs = []int{fds[1]} - epID, err := New(opt) - if err != nil { - t.Fatalf("Failed to create FD endpoint: %v", err) - } - ep := stack.FindLinkEndpoint(epID).(*endpoint) - - c := &context{ - t: t, - fds: fds, - ep: ep, - ch: make(chan packetInfo, 100), - done: done, - } - - ep.Attach(c) - - return c -} - -func (c *context) cleanup() { - syscall.Close(c.fds[0]) - <-c.done - syscall.Close(c.fds[1]) -} - -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - c.ch <- packetInfo{remote, protocol, vv.ToView()} -} - -func TestNoEthernetProperties(t *testing.T) { - c := newContext(t, &Options{MTU: mtu}) - defer c.cleanup() - - if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v { - t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) - } - - if want, v := uint32(mtu), c.ep.MTU(); want != v { - t.Fatalf("MTU() = %v, want %v", v, want) - } -} - -func TestEthernetProperties(t *testing.T) { - c := newContext(t, &Options{EthernetHeader: true, MTU: mtu}) - defer c.cleanup() - - if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v { - t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) - } - - if want, v := uint32(mtu), c.ep.MTU(); want != v { - t.Fatalf("MTU() = %v, want %v", v, want) - } -} - -func TestAddress(t *testing.T) { - addrs := []tcpip.LinkAddress{"", "abc", "def"} - for _, a := range addrs { - t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) { - c := newContext(t, &Options{Address: a, MTU: mtu}) - defer c.cleanup() - - if want, v := a, c.ep.LinkAddress(); want != v { - t.Fatalf("LinkAddress() = %v, want %v", v, want) - } - }) - } -} - -func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) - defer c.cleanup() - - r := &stack.Route{ - RemoteLinkAddress: raddr, - } - - // Build header. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) - b := hdr.Prepend(100) - for i := range b { - b[i] = uint8(rand.Intn(256)) - } - - // Build payload and write. - payload := make(buffer.View, plen) - for i := range payload { - payload[i] = uint8(rand.Intn(256)) - } - want := append(hdr.View(), payload...) - var gso *stack.GSO - if gsoMaxSize != 0 { - gso = &stack.GSO{ - Type: stack.GSOTCPv6, - NeedsCsum: true, - CsumOffset: csumOffset, - MSS: gsoMSS, - MaxSize: gsoMaxSize, - L3HdrLen: header.IPv4MaximumHeaderSize, - } - } - if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Read from fd, then compare with what we wrote. - b = make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - b = b[:n] - if gsoMaxSize != 0 { - vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0])) - if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 { - t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM) - } - csumStart := header.EthernetMinimumSize + gso.L3HdrLen - if vnetHdr.csumStart != csumStart { - t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart) - } - if vnetHdr.csumOffset != csumOffset { - t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset) - } - gsoType := uint8(0) - if int(gso.MSS) < plen { - gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 - } - if vnetHdr.gsoType != gsoType { - t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType) - } - b = b[virtioNetHdrSize:] - } - if eth { - h := header.Ethernet(b) - b = b[header.EthernetMinimumSize:] - - if a := h.SourceAddress(); a != laddr { - t.Fatalf("SourceAddress() = %v, want %v", a, laddr) - } - - if a := h.DestinationAddress(); a != raddr { - t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) - } - - if et := h.Type(); et != proto { - t.Fatalf("Type() = %v, want %v", et, proto) - } - } - if len(b) != len(want) { - t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) - } - if !bytes.Equal(b, want) { - t.Fatalf("Read returned %x, want %x", b, want) - } -} - -func TestWritePacket(t *testing.T) { - lengths := []int{0, 100, 1000} - eths := []bool{true, false} - gsos := []uint32{0, 32768} - - for _, eth := range eths { - for _, plen := range lengths { - for _, gso := range gsos { - t.Run( - fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso), - func(t *testing.T) { - testWritePacket(t, plen, eth, gso) - }, - ) - } - } - } -} - -func TestPreserveSrcAddress(t *testing.T) { - baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99") - - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true}) - defer c.cleanup() - - // Set LocalLinkAddress in route to the value of the bridged address. - r := &stack.Route{ - RemoteLinkAddress: raddr, - LocalLinkAddress: baddr, - } - - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Read from the FD, then compare with what we wrote. - b := make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - b = b[:n] - h := header.Ethernet(b) - - if a := h.SourceAddress(); a != baddr { - t.Fatalf("SourceAddress() = %v, want %v", a, baddr) - } -} - -func TestDeliverPacket(t *testing.T) { - lengths := []int{100, 1000} - eths := []bool{true, false} - - for _, eth := range eths { - for _, plen := range lengths { - t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) - defer c.cleanup() - - // Build packet. - b := make([]byte, plen) - all := b - for i := range b { - b[i] = uint8(rand.Intn(256)) - } - - if !eth { - // So that it looks like an IPv4 packet. - b[0] = 0x40 - } else { - hdr := make(header.Ethernet, header.EthernetMinimumSize) - hdr.Encode(&header.EthernetFields{ - SrcAddr: raddr, - DstAddr: laddr, - Type: proto, - }) - all = append(hdr, b...) - } - - // Write packet via the file descriptor. - if _, err := syscall.Write(c.fds[0], all); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Receive packet through the endpoint. - select { - case pi := <-c.ch: - want := packetInfo{ - raddr: raddr, - proto: proto, - contents: b, - } - if !eth { - want.proto = header.IPv4ProtocolNumber - want.raddr = "" - } - if !reflect.DeepEqual(want, pi) { - t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) - } - case <-time.After(10 * time.Second): - t.Fatalf("Timed out waiting for packet") - } - }) - } - } -} - -func TestBufConfigMaxLength(t *testing.T) { - got := 0 - for _, i := range BufConfig { - got += i - } - want := header.MaxIPPacketSize // maximum TCP packet size - if got < want { - t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want) - } -} - -func TestBufConfigFirst(t *testing.T) { - // The stack assumes that the TCP/IP header is enterily contained in the first view. - // Therefore, the first view needs to be large enough to contain the maximum TCP/IP - // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP). - want := 120 - got := BufConfig[0] - if got < want { - t.Errorf("first view has an invalid size: got %d, want >= %d", got, want) - } -} - -var capLengthTestCases = []struct { - comment string - config []int - n int - wantUsed int - wantLengths []int -}{ - { - comment: "Single slice", - config: []int{2}, - n: 1, - wantUsed: 1, - wantLengths: []int{1}, - }, - { - comment: "Multiple slices", - config: []int{1, 2}, - n: 2, - wantUsed: 2, - wantLengths: []int{1, 1}, - }, - { - comment: "Entire buffer", - config: []int{1, 2}, - n: 3, - wantUsed: 2, - wantLengths: []int{1, 2}, - }, - { - comment: "Entire buffer but not on the last slice", - config: []int{1, 2, 3}, - n: 3, - wantUsed: 2, - wantLengths: []int{1, 2, 3}, - }, -} - -func TestReadVDispatcherCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - // fd does not matter for this test. - d := readVDispatcher{fd: -1, e: &endpoint{}} - d.views = make([]buffer.View, len(c.config)) - d.iovecs = make([]syscall.Iovec, len(c.config)) - d.allocateViews(c.config) - - used := d.capViews(c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views)) - for i, v := range d.views { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } - } -} - -func TestRecvMMsgDispatcherCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - d := recvMMsgDispatcher{ - fd: -1, // fd does not matter for this test. - e: &endpoint{}, - views: make([][]buffer.View, 1), - iovecs: make([][]syscall.Iovec, 1), - msgHdrs: make([]rawfile.MMsgHdr, 1), - } - - for i, _ := range d.views { - d.views[i] = make([]buffer.View, len(c.config)) - } - for i := range d.iovecs { - d.iovecs[i] = make([]syscall.Iovec, len(c.config)) - } - for k, msgHdr := range d.msgHdrs { - msgHdr.Msg.Iov = &d.iovecs[k][0] - msgHdr.Msg.Iovlen = uint64(len(c.config)) - } - - d.allocateViews(c.config) - - used := d.capViews(0, c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views[0])) - for i, v := range d.views[0] { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } - - } -} diff --git a/pkg/tcpip/link/fdbased/fdbased_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go new file mode 100755 index 000000000..0555db528 --- /dev/null +++ b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package fdbased + diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD deleted file mode 100644 index 47a54845c..000000000 --- a/pkg/tcpip/link/loopback/BUILD +++ /dev/null @@ -1,15 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "loopback", - srcs = ["loopback.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", - visibility = ["//:sandbox"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/loopback/loopback_state_autogen.go b/pkg/tcpip/link/loopback/loopback_state_autogen.go new file mode 100755 index 000000000..87ec8cfc7 --- /dev/null +++ b/pkg/tcpip/link/loopback/loopback_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package loopback + diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD deleted file mode 100644 index ea12ef1ac..000000000 --- a/pkg/tcpip/link/muxed/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "muxed", - srcs = ["injectable.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "muxed_test", - size = "small", - srcs = ["injectable_test.go"], - embed = [":muxed"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go deleted file mode 100644 index a577a3d52..000000000 --- a/pkg/tcpip/link/muxed/injectable.go +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package muxed provides a muxed link endpoints. -package muxed - -import ( - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// InjectableEndpoint is an injectable multi endpoint. The endpoint has -// trivial routing rules that determine which InjectableEndpoint a given packet -// will be written to. Note that HandleLocal works differently for this -// endpoint (see WritePacket). -type InjectableEndpoint struct { - routes map[tcpip.Address]stack.InjectableLinkEndpoint - dispatcher stack.NetworkDispatcher -} - -// MTU implements stack.LinkEndpoint. -func (m *InjectableEndpoint) MTU() uint32 { - minMTU := ^uint32(0) - for _, endpoint := range m.routes { - if endpointMTU := endpoint.MTU(); endpointMTU < minMTU { - minMTU = endpointMTU - } - } - return minMTU -} - -// Capabilities implements stack.LinkEndpoint. -func (m *InjectableEndpoint) Capabilities() stack.LinkEndpointCapabilities { - minCapabilities := stack.LinkEndpointCapabilities(^uint(0)) - for _, endpoint := range m.routes { - minCapabilities &= endpoint.Capabilities() - } - return minCapabilities -} - -// MaxHeaderLength implements stack.LinkEndpoint. -func (m *InjectableEndpoint) MaxHeaderLength() uint16 { - minHeaderLen := ^uint16(0) - for _, endpoint := range m.routes { - if headerLen := endpoint.MaxHeaderLength(); headerLen < minHeaderLen { - minHeaderLen = headerLen - } - } - return minHeaderLen -} - -// LinkAddress implements stack.LinkEndpoint. -func (m *InjectableEndpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -// Attach implements stack.LinkEndpoint. -func (m *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - for _, endpoint := range m.routes { - endpoint.Attach(dispatcher) - } - m.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint. -func (m *InjectableEndpoint) IsAttached() bool { - return m.dispatcher != nil -} - -// Inject implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv) -} - -// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint -// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a -// route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, nil /* gso */, hdr, payload, protocol) - } - return tcpip.ErrNoRoute -} - -// WriteRawPacket writes outbound packets to the appropriate -// LinkInjectableEndpoint based on the dest address. -func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error { - endpoint, ok := m.routes[dest] - if !ok { - return tcpip.ErrNoRoute - } - return endpoint.WriteRawPacket(dest, packet) -} - -// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. -func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) (tcpip.LinkEndpointID, *InjectableEndpoint) { - e := &InjectableEndpoint{ - routes: routes, - } - return stack.RegisterLinkEndpoint(e), e -} diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go deleted file mode 100644 index 174b9330f..000000000 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package muxed - -import ( - "bytes" - "net" - "os" - "syscall" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -func TestInjectableEndpointRawDispatch(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - endpoint.WriteRawPacket(dstIP, []byte{0xFA}) - - buf := make([]byte, ipv4.MaxTotalSize) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func TestInjectableEndpointDispatch(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} - - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), ipv4.ProtocolNumber) - - buf := make([]byte, 6500) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA, 0xFB}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { - endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewView(0).ToVectorisedView(), ipv4.ProtocolNumber) - buf := make([]byte, 6500) - bytesRead, err := sock.Read(buf) - if err != nil { - t.Fatalf("Unable to read from socketpair: %v", err) - } - if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) { - t.Fatalf("Read %v from the socketpair, wanted %v", got, want) - } -} - -func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tcpip.Address) { - dstIP := tcpip.Address(net.ParseIP("1.2.3.4").To4()) - pair, err := syscall.Socketpair(syscall.AF_UNIX, - syscall.SOCK_SEQPACKET|syscall.SOCK_CLOEXEC|syscall.SOCK_NONBLOCK, 0) - if err != nil { - t.Fatal("Failed to create socket pair:", err) - } - _, underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) - routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint} - _, endpoint := NewInjectableEndpoint(routes) - return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP -} diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD deleted file mode 100644 index 6e3a7a9d7..000000000 --- a/pkg/tcpip/link/rawfile/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "rawfile", - srcs = [ - "blockingpoll_amd64.s", - "blockingpoll_amd64_unsafe.go", - "blockingpoll_unsafe.go", - "errors.go", - "rawfile_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", - visibility = [ - "//visibility:public", - ], - deps = ["//pkg/tcpip"], -) diff --git a/pkg/tcpip/link/rawfile/rawfile_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go new file mode 100755 index 000000000..662c04444 --- /dev/null +++ b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package rawfile + diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD deleted file mode 100644 index f2998aa98..000000000 --- a/pkg/tcpip/link/sharedmem/BUILD +++ /dev/null @@ -1,42 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "sharedmem", - srcs = [ - "rx.go", - "sharedmem.go", - "sharedmem_unsafe.go", - "tx.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", - visibility = [ - "//:sandbox", - ], - deps = [ - "//pkg/log", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/sharedmem/queue", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "sharedmem_test", - srcs = [ - "sharedmem_test.go", - ], - embed = [":sharedmem"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/sharedmem/pipe", - "//pkg/tcpip/link/sharedmem/queue", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD deleted file mode 100644 index 94725cb11..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "pipe", - srcs = [ - "pipe.go", - "pipe_unsafe.go", - "rx.go", - "tx.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", - visibility = ["//:sandbox"], -) - -go_test( - name = "pipe_test", - srcs = [ - "pipe_test.go", - ], - embed = [":pipe"], -) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe.go b/pkg/tcpip/link/sharedmem/pipe/pipe.go deleted file mode 100644 index 74c9f0311..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/pipe.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package pipe implements a shared memory ring buffer on which a single reader -// and a single writer can operate (read/write) concurrently. The ring buffer -// allows for data of different sizes to be written, and preserves the boundary -// of the written data. -// -// Example usage is as follows: -// -// wb := t.Push(20) -// // Write data to wb. -// t.Flush() -// -// rb := r.Pull() -// // Do something with data in rb. -// t.Flush() -package pipe - -import ( - "math" -) - -const ( - jump uint64 = math.MaxUint32 + 1 - offsetMask uint64 = math.MaxUint32 - revolutionMask uint64 = ^offsetMask - - sizeOfSlotHeader = 8 // sizeof(uint64) - slotFree uint64 = 1 << 63 - slotSizeMask uint64 = math.MaxUint32 -) - -// payloadToSlotSize calculates the total size of a slot based on its payload -// size. The total size is the header size, plus the payload size, plus padding -// if necessary to make the total size a multiple of sizeOfSlotHeader. -func payloadToSlotSize(payloadSize uint64) uint64 { - s := sizeOfSlotHeader + payloadSize - return (s + sizeOfSlotHeader - 1) &^ (sizeOfSlotHeader - 1) -} - -// slotToPayloadSize calculates the payload size of a slot based on the total -// size of the slot. This is only meant to be used when creating slots that -// don't carry information (e.g., free slots or wrap slots). -func slotToPayloadSize(offset uint64) uint64 { - return offset - sizeOfSlotHeader -} - -// pipe is a basic data structure used by both (transmit & receive) ends of a -// pipe. Indices into this pipe are split into two fields: offset, which counts -// the number of bytes from the beginning of the buffer, and revolution, which -// counts the number of times the index has wrapped around. -type pipe struct { - buffer []byte -} - -// init initializes the pipe buffer such that its size is a multiple of the size -// of the slot header. -func (p *pipe) init(b []byte) { - p.buffer = b[:len(b)&^(sizeOfSlotHeader-1)] -} - -// data returns a section of the buffer starting at the given index (which may -// include revolution information) and with the given size. -func (p *pipe) data(idx uint64, size uint64) []byte { - return p.buffer[(idx&offsetMask)+sizeOfSlotHeader:][:size] -} diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go deleted file mode 100644 index 59ef69a8b..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ /dev/null @@ -1,517 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "math/rand" - "reflect" - "runtime" - "sync" - "testing" -) - -func TestSimpleReadWrite(t *testing.T) { - // Check that a simple write can be properly read from the rx side. - tr := rand.New(rand.NewSource(99)) - rr := rand.New(rand.NewSource(99)) - - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - wb := tx.Push(10) - if wb == nil { - t.Fatalf("Push failed on empty pipe") - } - for i := range wb { - wb[i] = byte(tr.Intn(256)) - } - tx.Flush() - - var rx Rx - rx.Init(b) - rb := rx.Pull() - if len(rb) != 10 { - t.Fatalf("Bad buffer size returned: got %v, want %v", len(rb), 10) - } - - for i := range rb { - if v := byte(rr.Intn(256)); v != rb[i] { - t.Fatalf("Bad read buffer at index %v: got %v, want %v", i, rb[i], v) - } - } - rx.Flush() -} - -func TestEmptyRead(t *testing.T) { - // Check that pulling from an empty pipe fails. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestTooLargeWrite(t *testing.T) { - // Check that writes that are too large are properly rejected. - b := make([]byte, 96) - var tx Tx - tx.Init(b) - - if wb := tx.Push(96); wb != nil { - t.Fatalf("Write of 96 bytes succeeded on 96-byte pipe") - } - - if wb := tx.Push(88); wb != nil { - t.Fatalf("Write of 88 bytes succeeded on 96-byte pipe") - } - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } -} - -func TestFullWrite(t *testing.T) { - // Check that writes fail when the pipe is full. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } - - if wb := tx.Push(1); wb != nil { - t.Fatalf("Write succeeded on full pipe") - } -} - -func TestFullAndFlushedWrite(t *testing.T) { - // Check that writes fail when the pipe is full and has already been - // flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(80); wb == nil { - t.Fatalf("Write of 80 bytes failed on 96-byte pipe") - } - - tx.Flush() - - if wb := tx.Push(1); wb != nil { - t.Fatalf("Write succeeded on full pipe") - } -} - -func TestTxFlushTwice(t *testing.T) { - // Checks that a second consecutive tx flush is a no-op. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - // Make copy of original tx queue, flush it, then check that it didn't - // change. - orig := tx - tx.Flush() - - if !reflect.DeepEqual(orig, tx) { - t.Fatalf("Flush mutated tx pipe: got %v, want %v", tx, orig) - } -} - -func TestRxFlushTwice(t *testing.T) { - // Checks that a second consecutive rx flush is a no-op. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // Make copy of original rx queue, flush it, then check that it didn't - // change. - orig := rx - rx.Flush() - - if !reflect.DeepEqual(orig, rx) { - t.Fatalf("Flush mutated rx pipe: got %v, want %v", rx, orig) - } -} - -func TestWrapInMiddleOfTransaction(t *testing.T) { - // Check that writes are not flushed when we need to wrap the buffer - // around. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - // We haven't flushed yet, so pull must return nil. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - tx.Flush() - - // The two buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } -} - -func TestWriteAbort(t *testing.T) { - // Check that a read fails on a pipe that has had data pushed to it but - // has aborted the push. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(10); wb == nil { - t.Fatalf("Write failed on empty pipe") - } - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } - - tx.Abort() - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestWrappedWriteAbort(t *testing.T) { - // Check that writes are properly aborted even if the writes wrap - // around. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - // We haven't flushed yet, so pull must return nil. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - tx.Abort() - - // The pushes were aborted, so no data should be readable. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on non-flushed pipe") - } - - // Try the same transactions again, but flush this time. - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - tx.Flush() - - // The two buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } -} - -func TestEmptyReadOnNonFlushedWrite(t *testing.T) { - // Check that a read fails on a pipe that has had data pushed to it - // but not yet flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(10); wb == nil { - t.Fatalf("Write failed on empty pipe") - } - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } - - tx.Flush() - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull on failed on non-empty pipe") - } -} - -func TestPullAfterPullingEntirePipe(t *testing.T) { - // Check that Pull fails when the pipe is full, but all of it has - // already been pulled but not yet flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 3 - // buffers that will fill the pipe. - if wb := tx.Push(10); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - if wb := tx.Push(20); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - if wb := tx.Push(24); wb == nil { - t.Fatalf("Push failed on non-full pipe") - } - - tx.Flush() - - // The three buffers must be available now. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - - // Fourth pull must fail. - if rb := rx.Pull(); rb != nil { - t.Fatalf("Pull succeeded on empty pipe") - } -} - -func TestNoRoomToWrapOnPush(t *testing.T) { - // Check that Push fails when it tries to allocate room to add a wrap - // message. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - var rx Rx - rx.Init(b) - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // At this point the ring buffer is empty, but the write is at offset - // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 20, - // which won't fit (64+20+8+padding = 96, which wouldn't leave room for - // the padding), so it wraps around. - if wb := tx.Push(20); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - - tx.Flush() - - // Buffer offset is at 28. Try to write 70, which would require a wrap - // slot which cannot be created now. - if wb := tx.Push(70); wb != nil { - t.Fatalf("Push succeeded on pipe with no room for wrap message") - } -} - -func TestRxImplicitFlushOfWrapMessage(t *testing.T) { - // Check if the first read is that of a wrapping message, that it gets - // immediately flushed. - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - if wb := tx.Push(50); wb == nil { - t.Fatalf("Push failed on empty pipe") - } - tx.Flush() - - // This will cause a wrapping message to written. - if wb := tx.Push(60); wb != nil { - t.Fatalf("Push succeeded when there is no room in pipe") - } - - var rx Rx - rx.Init(b) - - // Read the first message. - if rb := rx.Pull(); rb == nil { - t.Fatalf("Pull failed on non-empty pipe") - } - rx.Flush() - - // This should fail because of the wrapping message is taking up space. - if wb := tx.Push(60); wb != nil { - t.Fatalf("Push succeeded when there is no room in pipe") - } - - // Try to read the next one. This should consume the wrapping message. - rx.Pull() - - // This must now succeed. - if wb := tx.Push(60); wb == nil { - t.Fatalf("Push failed on empty pipe") - } -} - -func TestConcurrentReaderWriter(t *testing.T) { - // Push a million buffers of random sizes and random contents. Check - // that buffers read match what was written. - tr := rand.New(rand.NewSource(99)) - rr := rand.New(rand.NewSource(99)) - - b := make([]byte, 100) - var tx Tx - tx.Init(b) - - var rx Rx - rx.Init(b) - - const count = 1000000 - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + tr.Intn(80) - wb := tx.Push(uint64(n)) - for wb == nil { - wb = tx.Push(uint64(n)) - } - - for j := range wb { - wb[j] = byte(tr.Intn(256)) - } - - tx.Flush() - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + rr.Intn(80) - rb := rx.Pull() - for rb == nil { - rb = rx.Pull() - } - - if n != len(rb) { - t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) - } - - for j := range rb { - if v := byte(rr.Intn(256)); v != rb[j] { - t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) - } - } - - rx.Flush() - } - }() - - wg.Wait() -} diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go deleted file mode 100644 index 62d17029e..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "sync/atomic" - "unsafe" -) - -func (p *pipe) write(idx uint64, v uint64) { - ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) - *ptr = v -} - -func (p *pipe) writeAtomic(idx uint64, v uint64) { - ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) - atomic.StoreUint64(ptr, v) -} - -func (p *pipe) readAtomic(idx uint64) uint64 { - ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) - return atomic.LoadUint64(ptr) -} diff --git a/pkg/tcpip/link/sharedmem/pipe/rx.go b/pkg/tcpip/link/sharedmem/pipe/rx.go deleted file mode 100644 index f22e533ac..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/rx.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -// Rx is the receive side of the shared memory ring buffer. -type Rx struct { - p pipe - - tail uint64 - head uint64 -} - -// Init initializes the receive end of the pipe. In the initial state, the next -// slot to be inspected is the very first one. -func (r *Rx) Init(b []byte) { - r.p.init(b) - r.tail = 0xfffffffe * jump - r.head = r.tail -} - -// Pull reads the next buffer from the pipe, returning nil if there isn't one -// currently available. -// -// The returned slice is available until Flush() is next called. After that, it -// must not be touched. -func (r *Rx) Pull() []byte { - if r.head == r.tail+jump { - // We've already pulled the whole pipe. - return nil - } - - header := r.p.readAtomic(r.head) - if header&slotFree != 0 { - // The next slot is free, we can't pull it yet. - return nil - } - - payloadSize := header & slotSizeMask - newHead := r.head + payloadToSlotSize(payloadSize) - headWrap := (r.head & revolutionMask) | uint64(len(r.p.buffer)) - - // Check if this is a wrapping slot. If that's the case, it carries no - // data, so we just skip it and try again from the first slot. - if int64(newHead-headWrap) >= 0 { - if int64(newHead-headWrap) > int64(jump) || newHead&offsetMask != 0 { - return nil - } - - if r.tail == r.head { - // If this is the first pull since the last Flush() - // call, we flush the state so that the sender can use - // this space if it needs to. - r.p.writeAtomic(r.head, slotFree|slotToPayloadSize(newHead-r.head)) - r.tail = newHead - } - - r.head = newHead - return r.Pull() - } - - // Grab the buffer before updating r.head. - b := r.p.data(r.head, payloadSize) - r.head = newHead - return b -} - -// Flush tells the transmitter that all buffers pulled since the last Flush() -// have been used, so the transmitter is free to used their slots for further -// transmission. -func (r *Rx) Flush() { - if r.head == r.tail { - return - } - r.p.writeAtomic(r.tail, slotFree|slotToPayloadSize(r.head-r.tail)) - r.tail = r.head -} - -// Bytes returns the byte slice on which the pipe operates. -func (r *Rx) Bytes() []byte { - return r.p.buffer -} diff --git a/pkg/tcpip/link/sharedmem/pipe/tx.go b/pkg/tcpip/link/sharedmem/pipe/tx.go deleted file mode 100644 index 9841eb231..000000000 --- a/pkg/tcpip/link/sharedmem/pipe/tx.go +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -// Tx is the transmit side of the shared memory ring buffer. -type Tx struct { - p pipe - maxPayloadSize uint64 - - head uint64 - tail uint64 - next uint64 - - tailHeader uint64 -} - -// Init initializes the transmit end of the pipe. In the initial state, the next -// slot to be written is the very first one, and the transmitter has the whole -// ring buffer available to it. -func (t *Tx) Init(b []byte) { - t.p.init(b) - // maxPayloadSize excludes the header of the payload, and the header - // of the wrapping message. - t.maxPayloadSize = uint64(len(t.p.buffer)) - 2*sizeOfSlotHeader - t.tail = 0xfffffffe * jump - t.next = t.tail - t.head = t.tail + jump - t.p.write(t.tail, slotFree) -} - -// Capacity determines how many records of the given size can be written to the -// pipe before it fills up. -func (t *Tx) Capacity(recordSize uint64) uint64 { - available := uint64(len(t.p.buffer)) - sizeOfSlotHeader - entryLen := payloadToSlotSize(recordSize) - return available / entryLen -} - -// Push reserves "payloadSize" bytes for transmission in the pipe. The caller -// populates the returned slice with the data to be transferred and enventually -// calls Flush() to make the data visible to the reader, or Abort() to make the -// pipe forget all Push() calls since the last Flush(). -// -// The returned slice is available until Flush() or Abort() is next called. -// After that, it must not be touched. -func (t *Tx) Push(payloadSize uint64) []byte { - // Fail request if we know we will never have enough room. - if payloadSize > t.maxPayloadSize { - return nil - } - - totalLen := payloadToSlotSize(payloadSize) - newNext := t.next + totalLen - nextWrap := (t.next & revolutionMask) | uint64(len(t.p.buffer)) - if int64(newNext-nextWrap) >= 0 { - // The new buffer would overflow the pipe, so we push a wrapping - // slot, then try to add the actual slot to the front of the - // pipe. - newNext = (newNext & revolutionMask) + jump - wrappingPayloadSize := slotToPayloadSize(newNext - t.next) - if !t.reclaim(newNext) { - return nil - } - - oldNext := t.next - t.next = newNext - if oldNext != t.tail { - t.p.write(oldNext, wrappingPayloadSize) - } else { - t.tailHeader = wrappingPayloadSize - t.Flush() - } - - newNext += totalLen - } - - // Check that we have enough room for the buffer. - if !t.reclaim(newNext) { - return nil - } - - if t.next != t.tail { - t.p.write(t.next, payloadSize) - } else { - t.tailHeader = payloadSize - } - - // Grab the buffer before updating t.next. - b := t.p.data(t.next, payloadSize) - t.next = newNext - - return b -} - -// reclaim attempts to advance the head until at least newNext. If the head is -// already at or beyond newNext, nothing happens and true is returned; otherwise -// it tries to reclaim slots that have already been consumed by the receive end -// of the pipe (they will be marked as free) and returns a boolean indicating -// whether it was successful in reclaiming enough slots. -func (t *Tx) reclaim(newNext uint64) bool { - for int64(newNext-t.head) > 0 { - // Can't reclaim if slot is not free. - header := t.p.readAtomic(t.head) - if header&slotFree == 0 { - return false - } - - payloadSize := header & slotSizeMask - newHead := t.head + payloadToSlotSize(payloadSize) - - // Check newHead is within bounds and valid. - if int64(newHead-t.tail) > int64(jump) || newHead&offsetMask >= uint64(len(t.p.buffer)) { - return false - } - - t.head = newHead - } - - return true -} - -// Abort causes all Push() calls since the last Flush() to be forgotten and -// therefore they will not be made visible to the receiver. -func (t *Tx) Abort() { - t.next = t.tail -} - -// Flush causes all buffers pushed since the last Flush() [or Abort(), whichever -// is the most recent] to be made visible to the receiver. -func (t *Tx) Flush() { - if t.next == t.tail { - // Nothing to do if there are no pushed buffers. - return - } - - if t.next != t.head { - // The receiver will spin in t.next, so we must make sure that - // the slotFree bit is set. - t.p.write(t.next, slotFree) - } - - t.p.writeAtomic(t.tail, t.tailHeader) - t.tail = t.next -} - -// Bytes returns the byte slice on which the pipe operates. -func (t *Tx) Bytes() []byte { - return t.p.buffer -} diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD deleted file mode 100644 index 160a8f864..000000000 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "queue", - srcs = [ - "rx.go", - "tx.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/tcpip/link/sharedmem/pipe", - ], -) - -go_test( - name = "queue_test", - srcs = [ - "queue_test.go", - ], - embed = [":queue"], - deps = [ - "//pkg/tcpip/link/sharedmem/pipe", - ], -) diff --git a/pkg/tcpip/link/sharedmem/queue/queue_test.go b/pkg/tcpip/link/sharedmem/queue/queue_test.go deleted file mode 100644 index 9a0aad5d7..000000000 --- a/pkg/tcpip/link/sharedmem/queue/queue_test.go +++ /dev/null @@ -1,517 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package queue - -import ( - "encoding/binary" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" -) - -func TestBasicTxQueue(t *testing.T) { - // Tests that a basic transmit on a queue works, and that completion - // gets properly reported as well. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Enqueue two buffers. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue failed on empty queue") - } - - // Check the contents of the pipe. - d := rxp.Pull() - if d == nil { - t.Fatalf("Tx pipe is empty after Enqueue") - } - - want := []byte{ - 234, 3, 0, 0, 0, 0, 0, 0, // id - 100, 0, 0, 0, // total size - 0, 0, 0, 0, // reserved - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - } - - if !reflect.DeepEqual(want, d) { - t.Fatalf("Bad posted packet: got %v, want %v", d, want) - } - - rxp.Flush() - - // Check that there are no completions yet. - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Packet reported as completed too soon") - } - - // Post a completion. - d = txp.Push(8) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - binary.LittleEndian.PutUint64(d, usedID) - txp.Flush() - - // Check that completion is properly reported. - id, ok := q.CompletedPacket() - if !ok { - t.Fatalf("Completion not reported") - } - - if id != usedID { - t.Fatalf("Bad completion id: got %v, want %v", id, usedID) - } -} - -func TestBasicRxQueue(t *testing.T) { - // Tests that a basic receive on a queue works. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post two buffers. - b := []RxBuffer{ - {100, 60, 1077, 0}, - {200, 40, 2123, 0}, - } - - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on empty queue") - } - - // Check the contents of the pipe. - want := [][]byte{ - { - 100, 0, 0, 0, 0, 0, 0, 0, // Offset1 - 60, 0, 0, 0, // Size1 - 0, 0, 0, 0, // Remaining in group 1 - 0, 0, 0, 0, 0, 0, 0, 0, // User data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - }, - { - 200, 0, 0, 0, 0, 0, 0, 0, // Offset2 - 40, 0, 0, 0, // Size2 - 0, 0, 0, 0, // Remaining in group 2 - 0, 0, 0, 0, 0, 0, 0, 0, // User data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }, - } - - for i := range b { - d := rxp.Pull() - if d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - - if !reflect.DeepEqual(want[i], d) { - t.Fatalf("Bad posted packet: got %v, want %v", d, want[i]) - } - - rxp.Flush() - } - - // Check that there are no completions. - if _, n := q.Dequeue(nil); n != 0 { - t.Fatalf("Packet reported as received too soon") - } - - // Post a completion. - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - - // Check that completion is properly reported. - bufs, n := q.Dequeue(nil) - if n != 100 { - t.Fatalf("Bad packet size: got %v, want %v", n, 100) - } - - if !reflect.DeepEqual(bufs, b) { - t.Fatalf("Bad returned buffers: got %v, want %v", bufs, b) - } -} - -func TestBadTxCompletion(t *testing.T) { - // Check that tx completions with bad sizes are properly ignored. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Post a completion that is too short, and check that it is ignored. - if d := txp.Push(7); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion that is too long, and check that it is ignored. - if d := txp.Push(10); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if _, ok := q.CompletedPacket(); ok { - t.Fatalf("Bad completion not ignored") - } -} - -func TestBadRxCompletion(t *testing.T) { - // Check that bad rx completions are properly ignored. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post a completion that is too short, and check that it is ignored. - if d := txp.Push(7); d == nil { - t.Fatalf("Unable to push to rx pipe") - } - txp.Flush() - - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion whose buffer sizes add up to less than the total - // size. - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 10, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 10, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } - - // Post a completion whose buffer sizes will cause a 32-bit overflow, - // but adds up to the right number. - d = txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 255, 255, 255, 255, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 101, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - if b, _ := q.Dequeue(nil); b != nil { - t.Fatalf("Bad completion not ignored") - } -} - -func TestFillTxPipe(t *testing.T) { - // Check that transmitting a new buffer when the buffer pipe is full - // fails gracefully. - pb1 := make([]byte, 104) - pb2 := make([]byte, 104) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Transmit twice, which should fill the tx pipe. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - for i := uint64(0); i < 2; i++ { - if !q.Enqueue(usedID+i, usedTotalSize, 2, &b[0]) { - t.Fatalf("Failed to transmit buffer") - } - } - - // Transmit another packet now that the tx pipe is full. - if q.Enqueue(usedID+2, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue succeeded when tx pipe is full") - } -} - -func TestFillRxPipe(t *testing.T) { - // Check that posting a new buffer when the buffer pipe is full fails - // gracefully. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Post a buffer twice, it should fill the tx pipe. - b := []RxBuffer{ - {100, 60, 1077, 0}, - } - - for i := 0; i < 2; i++ { - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on non-full queue") - } - } - - // Post another buffer now that the tx pipe is full. - if q.PostBuffers(b) { - t.Fatalf("PostBuffers succeeded on full queue") - } -} - -func TestLotsOfTransmissions(t *testing.T) { - // Make sure pipes are being properly flushed when transmitting packets. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Tx - q.Init(pb1, pb2) - - // Prepare packet with two buffers. - b := []TxBuffer{ - {nil, 100, 60}, - {nil, 200, 40}, - } - - b[0].Next = &b[1] - - const usedID = 1002 - const usedTotalSize = 100 - - // Post 100000 packets and completions. - for i := 100000; i > 0; i-- { - if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { - t.Fatalf("Enqueue failed on non-full queue") - } - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after Enqueue") - } - rxp.Flush() - - d := txp.Push(8) - if d == nil { - t.Fatalf("Unable to write to rx pipe") - } - binary.LittleEndian.PutUint64(d, usedID) - txp.Flush() - if _, ok := q.CompletedPacket(); !ok { - t.Fatalf("Completion not returned") - } - } -} - -func TestLotsOfReceptions(t *testing.T) { - // Make sure pipes are being properly flushed when receiving packets. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var rxp pipe.Rx - rxp.Init(pb1) - - var txp pipe.Tx - txp.Init(pb2) - - var q Rx - q.Init(pb1, pb2, nil) - - // Prepare for posting two buffers. - b := []RxBuffer{ - {100, 60, 1077, 0}, - {200, 40, 2123, 0}, - } - - // Post 100000 buffers and completions. - for i := 100000; i > 0; i-- { - if !q.PostBuffers(b) { - t.Fatalf("PostBuffers failed on non-full queue") - } - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - rxp.Flush() - - if d := rxp.Pull(); d == nil { - t.Fatalf("Tx pipe is empty after PostBuffers") - } - rxp.Flush() - - d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) - if d == nil { - t.Fatalf("Unable to push to rx pipe") - } - - copy(d, []byte{ - 100, 0, 0, 0, // packet size - 0, 0, 0, 0, // reserved - - 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 - 60, 0, 0, 0, // size 1 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 - 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 - - 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 - 40, 0, 0, 0, // size 2 - 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 - 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 - }) - - txp.Flush() - - if _, n := q.Dequeue(nil); n == 0 { - t.Fatalf("Dequeue failed when there is a completion") - } - } -} - -func TestRxEnableNotification(t *testing.T) { - // Check that enabling nofifications results in properly updated state. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var state uint32 - var q Rx - q.Init(pb1, pb2, &state) - - q.EnableNotification() - if state != eventFDEnabled { - t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDEnabled) - } -} - -func TestRxDisableNotification(t *testing.T) { - // Check that disabling nofifications results in properly updated state. - pb1 := make([]byte, 100) - pb2 := make([]byte, 100) - - var state uint32 - var q Rx - q.Init(pb1, pb2, &state) - - q.DisableNotification() - if state != eventFDDisabled { - t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDDisabled) - } -} diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go deleted file mode 100644 index 696e6c9e5..000000000 --- a/pkg/tcpip/link/sharedmem/queue/rx.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package queue provides the implementation of transmit and receive queues -// based on shared memory ring buffers. -package queue - -import ( - "encoding/binary" - "sync/atomic" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" -) - -const ( - // Offsets within a posted buffer. - postedOffset = 0 - postedSize = 8 - postedRemainingInGroup = 12 - postedUserData = 16 - postedID = 24 - - sizeOfPostedBuffer = 32 - - // Offsets within a received packet header. - consumedPacketSize = 0 - consumedPacketReserved = 4 - - sizeOfConsumedPacketHeader = 8 - - // Offsets within a consumed buffer. - consumedOffset = 0 - consumedSize = 8 - consumedUserData = 12 - consumedID = 20 - - sizeOfConsumedBuffer = 28 - - // The following are the allowed states of the shared data area. - eventFDUninitialized = 0 - eventFDDisabled = 1 - eventFDEnabled = 2 -) - -// RxBuffer is the descriptor of a receive buffer. -type RxBuffer struct { - Offset uint64 - Size uint32 - ID uint64 - UserData uint64 -} - -// Rx is a receive queue. It is implemented with one tx and one rx pipe: the tx -// pipe is used to "post" buffers, while the rx pipe is used to receive packets -// whose contents have been written to previously posted buffers. -// -// This struct is thread-compatible. -type Rx struct { - tx pipe.Tx - rx pipe.Rx - sharedEventFDState *uint32 -} - -// Init initializes the receive queue with the given pipes, and shared state -// pointer -- the latter is used to enable/disable eventfd notifications. -func (r *Rx) Init(tx, rx []byte, sharedEventFDState *uint32) { - r.sharedEventFDState = sharedEventFDState - r.tx.Init(tx) - r.rx.Init(rx) -} - -// EnableNotification updates the shared state such that the peer will notify -// the eventfd when there are packets to be dequeued. -func (r *Rx) EnableNotification() { - atomic.StoreUint32(r.sharedEventFDState, eventFDEnabled) -} - -// DisableNotification updates the shared state such that the peer will not -// notify the eventfd. -func (r *Rx) DisableNotification() { - atomic.StoreUint32(r.sharedEventFDState, eventFDDisabled) -} - -// PostedBuffersLimit returns the maximum number of buffers that can be posted -// before the tx queue fills up. -func (r *Rx) PostedBuffersLimit() uint64 { - return r.tx.Capacity(sizeOfPostedBuffer) -} - -// PostBuffers makes the given buffers available for receiving data from the -// peer. Once they are posted, the peer is free to write to them and will -// eventually post them back for consumption. -func (r *Rx) PostBuffers(buffers []RxBuffer) bool { - for i := range buffers { - b := r.tx.Push(sizeOfPostedBuffer) - if b == nil { - r.tx.Abort() - return false - } - - pb := &buffers[i] - binary.LittleEndian.PutUint64(b[postedOffset:], pb.Offset) - binary.LittleEndian.PutUint32(b[postedSize:], pb.Size) - binary.LittleEndian.PutUint32(b[postedRemainingInGroup:], 0) - binary.LittleEndian.PutUint64(b[postedUserData:], pb.UserData) - binary.LittleEndian.PutUint64(b[postedID:], pb.ID) - } - - r.tx.Flush() - - return true -} - -// Dequeue receives buffers that have been previously posted by PostBuffers() -// and that have been filled by the peer and posted back. -// -// This is similar to append() in that new buffers are appended to "bufs", with -// reallocation only if "bufs" doesn't have enough capacity. -func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) { - for { - outBufs := bufs - - // Pull the next descriptor from the rx pipe. - b := r.rx.Pull() - if b == nil { - return bufs, 0 - } - - if len(b) < sizeOfConsumedPacketHeader { - log.Warningf("Ignoring packet header: size (%v) is less than header size (%v)", len(b), sizeOfConsumedPacketHeader) - r.rx.Flush() - continue - } - - totalDataSize := binary.LittleEndian.Uint32(b[consumedPacketSize:]) - - // Calculate the number of buffer descriptors and copy them - // over to the output. - count := (len(b) - sizeOfConsumedPacketHeader) / sizeOfConsumedBuffer - offset := sizeOfConsumedPacketHeader - buffersSize := uint32(0) - for i := count; i > 0; i-- { - s := binary.LittleEndian.Uint32(b[offset+consumedSize:]) - buffersSize += s - if buffersSize < s { - // The buffer size overflows an unsigned 32-bit - // integer, so break out and force it to be - // ignored. - totalDataSize = 1 - buffersSize = 0 - break - } - - outBufs = append(outBufs, RxBuffer{ - Offset: binary.LittleEndian.Uint64(b[offset+consumedOffset:]), - Size: s, - ID: binary.LittleEndian.Uint64(b[offset+consumedID:]), - }) - - offset += sizeOfConsumedBuffer - } - - r.rx.Flush() - - if buffersSize < totalDataSize { - // The descriptor is corrupted, ignore it. - log.Warningf("Ignoring packet: actual data size (%v) less than expected size (%v)", buffersSize, totalDataSize) - continue - } - - return outBufs, totalDataSize - } -} - -// Bytes returns the byte slices on which the queue operates. -func (r *Rx) Bytes() (tx, rx []byte) { - return r.tx.Bytes(), r.rx.Bytes() -} - -// DecodeRxBufferHeader decodes the header of a buffer posted on an rx queue. -func DecodeRxBufferHeader(b []byte) RxBuffer { - return RxBuffer{ - Offset: binary.LittleEndian.Uint64(b[postedOffset:]), - Size: binary.LittleEndian.Uint32(b[postedSize:]), - ID: binary.LittleEndian.Uint64(b[postedID:]), - UserData: binary.LittleEndian.Uint64(b[postedUserData:]), - } -} - -// RxCompletionSize returns the number of bytes needed to encode an rx -// completion containing "count" buffers. -func RxCompletionSize(count int) uint64 { - return sizeOfConsumedPacketHeader + uint64(count)*sizeOfConsumedBuffer -} - -// EncodeRxCompletion encodes an rx completion header. -func EncodeRxCompletion(b []byte, size, reserved uint32) { - binary.LittleEndian.PutUint32(b[consumedPacketSize:], size) - binary.LittleEndian.PutUint32(b[consumedPacketReserved:], reserved) -} - -// EncodeRxCompletionBuffer encodes the i-th rx completion buffer header. -func EncodeRxCompletionBuffer(b []byte, i int, rxb RxBuffer) { - b = b[RxCompletionSize(i):] - binary.LittleEndian.PutUint64(b[consumedOffset:], rxb.Offset) - binary.LittleEndian.PutUint32(b[consumedSize:], rxb.Size) - binary.LittleEndian.PutUint64(b[consumedUserData:], rxb.UserData) - binary.LittleEndian.PutUint64(b[consumedID:], rxb.ID) -} diff --git a/pkg/tcpip/link/sharedmem/queue/tx.go b/pkg/tcpip/link/sharedmem/queue/tx.go deleted file mode 100644 index beffe807b..000000000 --- a/pkg/tcpip/link/sharedmem/queue/tx.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package queue - -import ( - "encoding/binary" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" -) - -const ( - // Offsets within a packet header. - packetID = 0 - packetSize = 8 - packetReserved = 12 - - sizeOfPacketHeader = 16 - - // Offsets with a buffer descriptor - bufferOffset = 0 - bufferSize = 8 - - sizeOfBufferDescriptor = 12 -) - -// TxBuffer is the descriptor of a transmit buffer. -type TxBuffer struct { - Next *TxBuffer - Offset uint64 - Size uint32 -} - -// Tx is a transmit queue. It is implemented with one tx and one rx pipe: the -// tx pipe is used to request the transmission of packets, while the rx pipe -// is used to receive which transmissions have completed. -// -// This struct is thread-compatible. -type Tx struct { - tx pipe.Tx - rx pipe.Rx -} - -// Init initializes the transmit queue with the given pipes. -func (t *Tx) Init(tx, rx []byte) { - t.tx.Init(tx) - t.rx.Init(rx) -} - -// Enqueue queues the given linked list of buffers for transmission as one -// packet. While it is queued, the caller must not modify them. -func (t *Tx) Enqueue(id uint64, totalDataLen, bufferCount uint32, buffer *TxBuffer) bool { - // Reserve room in the tx pipe. - totalLen := sizeOfPacketHeader + uint64(bufferCount)*sizeOfBufferDescriptor - - b := t.tx.Push(totalLen) - if b == nil { - return false - } - - // Initialize the packet and buffer descriptors. - binary.LittleEndian.PutUint64(b[packetID:], id) - binary.LittleEndian.PutUint32(b[packetSize:], totalDataLen) - binary.LittleEndian.PutUint32(b[packetReserved:], 0) - - offset := sizeOfPacketHeader - for i := bufferCount; i != 0; i-- { - binary.LittleEndian.PutUint64(b[offset+bufferOffset:], buffer.Offset) - binary.LittleEndian.PutUint32(b[offset+bufferSize:], buffer.Size) - offset += sizeOfBufferDescriptor - buffer = buffer.Next - } - - t.tx.Flush() - - return true -} - -// CompletedPacket returns the id of the last completed transmission. The -// returned id, if any, refers to a value passed on a previous call to -// Enqueue(). -func (t *Tx) CompletedPacket() (id uint64, ok bool) { - for { - b := t.rx.Pull() - if b == nil { - return 0, false - } - - if len(b) != 8 { - t.rx.Flush() - log.Warningf("Ignoring completed packet: size (%v) is less than expected (%v)", len(b), 8) - continue - } - - v := binary.LittleEndian.Uint64(b) - - t.rx.Flush() - - return v, true - } -} - -// Bytes returns the byte slices on which the queue operates. -func (t *Tx) Bytes() (tx, rx []byte) { - return t.tx.Bytes(), t.rx.Bytes() -} - -// TxPacketInfo holds information about a packet sent on a tx queue. -type TxPacketInfo struct { - ID uint64 - Size uint32 - Reserved uint32 - BufferCount int -} - -// DecodeTxPacketHeader decodes the header of a packet sent over a tx queue. -func DecodeTxPacketHeader(b []byte) TxPacketInfo { - return TxPacketInfo{ - ID: binary.LittleEndian.Uint64(b[packetID:]), - Size: binary.LittleEndian.Uint32(b[packetSize:]), - Reserved: binary.LittleEndian.Uint32(b[packetReserved:]), - BufferCount: (len(b) - sizeOfPacketHeader) / sizeOfBufferDescriptor, - } -} - -// DecodeTxBufferHeader decodes the header of the i-th buffer of a packet sent -// over a tx queue. -func DecodeTxBufferHeader(b []byte, i int) TxBuffer { - b = b[sizeOfPacketHeader+i*sizeOfBufferDescriptor:] - return TxBuffer{ - Offset: binary.LittleEndian.Uint64(b[bufferOffset:]), - Size: binary.LittleEndian.Uint32(b[bufferSize:]), - } -} - -// EncodeTxCompletion encodes a tx completion header. -func EncodeTxCompletion(b []byte, id uint64) { - binary.LittleEndian.PutUint64(b, id) -} diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go deleted file mode 100644 index eec11e4cb..000000000 --- a/pkg/tcpip/link/sharedmem/rx.go +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package sharedmem - -import ( - "sync/atomic" - "syscall" - - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" -) - -// rx holds all state associated with an rx queue. -type rx struct { - data []byte - sharedData []byte - q queue.Rx - eventFD int -} - -// init initializes all state needed by the rx queue based on the information -// provided. -// -// The caller always retains ownership of all file descriptors passed in. The -// queue implementation will duplicate any that it may need in the future. -func (r *rx) init(mtu uint32, c *QueueConfig) error { - // Map in all buffers. - txPipe, err := getBuffer(c.TxPipeFD) - if err != nil { - return err - } - - rxPipe, err := getBuffer(c.RxPipeFD) - if err != nil { - syscall.Munmap(txPipe) - return err - } - - data, err := getBuffer(c.DataFD) - if err != nil { - syscall.Munmap(txPipe) - syscall.Munmap(rxPipe) - return err - } - - sharedData, err := getBuffer(c.SharedDataFD) - if err != nil { - syscall.Munmap(txPipe) - syscall.Munmap(rxPipe) - syscall.Munmap(data) - return err - } - - // Duplicate the eventFD so that caller can close it but we can still - // use it. - efd, err := syscall.Dup(c.EventFD) - if err != nil { - syscall.Munmap(txPipe) - syscall.Munmap(rxPipe) - syscall.Munmap(data) - syscall.Munmap(sharedData) - return err - } - - // Set the eventfd as non-blocking. - if err := syscall.SetNonblock(efd, true); err != nil { - syscall.Munmap(txPipe) - syscall.Munmap(rxPipe) - syscall.Munmap(data) - syscall.Munmap(sharedData) - syscall.Close(efd) - return err - } - - // Initialize state based on buffers. - r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData)) - r.data = data - r.eventFD = efd - r.sharedData = sharedData - - return nil -} - -// cleanup releases all resources allocated during init(). It must only be -// called if init() has previously succeeded. -func (r *rx) cleanup() { - a, b := r.q.Bytes() - syscall.Munmap(a) - syscall.Munmap(b) - - syscall.Munmap(r.data) - syscall.Munmap(r.sharedData) - syscall.Close(r.eventFD) -} - -// postAndReceive posts the provided buffers (if any), and then tries to read -// from the receive queue. -// -// Capacity permitting, it reuses the posted buffer slice to store the buffers -// that were read as well. -// -// This function will block if there aren't any available packets. -func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.RxBuffer, uint32) { - // Post the buffers first. If we cannot post, sleep until we can. We - // never post more than will fit concurrently, so it's safe to wait - // until enough room is available. - if len(b) != 0 && !r.q.PostBuffers(b) { - r.q.EnableNotification() - for !r.q.PostBuffers(b) { - var tmp [8]byte - rawfile.BlockingRead(r.eventFD, tmp[:]) - if atomic.LoadUint32(stopRequested) != 0 { - r.q.DisableNotification() - return nil, 0 - } - } - r.q.DisableNotification() - } - - // Read the next set of descriptors. - b, n := r.q.Dequeue(b[:0]) - if len(b) != 0 { - return b, n - } - - // Data isn't immediately available. Enable eventfd notifications. - r.q.EnableNotification() - for { - b, n = r.q.Dequeue(b) - if len(b) != 0 { - break - } - - // Wait for notification. - var tmp [8]byte - rawfile.BlockingRead(r.eventFD, tmp[:]) - if atomic.LoadUint32(stopRequested) != 0 { - r.q.DisableNotification() - return nil, 0 - } - } - r.q.DisableNotification() - - return b, n -} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go deleted file mode 100644 index 834ea5c40..000000000 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// Package sharedmem provides the implemention of data-link layer endpoints -// backed by shared memory. -// -// Shared memory endpoints can be used in the networking stack by calling New() -// to create a new endpoint, and then passing it as an argument to -// Stack.CreateNIC(). -package sharedmem - -import ( - "sync" - "sync/atomic" - "syscall" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// QueueConfig holds all the file descriptors needed to describe a tx or rx -// queue over shared memory. It is used when creating new shared memory -// endpoints to describe tx and rx queues. -type QueueConfig struct { - // DataFD is a file descriptor for the file that contains the data to - // be transmitted via this queue. Descriptors contain offsets within - // this file. - DataFD int - - // EventFD is a file descriptor for the event that is signaled when - // data is becomes available in this queue. - EventFD int - - // TxPipeFD is a file descriptor for the tx pipe associated with the - // queue. - TxPipeFD int - - // RxPipeFD is a file descriptor for the rx pipe associated with the - // queue. - RxPipeFD int - - // SharedDataFD is a file descriptor for the file that contains shared - // state between the two ends of the queue. This data specifies, for - // example, whether EventFD signaling is enabled or disabled. - SharedDataFD int -} - -type endpoint struct { - // mtu (maximum transmission unit) is the maximum size of a packet. - mtu uint32 - - // bufferSize is the size of each individual buffer. - bufferSize uint32 - - // addr is the local address of this endpoint. - addr tcpip.LinkAddress - - // rx is the receive queue. - rx rx - - // stopRequested is to be accessed atomically only, and determines if - // the worker goroutines should stop. - stopRequested uint32 - - // Wait group used to indicate that all workers have stopped. - completed sync.WaitGroup - - // mu protects the following fields. - mu sync.Mutex - - // tx is the transmit queue. - tx tx - - // workerStarted specifies whether the worker goroutine was started. - workerStarted bool -} - -// New creates a new shared-memory-based endpoint. Buffers will be broken up -// into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) { - e := &endpoint{ - mtu: mtu, - bufferSize: bufferSize, - addr: addr, - } - - if err := e.tx.init(bufferSize, &tx); err != nil { - return 0, err - } - - if err := e.rx.init(bufferSize, &rx); err != nil { - e.tx.cleanup() - return 0, err - } - - return stack.RegisterLinkEndpoint(e), nil -} - -// Close frees all resources associated with the endpoint. -func (e *endpoint) Close() { - // Tell dispatch goroutine to stop, then write to the eventfd so that - // it wakes up in case it's sleeping. - atomic.StoreUint32(&e.stopRequested, 1) - syscall.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Cleanup the queues inline if the worker hasn't started yet; we also - // know it won't start from now on because stopRequested is set to 1. - e.mu.Lock() - workerPresent := e.workerStarted - e.mu.Unlock() - - if !workerPresent { - e.tx.cleanup() - e.rx.cleanup() - } -} - -// Wait waits until all workers have stopped after a Close() call. -func (e *endpoint) Wait() { - e.completed.Wait() -} - -// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that -// reads packets from the rx queue. -func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.mu.Lock() - if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 { - e.workerStarted = true - e.completed.Add(1) - // Link endpoints are not savable. When transportation endpoints - // are saved, they stop sending outgoing packets and all - // incoming packets are rejected. - go e.dispatchLoop(dispatcher) // S/R-SAFE: see above. - } - e.mu.Unlock() -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *endpoint) IsAttached() bool { - e.mu.Lock() - defer e.mu.Unlock() - return e.workerStarted -} - -// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized -// during construction. -func (e *endpoint) MTU() uint32 { - return e.mtu - header.EthernetMinimumSize -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 -} - -// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the -// ethernet frame header size. -func (*endpoint) MaxHeaderLength() uint16 { - return header.EthernetMinimumSize -} - -// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local -// link address. -func (e *endpoint) LinkAddress() tcpip.LinkAddress { - return e.addr -} - -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - // Add the ethernet header here. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) - ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, - Type: protocol, - } - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - - v := payload.ToView() - // Transmit the packet. - e.mu.Lock() - ok := e.tx.transmit(hdr.View(), v) - e.mu.Unlock() - - if !ok { - return tcpip.ErrWouldBlock - } - - return nil -} - -// dispatchLoop reads packets from the rx queue in a loop and dispatches them -// to the network stack. -func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { - // Post initial set of buffers. - limit := e.rx.q.PostedBuffersLimit() - if l := uint64(len(e.rx.data)) / uint64(e.bufferSize); limit > l { - limit = l - } - for i := uint64(0); i < limit; i++ { - b := queue.RxBuffer{ - Offset: i * uint64(e.bufferSize), - Size: e.bufferSize, - ID: i, - } - if !e.rx.q.PostBuffers([]queue.RxBuffer{b}) { - log.Warningf("Unable to post %v-th buffer", i) - } - } - - // Read in a loop until a stop is requested. - var rxb []queue.RxBuffer - for atomic.LoadUint32(&e.stopRequested) == 0 { - var n uint32 - rxb, n = e.rx.postAndReceive(rxb, &e.stopRequested) - - // Copy data from the shared area to its own buffer, then - // prepare to repost the buffer. - b := make([]byte, n) - offset := uint32(0) - for i := range rxb { - copy(b[offset:], e.rx.data[rxb[i].Offset:][:rxb[i].Size]) - offset += rxb[i].Size - - rxb[i].Size = e.bufferSize - } - - if n < header.EthernetMinimumSize { - continue - } - - // Send packet up the stack. - eth := header.Ethernet(b) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView()) - } - - // Clean state. - e.tx.cleanup() - e.rx.cleanup() - - e.completed.Done() -} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go deleted file mode 100644 index 75f08dfb0..000000000 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ /dev/null @@ -1,776 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package sharedmem - -import ( - "bytes" - "io/ioutil" - "math/rand" - "os" - "strings" - "sync" - "syscall" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe" - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - localLinkAddr = "\xde\xad\xbe\xef\x56\x78" - remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34" - - queueDataSize = 1024 * 1024 - queuePipeSize = 4096 -) - -type queueBuffers struct { - data []byte - rx pipe.Tx - tx pipe.Rx -} - -func initQueue(t *testing.T, q *queueBuffers, c *QueueConfig) { - // Prepare tx pipe. - b, err := getBuffer(c.TxPipeFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } - q.tx.Init(b) - - // Prepare rx pipe. - b, err = getBuffer(c.RxPipeFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } - q.rx.Init(b) - - // Get data slice. - q.data, err = getBuffer(c.DataFD) - if err != nil { - t.Fatalf("getBuffer failed: %v", err) - } -} - -func (q *queueBuffers) cleanup() { - syscall.Munmap(q.tx.Bytes()) - syscall.Munmap(q.rx.Bytes()) - syscall.Munmap(q.data) -} - -type packetInfo struct { - addr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - vv buffer.VectorisedView -} - -type testContext struct { - t *testing.T - ep *endpoint - txCfg QueueConfig - rxCfg QueueConfig - txq queueBuffers - rxq queueBuffers - - packetCh chan struct{} - mu sync.Mutex - packets []packetInfo -} - -func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress) *testContext { - var err error - c := &testContext{ - t: t, - packetCh: make(chan struct{}, 1000000), - } - c.txCfg = createQueueFDs(t, queueSizes{ - dataSize: queueDataSize, - txPipeSize: queuePipeSize, - rxPipeSize: queuePipeSize, - sharedDataSize: 4096, - }) - - c.rxCfg = createQueueFDs(t, queueSizes{ - dataSize: queueDataSize, - txPipeSize: queuePipeSize, - rxPipeSize: queuePipeSize, - sharedDataSize: 4096, - }) - - initQueue(t, &c.txq, &c.txCfg) - initQueue(t, &c.rxq, &c.rxCfg) - - id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - c.ep = stack.FindLinkEndpoint(id).(*endpoint) - c.ep.Attach(c) - - return c -} - -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - c.mu.Lock() - c.packets = append(c.packets, packetInfo{ - addr: remoteLinkAddr, - proto: proto, - vv: vv.Clone(nil), - }) - c.mu.Unlock() - - c.packetCh <- struct{}{} -} - -func (c *testContext) cleanup() { - c.ep.Close() - closeFDs(&c.txCfg) - closeFDs(&c.rxCfg) - c.txq.cleanup() - c.rxq.cleanup() -} - -func (c *testContext) waitForPackets(n int, to <-chan time.Time, errorStr string) { - for i := 0; i < n; i++ { - select { - case <-c.packetCh: - case <-to: - c.t.Fatalf(errorStr) - } - } -} - -func (c *testContext) pushRxCompletion(size uint32, bs []queue.RxBuffer) { - b := c.rxq.rx.Push(queue.RxCompletionSize(len(bs))) - queue.EncodeRxCompletion(b, size, 0) - for i := range bs { - queue.EncodeRxCompletionBuffer(b, i, queue.RxBuffer{ - Offset: bs[i].Offset, - Size: bs[i].Size, - ID: bs[i].ID, - }) - } -} - -func randomFill(b []byte) { - for i := range b { - b[i] = byte(rand.Intn(256)) - } -} - -func shuffle(b []int) { - for i := len(b) - 1; i >= 0; i-- { - j := rand.Intn(i + 1) - b[i], b[j] = b[j], b[i] - } -} - -func createFile(t *testing.T, size int64, initQueue bool) int { - tmpDir := os.Getenv("TEST_TMPDIR") - if tmpDir == "" { - tmpDir = os.Getenv("TMPDIR") - } - f, err := ioutil.TempFile(tmpDir, "sharedmem_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - defer f.Close() - syscall.Unlink(f.Name()) - - if initQueue { - // Write the "slot-free" flag in the initial queue. - _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0) - if err != nil { - t.Fatalf("WriteAt failed: %v", err) - } - } - - fd, err := syscall.Dup(int(f.Fd())) - if err != nil { - t.Fatalf("Dup failed: %v", err) - } - - if err := syscall.Ftruncate(fd, size); err != nil { - syscall.Close(fd) - t.Fatalf("Ftruncate failed: %v", err) - } - - return fd -} - -func closeFDs(c *QueueConfig) { - syscall.Close(c.DataFD) - syscall.Close(c.EventFD) - syscall.Close(c.TxPipeFD) - syscall.Close(c.RxPipeFD) - syscall.Close(c.SharedDataFD) -} - -type queueSizes struct { - dataSize int64 - txPipeSize int64 - rxPipeSize int64 - sharedDataSize int64 -} - -func createQueueFDs(t *testing.T, s queueSizes) QueueConfig { - fd, _, err := syscall.RawSyscall(syscall.SYS_EVENTFD2, 0, 0, 0) - if err != 0 { - t.Fatalf("eventfd failed: %v", error(err)) - } - - return QueueConfig{ - EventFD: int(fd), - DataFD: createFile(t, s.dataSize, false), - TxPipeFD: createFile(t, s.txPipeSize, true), - RxPipeFD: createFile(t, s.rxPipeSize, true), - SharedDataFD: createFile(t, s.sharedDataSize, false), - } -} - -// TestSimpleSend sends 1000 packets with random header and payload sizes, -// then checks that the right payload is received on the shared memory queues. -func TestSimpleSend(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Prepare route. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - for iters := 1000; iters > 0; iters-- { - func() { - // Prepare and send packet. - n := rand.Intn(10000) - hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) - hdrBuf := hdr.Prepend(n) - randomFill(hdrBuf) - - n = rand.Intn(10000) - buf := buffer.NewView(n) - randomFill(buf) - - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), proto); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Receive packet. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if pi.Reserved != 0 { - t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) - } - contents := make([]byte, 0, pi.Size) - for i := 0; i < pi.BufferCount; i++ { - bi := queue.DecodeTxBufferHeader(desc, i) - contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) - } - c.txq.tx.Flush() - - defer func() { - // Tell the endpoint about the completion of the write. - b := c.txq.rx.Push(8) - queue.EncodeTxCompletion(b, pi.ID) - c.txq.rx.Flush() - }() - - // Check the ethernet header. - ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) - ethTemplate.Encode(&header.EthernetFields{ - SrcAddr: localLinkAddr, - DstAddr: remoteLinkAddr, - Type: proto, - }) - if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { - t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) - } - - // Compare contents skipping the ethernet header added by the - // endpoint. - merged := append(hdrBuf, buf...) - if uint32(len(contents)) < pi.Size { - t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) - } - contents = contents[:pi.Size][header.EthernetMinimumSize:] - - if !bytes.Equal(contents, merged) { - t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged)) - } - }() - } -} - -// TestPreserveSrcAddressInSend calls WritePacket once with LocalLinkAddress -// set in Route (using much of the same code as TestSimpleSend), then checks -// that the encoded ethernet header received includes the correct SrcAddr. -func TestPreserveSrcAddressInSend(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) - // Set both remote and local link address in route. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - LocalLinkAddress: newLocalLinkAddress, - } - - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } - - // Receive packet. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if pi.Reserved != 0 { - t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) - } - contents := make([]byte, 0, pi.Size) - for i := 0; i < pi.BufferCount; i++ { - bi := queue.DecodeTxBufferHeader(desc, i) - contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) - } - c.txq.tx.Flush() - - defer func() { - // Tell the endpoint about the completion of the write. - b := c.txq.rx.Push(8) - queue.EncodeTxCompletion(b, pi.ID) - c.txq.rx.Flush() - }() - - // Check that the ethernet header contains the expected SrcAddr. - ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) - ethTemplate.Encode(&header.EthernetFields{ - SrcAddr: newLocalLinkAddress, - DstAddr: remoteLinkAddr, - Type: proto, - }) - if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { - t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) - } -} - -// TestFillTxQueue sends packets until the queue is full. -func TestFillTxQueue(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Each packet is uses no more than 40 bytes, so write that many packets - // until the tx queue if full. - ids := make(map[uint64]struct{}) - for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - } - - // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } -} - -// TestFillTxQueueAfterBadCompletion sends a bad completion, then sends packets -// until the queue is full. -func TestFillTxQueueAfterBadCompletion(t *testing.T) { - c := newTestContext(t, 20000, 1500, localLinkAddr) - defer c.cleanup() - - // Send a bad completion. - queue.EncodeTxCompletion(c.txq.rx.Push(8), 1) - c.txq.rx.Flush() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Send two packets so that the id slice has at least two slots. - for i := 2; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - } - - // Complete the two writes twice. - for i := 2; i > 0; i-- { - pi := queue.DecodeTxPacketHeader(c.txq.tx.Pull()) - - queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) - queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) - c.txq.rx.Flush() - } - c.txq.tx.Flush() - - // Each packet is uses no more than 40 bytes, so write that many packets - // until the tx queue if full. - ids := make(map[uint64]struct{}) - for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - } - - // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } -} - -// TestFillTxMemory sends packets until the we run out of shared memory. -func TestFillTxMemory(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Each packet is uses up one buffer, so write as many as possible until - // we fill the memory. - ids := make(map[uint64]struct{}) - for i := queueDataSize / bufferSize; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Check that they have different IDs. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - if _, ok := ids[pi.ID]; ok { - t.Fatalf("ID (%v) reused", pi.ID) - } - ids[pi.ID] = struct{}{} - c.txq.tx.Flush() - } - - // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber) - if want := tcpip.ErrWouldBlock; err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } -} - -// TestFillTxMemoryWithMultiBuffer sends packets until the we run out of -// shared memory for a 2-buffer packet, but still with room for a 1-buffer -// packet. -func TestFillTxMemoryWithMultiBuffer(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } - - buf := buffer.NewView(100) - - // Each packet is uses up one buffer, so write as many as possible - // until there is only one buffer left. - for i := queueDataSize/bufferSize - 1; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - - // Pull the posted buffer. - c.txq.tx.Pull() - c.txq.tx.Flush() - } - - // Attempt to write a two-buffer packet. It must fail. - { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, uu, header.IPv4ProtocolNumber); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) - } - } - - // Attempt to write the one-buffer packet again. It must succeed. - { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) - } - } -} - -func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte { - t.Helper() - - for { - b := p.Pull() - if b != nil { - return b - } - - select { - case <-time.After(10 * time.Millisecond): - case <-to: - t.Fatal(errStr) - } - } -} - -// TestSimpleReceive completes 1000 different receives with random payload and -// random number of buffers. It checks that the contents match the expected -// values. -func TestSimpleReceive(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Check that buffers have been posted. - limit := c.ep.rx.q.PostedBuffersLimit() - for i := uint64(0); i < limit; i++ { - timeout := time.After(2 * time.Second) - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers to be posted")) - - if want := i * bufferSize; want != bi.Offset { - t.Fatalf("Bad posted offset: got %v, want %v", bi.Offset, want) - } - - if want := i; want != bi.ID { - t.Fatalf("Bad posted ID: got %v, want %v", bi.ID, want) - } - - if bufferSize != bi.Size { - t.Fatalf("Bad posted bufferSize: got %v, want %v", bi.Size, bufferSize) - } - } - c.rxq.tx.Flush() - - // Create a slice with the indices 0..limit-1. - idx := make([]int, limit) - for i := range idx { - idx[i] = i - } - - // Complete random packets 1000 times. - for iters := 1000; iters > 0; iters-- { - timeout := time.After(2 * time.Second) - // Prepare a random packet. - shuffle(idx) - n := 1 + rand.Intn(10) - bufs := make([]queue.RxBuffer, n) - contents := make([]byte, bufferSize*n-rand.Intn(500)) - randomFill(contents) - for i := range bufs { - j := idx[i] - bufs[i].Size = bufferSize - bufs[i].Offset = uint64(bufferSize * j) - bufs[i].ID = uint64(j) - - copy(c.rxq.data[bufs[i].Offset:][:bufferSize], contents[i*bufferSize:]) - } - - // Push completion. - c.pushRxCompletion(uint32(len(contents)), bufs) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for packet to be received, then check it. - c.waitForPackets(1, time.After(time.Second), "Error waiting for packet") - c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) - c.packets = c.packets[:0] - c.mu.Unlock() - - if contents := contents[header.EthernetMinimumSize:]; !bytes.Equal(contents, rcvd) { - t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents) - } - - // Check that buffers have been reposted. - for i := range bufs { - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffers to be reposted")) - if bi != bufs[i] { - t.Fatalf("Unexpected buffer reposted: got %x, want %x", bi, bufs[i]) - } - } - c.rxq.tx.Flush() - } -} - -// TestRxBuffersReposted tests that rx buffers get reposted after they have been -// completed. -func TestRxBuffersReposted(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Receive all posted buffers. - limit := c.ep.rx.q.PostedBuffersLimit() - buffers := make([]queue.RxBuffer, 0, limit) - for i := limit; i > 0; i-- { - timeout := time.After(2 * time.Second) - buffers = append(buffers, queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers"))) - } - c.rxq.tx.Flush() - - // Check that all buffers are reposted when individually completed. - for i := range buffers { - timeout := time.After(2 * time.Second) - // Complete the buffer. - c.pushRxCompletion(buffers[i].Size, buffers[i:][:1]) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for it to be reposted. - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) - if bi != buffers[i] { - t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[i]) - } - } - c.rxq.tx.Flush() - - // Check that all buffers are reposted when completed in pairs. - for i := 0; i < len(buffers)/2; i++ { - timeout := time.After(2 * time.Second) - // Complete with two buffers. - c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2]) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for them to be reposted. - for j := 0; j < 2; j++ { - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) - if bi != buffers[2*i+j] { - t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+j]) - } - } - } - c.rxq.tx.Flush() -} - -// TestReceivePostingIsFull checks that the endpoint will properly handle the -// case when a received buffer cannot be immediately reposted because it hasn't -// been pulled from the tx pipe yet. -func TestReceivePostingIsFull(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - defer c.cleanup() - - // Complete first posted buffer before flushing it from the tx pipe. - first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted")) - c.pushRxCompletion(first.Size, []queue.RxBuffer{first}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that packet is received. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") - - // Complete another buffer. - second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted")) - c.pushRxCompletion(second.Size, []queue.RxBuffer{second}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that no packet is received yet, as the worker is blocked trying - // to repost. - select { - case <-time.After(500 * time.Millisecond): - case <-c.packetCh: - t.Fatalf("Unexpected packet received") - } - - // Flush tx queue, which will allow the first buffer to be reposted, - // and the second completion to be pulled. - c.rxq.tx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Check that second packet completes. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet") -} - -// TestCloseWhileWaitingToPost closes the endpoint while it is waiting to -// repost a buffer. Make sure it backs out. -func TestCloseWhileWaitingToPost(t *testing.T) { - const bufferSize = 1500 - c := newTestContext(t, 20000, bufferSize, localLinkAddr) - cleaned := false - defer func() { - if !cleaned { - c.cleanup() - } - }() - - // Complete first posted buffer before flushing it from the tx pipe. - bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted")) - c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi}) - c.rxq.rx.Flush() - syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) - - // Wait for packet to be indicated. - c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") - - // Cleanup and wait for worker to complete. - c.cleanup() - cleaned = true - c.ep.Wait() -} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go deleted file mode 100644 index f7e816a41..000000000 --- a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sharedmem - -import ( - "unsafe" -) - -// sharedDataPointer converts the shared data slice into a pointer so that it -// can be used in atomic operations. -func sharedDataPointer(sharedData []byte) *uint32 { - return (*uint32)(unsafe.Pointer(&sharedData[0:4][0])) -} diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go deleted file mode 100644 index 6b8d7859d..000000000 --- a/pkg/tcpip/link/sharedmem/tx.go +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sharedmem - -import ( - "math" - "syscall" - - "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" -) - -const ( - nilID = math.MaxUint64 -) - -// tx holds all state associated with a tx queue. -type tx struct { - data []byte - q queue.Tx - ids idManager - bufs bufferManager -} - -// init initializes all state needed by the tx queue based on the information -// provided. -// -// The caller always retains ownership of all file descriptors passed in. The -// queue implementation will duplicate any that it may need in the future. -func (t *tx) init(mtu uint32, c *QueueConfig) error { - // Map in all buffers. - txPipe, err := getBuffer(c.TxPipeFD) - if err != nil { - return err - } - - rxPipe, err := getBuffer(c.RxPipeFD) - if err != nil { - syscall.Munmap(txPipe) - return err - } - - data, err := getBuffer(c.DataFD) - if err != nil { - syscall.Munmap(txPipe) - syscall.Munmap(rxPipe) - return err - } - - // Initialize state based on buffers. - t.q.Init(txPipe, rxPipe) - t.ids.init() - t.bufs.init(0, len(data), int(mtu)) - t.data = data - - return nil -} - -// cleanup releases all resources allocated during init(). It must only be -// called if init() has previously succeeded. -func (t *tx) cleanup() { - a, b := t.q.Bytes() - syscall.Munmap(a) - syscall.Munmap(b) - syscall.Munmap(t.data) -} - -// transmit sends a packet made up of up to two buffers. Returns a boolean that -// specifies whether the packet was successfully transmitted. -func (t *tx) transmit(a, b []byte) bool { - // Pull completions from the tx queue and add their buffers back to the - // pool so that we can reuse them. - for { - id, ok := t.q.CompletedPacket() - if !ok { - break - } - - if buf := t.ids.remove(id); buf != nil { - t.bufs.free(buf) - } - } - - bSize := t.bufs.entrySize - total := uint32(len(a) + len(b)) - bufCount := (total + bSize - 1) / bSize - - // Allocate enough buffers to hold all the data. - var buf *queue.TxBuffer - for i := bufCount; i != 0; i-- { - b := t.bufs.alloc() - if b == nil { - // Failed to get all buffers. Return to the pool - // whatever we had managed to get. - if buf != nil { - t.bufs.free(buf) - } - return false - } - b.Next = buf - buf = b - } - - // Copy data into allocated buffers. - nBuf := buf - var dBuf []byte - for _, data := range [][]byte{a, b} { - for len(data) > 0 { - if len(dBuf) == 0 { - dBuf = t.data[nBuf.Offset:][:nBuf.Size] - nBuf = nBuf.Next - } - n := copy(dBuf, data) - data = data[n:] - dBuf = dBuf[n:] - } - } - - // Get an id for this packet and send it out. - id := t.ids.add(buf) - if !t.q.Enqueue(id, total, bufCount, buf) { - t.ids.remove(id) - t.bufs.free(buf) - return false - } - - return true -} - -// getBuffer returns a memory region mapped to the full contents of the given -// file descriptor. -func getBuffer(fd int) ([]byte, error) { - var s syscall.Stat_t - if err := syscall.Fstat(fd, &s); err != nil { - return nil, err - } - - // Check that size doesn't overflow an int. - if s.Size > int64(^uint(0)>>1) { - return nil, syscall.EDOM - } - - return syscall.Mmap(fd, 0, int(s.Size), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED|syscall.MAP_FILE) -} - -// idDescriptor is used by idManager to either point to a tx buffer (in case -// the ID is assigned) or to the next free element (if the id is not assigned). -type idDescriptor struct { - buf *queue.TxBuffer - nextFree uint64 -} - -// idManager is a manager of tx buffer identifiers. It assigns unique IDs to -// tx buffers that are added to it; the IDs can only be reused after they have -// been removed. -// -// The ID assignments are stored so that the tx buffers can be retrieved from -// the IDs previously assigned to them. -type idManager struct { - // ids is a slice containing all tx buffers. The ID is the index into - // this slice. - ids []idDescriptor - - // freeList a list of free IDs. - freeList uint64 -} - -// init initializes the id manager. -func (m *idManager) init() { - m.freeList = nilID -} - -// add assigns an ID to the given tx buffer. -func (m *idManager) add(b *queue.TxBuffer) uint64 { - if i := m.freeList; i != nilID { - // There is an id available in the free list, just use it. - m.ids[i].buf = b - m.freeList = m.ids[i].nextFree - return i - } - - // We need to expand the id descriptor. - m.ids = append(m.ids, idDescriptor{buf: b}) - return uint64(len(m.ids) - 1) -} - -// remove retrieves the tx buffer associated with the given ID, and removes the -// ID from the assigned table so that it can be reused in the future. -func (m *idManager) remove(i uint64) *queue.TxBuffer { - if i >= uint64(len(m.ids)) { - return nil - } - - desc := &m.ids[i] - b := desc.buf - if b == nil { - // The provided id is not currently assigned. - return nil - } - - desc.buf = nil - desc.nextFree = m.freeList - m.freeList = i - - return b -} - -// bufferManager manages a buffer region broken up into smaller, equally sized -// buffers. Smaller buffers can be allocated and freed. -type bufferManager struct { - freeList *queue.TxBuffer - curOffset uint64 - limit uint64 - entrySize uint32 -} - -// init initializes the buffer manager. -func (b *bufferManager) init(initialOffset, size, entrySize int) { - b.freeList = nil - b.curOffset = uint64(initialOffset) - b.limit = uint64(initialOffset + size/entrySize*entrySize) - b.entrySize = uint32(entrySize) -} - -// alloc allocates a buffer from the manager, if one is available. -func (b *bufferManager) alloc() *queue.TxBuffer { - if b.freeList != nil { - // There is a descriptor ready for reuse in the free list. - d := b.freeList - b.freeList = d.Next - d.Next = nil - return d - } - - if b.curOffset < b.limit { - // There is room available in the never-used range, so create - // a new descriptor for it. - d := &queue.TxBuffer{ - Offset: b.curOffset, - Size: b.entrySize, - } - b.curOffset += uint64(b.entrySize) - return d - } - - return nil -} - -// free returns all buffers in the list to the buffer manager so that they can -// be reused. -func (b *bufferManager) free(d *queue.TxBuffer) { - // Find the last buffer in the list. - last := d - for last.Next != nil { - last = last.Next - } - - // Push list onto free list. - last.Next = b.freeList - b.freeList = d -} diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD deleted file mode 100644 index 1756114e6..000000000 --- a/pkg/tcpip/link/sniffer/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "sniffer", - srcs = [ - "pcap.go", - "sniffer.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/log", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/sniffer/sniffer_state_autogen.go b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go new file mode 100755 index 000000000..cfd84a739 --- /dev/null +++ b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package sniffer + diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD deleted file mode 100644 index 92dce8fac..000000000 --- a/pkg/tcpip/link/tun/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "tun", - srcs = ["tun_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", - visibility = [ - "//visibility:public", - ], -) diff --git a/pkg/tcpip/link/tun/tun_unsafe.go b/pkg/tcpip/link/tun/tun_unsafe.go deleted file mode 100644 index 09ca9b527..000000000 --- a/pkg/tcpip/link/tun/tun_unsafe.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// Package tun contains methods to open TAP and TUN devices. -package tun - -import ( - "syscall" - "unsafe" -) - -// Open opens the specified TUN device, sets it to non-blocking mode, and -// returns its file descriptor. -func Open(name string) (int, error) { - return open(name, syscall.IFF_TUN|syscall.IFF_NO_PI) -} - -// OpenTAP opens the specified TAP device, sets it to non-blocking mode, and -// returns its file descriptor. -func OpenTAP(name string) (int, error) { - return open(name, syscall.IFF_TAP|syscall.IFF_NO_PI) -} - -func open(name string, flags uint16) (int, error) { - fd, err := syscall.Open("/dev/net/tun", syscall.O_RDWR, 0) - if err != nil { - return -1, err - } - - var ifr struct { - name [16]byte - flags uint16 - _ [22]byte - } - - copy(ifr.name[:], name) - ifr.flags = flags - _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.TUNSETIFF, uintptr(unsafe.Pointer(&ifr))) - if errno != 0 { - syscall.Close(fd) - return -1, errno - } - - if err = syscall.SetNonblock(fd, true); err != nil { - syscall.Close(fd) - return -1, err - } - - return fd, nil -} diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD deleted file mode 100644 index 2597d4b3e..000000000 --- a/pkg/tcpip/link/waitable/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "waitable", - srcs = [ - "waitable.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/gate", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "waitable_test", - srcs = [ - "waitable_test.go", - ], - embed = [":waitable"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - ], -) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go deleted file mode 100644 index 3b6ac2ff7..000000000 --- a/pkg/tcpip/link/waitable/waitable.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package waitable provides the implementation of data-link layer endpoints -// that wrap other endpoints, and can wait for inflight calls to WritePacket or -// DeliverNetworkPacket to finish (and new ones to be prevented). -// -// Waitable endpoints can be used in the networking stack by calling New(eID) to -// create a new endpoint, where eID is the ID of the endpoint being wrapped, -// and then passing it as an argument to Stack.CreateNIC(). -package waitable - -import ( - "gvisor.dev/gvisor/pkg/gate" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// Endpoint is a waitable link-layer endpoint. -type Endpoint struct { - dispatchGate gate.Gate - dispatcher stack.NetworkDispatcher - - writeGate gate.Gate - lower stack.LinkEndpoint -} - -// New creates a new waitable link-layer endpoint. It wraps around another -// endpoint and allows the caller to block new write/dispatch calls and wait for -// the inflight ones to finish before returning. -func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ - lower: stack.FindLinkEndpoint(lower), - } - return stack.RegisterLinkEndpoint(e), e -} - -// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. -// It is called by the link-layer endpoint being wrapped when a packet arrives, -// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't -// been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - if !e.dispatchGate.Enter() { - return - } - - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv) - e.dispatchGate.Leave() -} - -// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and -// registers with the lower endpoint as its dispatcher so that "e" is called -// for inbound packets. -func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher - e.lower.Attach(e) -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *Endpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the -// lower endpoint. -func (e *Endpoint) MTU() uint32 { - return e.lower.MTU() -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the -// request to the lower endpoint. -func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.lower.Capabilities() -} - -// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It just -// forwards the request to the lower endpoint. -func (e *Endpoint) MaxHeaderLength() uint16 { - return e.lower.MaxHeaderLength() -} - -// LinkAddress implements stack.LinkEndpoint.LinkAddress. It just forwards the -// request to the lower endpoint. -func (e *Endpoint) LinkAddress() tcpip.LinkAddress { - return e.lower.LinkAddress() -} - -// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by -// higher-level protocols to write packets. It only forwards packets to the -// lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - if !e.writeGate.Enter() { - return nil - } - - err := e.lower.WritePacket(r, gso, hdr, payload, protocol) - e.writeGate.Leave() - return err -} - -// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint, -// and waits for inflight ones to finish before returning. -func (e *Endpoint) WaitWrite() { - e.writeGate.Close() -} - -// WaitDispatch prevents new calls to DeliverNetworkPacket from reaching the -// actual dispatcher, and waits for inflight ones to finish before returning. -func (e *Endpoint) WaitDispatch() { - e.dispatchGate.Close() -} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go deleted file mode 100644 index 56e18ecb0..000000000 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package waitable - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type countedEndpoint struct { - dispatchCount int - writeCount int - attachCount int - - mtu uint32 - capabilities stack.LinkEndpointCapabilities - hdrLen uint16 - linkAddr tcpip.LinkAddress - - dispatcher stack.NetworkDispatcher -} - -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.dispatchCount++ -} - -func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.attachCount++ - e.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *countedEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -func (e *countedEndpoint) MTU() uint32 { - return e.mtu -} - -func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.capabilities -} - -func (e *countedEndpoint) MaxHeaderLength() uint16 { - return e.hdrLen -} - -func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { - return e.linkAddr -} - -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - e.writeCount++ - return nil -} - -func TestWaitWrite(t *testing.T) { - ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) - - // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) - if want := 1; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } - - // Wait on dispatches, then try to write. It must go through. - wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) - if want := 2; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } - - // Wait on writes, then try to write. It must not go through. - wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) - if want := 2; ep.writeCount != want { - t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) - } -} - -func TestWaitDispatch(t *testing.T) { - ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) - - // Check that attach happens. - wep.Attach(ep) - if want := 1; ep.attachCount != want { - t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want) - } - - // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) - if want := 1; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } - - // Wait on writes, then try to dispatch. It must go through. - wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) - if want := 2; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } - - // Wait on dispatches, then try to dispatch. It must not go through. - wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) - if want := 2; ep.dispatchCount != want { - t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) - } -} - -func TestOtherMethods(t *testing.T) { - const ( - mtu = 0xdead - capabilities = 0xbeef - hdrLen = 0x1234 - linkAddr = "test address" - ) - ep := &countedEndpoint{ - mtu: mtu, - capabilities: capabilities, - hdrLen: hdrLen, - linkAddr: linkAddr, - } - _, wep := New(stack.RegisterLinkEndpoint(ep)) - - if v := wep.MTU(); v != mtu { - t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) - } - - if v := wep.Capabilities(); v != capabilities { - t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities) - } - - if v := wep.MaxHeaderLength(); v != hdrLen { - t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen) - } - - if v := wep.LinkAddress(); v != linkAddr { - t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr) - } -} diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD deleted file mode 100644 index f36f49453..000000000 --- a/pkg/tcpip/network/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "ip_test", - size = "small", - srcs = [ - "ip_test.go", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - ], -) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD deleted file mode 100644 index d95d44f56..000000000 --- a/pkg/tcpip/network/arp/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "arp", - srcs = ["arp.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "arp_test", - size = "small", - srcs = ["arp_test.go"], - deps = [ - ":arp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - ], -) diff --git a/pkg/tcpip/network/arp/arp_state_autogen.go b/pkg/tcpip/network/arp/arp_state_autogen.go new file mode 100755 index 000000000..14a21baff --- /dev/null +++ b/pkg/tcpip/network/arp/arp_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package arp + diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go deleted file mode 100644 index 66c55821b..000000000 --- a/pkg/tcpip/network/arp/arp_test.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arp_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" -) - -const ( - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") - stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") - stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") -) - -type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack -} - -func newTestContext(t *testing.T) *testContext { - s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{}) - - const defaultMTU = 65536 - id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) - if testing.Verbose() { - id = sniffer.New(id) - } - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) - } - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) - } - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("AddAddress for arp failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", - NIC: 1, - }}) - - return &testContext{ - t: t, - s: s, - linkEP: linkEP, - } -} - -func (c *testContext) cleanup() { - close(c.linkEP.C) -} - -func TestDirectRequest(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - const senderMAC = "\x01\x02\x03\x04\x05\x06" - const senderIPv4 = "\x0a\x00\x00\x02" - - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), senderMAC) - copy(h.ProtocolAddressSender(), senderIPv4) - - inject := func(addr tcpip.Address) { - copy(h.ProtocolAddressTarget(), addr) - c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) - } - - inject(stackAddr1) - { - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto) - } - rep := header.ARP(pkt.Header) - if !rep.IsValid() { - t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) - } - if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 { - t.Errorf("stackAddr1: expected sender to be set") - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got) - } - } - - inject(stackAddr2) - { - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto) - } - rep := header.ARP(pkt.Header) - if !rep.IsValid() { - t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) - } - if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 { - t.Errorf("stackAddr2: expected sender to be set") - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got) - } - } - - inject(stackAddrBad) - select { - case pkt := <-c.linkEP.C: - t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) - case <-time.After(100 * time.Millisecond): - // Sleep tests are gross, but this will only potentially flake - // if there's a bug. If there is no bug this will reliably - // succeed. - } -} diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD deleted file mode 100644 index 118bfc763..000000000 --- a/pkg/tcpip/network/fragmentation/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "reassembler_list", - out = "reassembler_list.go", - package = "fragmentation", - prefix = "reassembler", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*reassembler", - "Linker": "*reassembler", - }, -) - -go_library( - name = "fragmentation", - srcs = [ - "frag_heap.go", - "fragmentation.go", - "reassembler.go", - "reassembler_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", - visibility = ["//:sandbox"], - deps = [ - "//pkg/log", - "//pkg/tcpip/buffer", - ], -) - -go_test( - name = "fragmentation_test", - size = "small", - srcs = [ - "frag_heap_test.go", - "fragmentation_test.go", - "reassembler_test.go", - ], - embed = [":fragmentation"], - deps = ["//pkg/tcpip/buffer"], -) - -filegroup( - name = "autogen", - srcs = [ - "reassembler_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go deleted file mode 100644 index 9ececcb9f..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -var reassambleTestCases = []struct { - comment string - in []fragment - want buffer.VectorisedView -}{ - { - comment: "Non-overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Non-overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Duplicated packets", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(1, "0"), - }, - { - comment: "Overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(2, "01")}, - {offset: 1, vv: vv(2, "12")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(2, "12")}, - {offset: 0, vv: vv(2, "01")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping subset in-order", - in: []fragment{ - {offset: 0, vv: vv(3, "012")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(3, "012"), - }, - { - comment: "Overlapping subset out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(3, "012")}, - }, - want: vv(3, "012"), - }, -} - -func TestReassamble(t *testing.T) { - for _, c := range reassambleTestCases { - t.Run(c.comment, func(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - for _, f := range c.in { - heap.Push(&h, f) - } - got, err := h.reassemble() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, c.want) { - t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want) - } - }) - } -} - -func TestReassambleFailsForNonZeroOffset(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when the first packet had offset != 0") - } -} - -func TestReassambleFailsForHoles(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) - heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when there was a hole in the packet") - } -} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go new file mode 100755 index 000000000..6b71d27c7 --- /dev/null +++ b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go @@ -0,0 +1,38 @@ +// automatically generated by stateify. + +package fragmentation + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *reassemblerList) beforeSave() {} +func (x *reassemblerList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *reassemblerList) afterLoad() {} +func (x *reassemblerList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *reassemblerEntry) beforeSave() {} +func (x *reassemblerEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *reassemblerEntry) afterLoad() {} +func (x *reassemblerEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("fragmentation.reassemblerList", (*reassemblerList)(nil), state.Fns{Save: (*reassemblerList).save, Load: (*reassemblerList).load}) + state.Register("fragmentation.reassemblerEntry", (*reassemblerEntry)(nil), state.Fns{Save: (*reassemblerEntry).save, Load: (*reassemblerEntry).load}) +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go deleted file mode 100644 index 799798544..000000000 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "reflect" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -// vv is a helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) buffer.VectorisedView { - views := make([]buffer.View, len(pieces)) - for i, p := range pieces { - views[i] = []byte(p) - } - - return buffer.NewVectorisedView(size, views) -} - -type processInput struct { - id uint32 - first uint16 - last uint16 - more bool - vv buffer.VectorisedView -} - -type processOutput struct { - vv buffer.VectorisedView - done bool -} - -var processTestCases = []struct { - comment string - in []processInput - out []processOutput -}{ - { - comment: "One ID", - in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, - { - comment: "Two IDs", - in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")}, - {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, - }, - out: []processOutput{ - {vv: buffer.VectorisedView{}, done: false}, - {vv: buffer.VectorisedView{}, done: false}, - {vv: vv(4, "ab", "cd"), done: true}, - {vv: vv(4, "01", "23"), done: true}, - }, - }, -} - -func TestFragmentationProcess(t *testing.T) { - for _, c := range processTestCases { - t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(1024, 512, DefaultReassembleTimeout) - for i, in := range c.in { - vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv) - if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) - } - if done != c.out[i].done { - t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done) - } - if c.out[i].done { - if _, ok := f.reassemblers[in.id]; ok { - t.Errorf("Process(%d) did not remove buffer from reassemblers", i) - } - for n := f.rList.Front(); n != nil; n = n.Next() { - if n.id == in.id { - t.Errorf("Process(%d) did not remove buffer from rList", i) - } - } - } - } - }) - } -} - -func TestReassemblingTimeout(t *testing.T) { - timeout := time.Millisecond - f := NewFragmentation(1024, 512, timeout) - // Send first fragment with id = 0, first = 0, last = 0, and more = true. - f.Process(0, 0, 0, true, vv(1, "0")) - // Sleep more than the timeout. - time.Sleep(2 * timeout) - // Send another fragment that completes a packet. - // However, no packet should be reassembled because the fragment arrived after the timeout. - _, done := f.Process(0, 1, 1, false, vv(1, "1")) - if done { - t.Errorf("Fragmentation does not respect the reassembling timeout.") - } -} - -func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(3, 1, DefaultReassembleTimeout) - // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) - // Send first fragment with id = 1. - f.Process(1, 0, 0, true, vv(1, "1")) - // Send first fragment with id = 2. - f.Process(2, 0, 0, true, vv(1, "2")) - - // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be - // evicted. - f.Process(3, 0, 0, true, vv(1, "3")) - - if _, ok := f.reassemblers[0]; ok { - t.Errorf("Memory limits are not respected: id=0 has not been evicted.") - } - if _, ok := f.reassemblers[1]; ok { - t.Errorf("Memory limits are not respected: id=1 has not been evicted.") - } - if _, ok := f.reassemblers[3]; !ok { - t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") - } -} - -func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(1, 0, DefaultReassembleTimeout) - // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) - // Send the same packet again. - f.Process(0, 0, 0, true, vv(1, "0")) - - got := f.size - want := 1 - if got != want { - t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) - } -} diff --git a/pkg/tcpip/network/fragmentation/reassembler_list.go b/pkg/tcpip/network/fragmentation/reassembler_list.go new file mode 100755 index 000000000..3189cae29 --- /dev/null +++ b/pkg/tcpip/network/fragmentation/reassembler_list.go @@ -0,0 +1,173 @@ +package fragmentation + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type reassemblerElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type reassemblerList struct { + head *reassembler + tail *reassembler +} + +// Reset resets list l to the empty state. +func (l *reassemblerList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *reassemblerList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *reassemblerList) Front() *reassembler { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *reassemblerList) Back() *reassembler { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *reassemblerList) PushFront(e *reassembler) { + reassemblerElementMapper{}.linkerFor(e).SetNext(l.head) + reassemblerElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *reassemblerList) PushBack(e *reassembler) { + reassemblerElementMapper{}.linkerFor(e).SetNext(nil) + reassemblerElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *reassemblerList) PushBackList(m *reassemblerList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head) + reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *reassemblerList) InsertAfter(b, e *reassembler) { + a := reassemblerElementMapper{}.linkerFor(b).Next() + reassemblerElementMapper{}.linkerFor(e).SetNext(a) + reassemblerElementMapper{}.linkerFor(e).SetPrev(b) + reassemblerElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + reassemblerElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *reassemblerList) InsertBefore(a, e *reassembler) { + b := reassemblerElementMapper{}.linkerFor(a).Prev() + reassemblerElementMapper{}.linkerFor(e).SetNext(a) + reassemblerElementMapper{}.linkerFor(e).SetPrev(b) + reassemblerElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + reassemblerElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *reassemblerList) Remove(e *reassembler) { + prev := reassemblerElementMapper{}.linkerFor(e).Prev() + next := reassemblerElementMapper{}.linkerFor(e).Next() + + if prev != nil { + reassemblerElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + reassemblerElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type reassemblerEntry struct { + next *reassembler + prev *reassembler +} + +// Next returns the entry that follows e in the list. +func (e *reassemblerEntry) Next() *reassembler { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *reassemblerEntry) Prev() *reassembler { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *reassemblerEntry) SetNext(elem *reassembler) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *reassemblerEntry) SetPrev(elem *reassembler) { + e.prev = elem +} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go deleted file mode 100644 index 7eee0710d..000000000 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "math" - "reflect" - "testing" -) - -type updateHolesInput struct { - first uint16 - last uint16 - more bool -} - -var holesTestCases = []struct { - comment string - in []updateHolesInput - want []hole -}{ - { - comment: "No fragments. Expected holes: {[0 -> inf]}.", - in: []updateHolesInput{}, - want: []hole{{first: 0, last: math.MaxUint16, deleted: false}}, - }, - { - comment: "One fragment at beginning. Expected holes: {[2, inf]}.", - in: []updateHolesInput{{first: 0, last: 1, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: false}, - }, - }, - { - comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.", - in: []updateHolesInput{{first: 1, last: 2, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, - {first: 3, last: math.MaxUint16, deleted: false}, - }, - }, - { - comment: "One fragment at the end. Expected holes: {[0, 0]}.", - in: []updateHolesInput{{first: 1, last: 2, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, - }, - }, - { - comment: "One fragment completing a packet. Expected holes: {}.", - in: []updateHolesInput{{first: 0, last: 1, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - }, - }, - { - comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 1, more: true}, - {first: 2, last: 3, more: false}, - }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: true}, - }, - }, - { - comment: "Two overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 2, more: true}, - {first: 2, last: 3, more: false}, - }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 3, last: math.MaxUint16, deleted: true}, - }, - }, -} - -func TestUpdateHoles(t *testing.T) { - for _, c := range holesTestCases { - r := newReassembler(0) - for _, i := range c.in { - r.updateHoles(i.first, i.last, i.more) - } - if !reflect.DeepEqual(r.holes, c.want) { - t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want) - } - } -} diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD deleted file mode 100644 index e6db5c0b0..000000000 --- a/pkg/tcpip/network/hash/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hash", - srcs = ["hash.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash", - visibility = ["//visibility:public"], - deps = [ - "//pkg/rand", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/network/hash/hash_state_autogen.go b/pkg/tcpip/network/hash/hash_state_autogen.go new file mode 100755 index 000000000..a3bcd4b69 --- /dev/null +++ b/pkg/tcpip/network/hash/hash_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package hash + diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go deleted file mode 100644 index db65ee7cc..000000000 --- a/pkg/tcpip/network/ip_test.go +++ /dev/null @@ -1,607 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ip_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const ( - localIpv4Addr = "\x0a\x00\x00\x01" - remoteIpv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" -) - -// testObject implements two interfaces: LinkEndpoint and TransportDispatcher. -// The former is used to pretend that it's a link endpoint so that we can -// inspect packets written by the network endpoints. The latter is used to -// pretend that it's the network stack so that it can inspect incoming packets -// that have been handled by the network endpoints. -// -// Packets are checked by comparing their fields/values against the expected -// values stored in the test object itself. -type testObject struct { - t *testing.T - protocol tcpip.TransportProtocolNumber - contents []byte - srcAddr tcpip.Address - dstAddr tcpip.Address - v4 bool - typ stack.ControlType - extra uint32 - - dataCalls int - controlCalls int -} - -// checkValues verifies that the transport protocol, data contents, src & dst -// addresses of a packet match what's expected. If any field doesn't match, the -// test fails. -func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) { - v := vv.ToView() - if protocol != t.protocol { - t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) - } - - if srcAddr != t.srcAddr { - t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr) - } - - if dstAddr != t.dstAddr { - t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr) - } - - if len(v) != len(t.contents) { - t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents)) - } - - for i := range t.contents { - if t.contents[i] != v[i] { - t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i]) - } - } -} - -// DeliverTransportPacket is called by network endpoints after parsing incoming -// packets. This is used by the test object to verify that the results of the -// parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { - t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) - t.dataCalls++ -} - -// DeliverTransportControlPacket is called by network endpoints after parsing -// incoming control (ICMP) packets. This is used by the test object to verify -// that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - t.checkValues(trans, vv, remote, local) - if typ != t.typ { - t.t.Errorf("typ = %v, want %v", typ, t.typ) - } - if extra != t.extra { - t.t.Errorf("extra = %v, want %v", extra, t.extra) - } - t.controlCalls++ -} - -// Attach is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) Attach(stack.NetworkDispatcher) {} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (*testObject) IsAttached() bool { - return true -} - -// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that -// matches the linux loopback MTU. -func (*testObject) MTU() uint32 { - return 65536 -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (*testObject) Capabilities() stack.LinkEndpointCapabilities { - return 0 -} - -// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface. -func (*testObject) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (*testObject) LinkAddress() tcpip.LinkAddress { - return "" -} - -// WritePacket is called by network endpoints after producing a packet and -// writing it to the link endpoint. This is used by the test object to verify -// that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - var prot tcpip.TransportProtocolNumber - var srcAddr tcpip.Address - var dstAddr tcpip.Address - - if t.v4 { - h := header.IPv4(hdr.View()) - prot = tcpip.TransportProtocolNumber(h.Protocol()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - - } else { - h := header.IPv6(hdr.View()) - prot = tcpip.TransportProtocolNumber(h.NextHeader()) - srcAddr = h.SourceAddress() - dstAddr = h.DestinationAddress() - } - t.checkValues(prot, payload, srcAddr, dstAddr) - return nil -} - -func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv4.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: ipv4SubnetAddr, - Mask: ipv4SubnetMask, - Gateway: ipv4Gateway, - NIC: 1, - }}) - - return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) -} - -func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv6.ProtocolNumber, local) - s.SetRouteTable([]tcpip.Route{{ - Destination: ipv6SubnetAddr, - Mask: ipv6SubnetMask, - Gateway: ipv6Gateway, - NIC: 1, - }}) - - return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) -} - -func TestIPv4Send(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, nil, &o) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) - - // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv4Addr - o.dstAddr = remoteIpv4Addr - o.contents = payload - - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestIPv4Receive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - totalLen := header.IPv4MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, - }) - - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[header.IPv4MinimumSize:totalLen] - - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - ep.HandlePacket(&r, view.ToVectorisedView()) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) - } -} - -func TestIPv4ReceiveControl(t *testing.T) { - const mtu = 0xbeef - header.IPv4MinimumSize - cases := []struct { - name string - expectedCount int - fragmentOffset uint16 - code uint8 - expectedTyp stack.ControlType - expectedExtra uint32 - trunc int - }{ - {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, - {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4DstUnreachableMinimumSize + header.IPv4MinimumSize + 8}, - {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8}, - } - r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb") - if err != nil { - t.Fatal(err) - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + 4 - view := buffer.NewView(dataOffset + 8) - - // Create the outer IPv4 header. - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(len(view) - c.trunc), - TTL: 20, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: "\x0a\x00\x00\xbb", - DstAddr: localIpv4Addr, - }) - - // Create the ICMP header. - icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) - icmp.SetType(header.ICMPv4DstUnreachable) - icmp.SetCode(c.code) - copy(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) - - // Create the inner IPv4 header. - ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize+4:]) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: 100, - TTL: 20, - Protocol: 10, - FragmentOffset: c.fragmentOffset, - SrcAddr: localIpv4Addr, - DstAddr: remoteIpv4Addr, - }) - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to IPv4 endpoint, dispatcher will validate that - // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra - - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) - } - }) - } -} - -func TestIPv4FragmentationReceive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - totalLen := header.IPv4MinimumSize + 24 - - frag1 := buffer.NewView(totalLen) - ip1 := header.IPv4(frag1) - ip1.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 0, - Flags: header.IPv4FlagMoreFragments, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, - }) - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag1[i] = uint8(i) - } - - frag2 := buffer.NewView(totalLen) - ip2 := header.IPv4(frag2) - ip2.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - FragmentOffset: 24, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, - }) - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - frag2[i] = uint8(i) - } - - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - - // Send first segment. - ep.HandlePacket(&r, frag1.ToVectorisedView()) - if o.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) - } - - // Send second segment. - ep.HandlePacket(&r, frag2.ToVectorisedView()) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) - } -} - -func TestIPv6Send(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, nil, &o) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Allocate and initialize the payload view. - payload := buffer.NewView(100) - for i := 0; i < len(payload); i++ { - payload[i] = uint8(i) - } - - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) - - // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv6Addr - o.dstAddr = remoteIpv6Addr - o.contents = payload - - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } -} - -func TestIPv6Receive(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - totalLen := header.IPv6MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(totalLen - header.IPv6MinimumSize), - NextHeader: 10, - HopLimit: 20, - SrcAddr: remoteIpv6Addr, - DstAddr: localIpv6Addr, - }) - - // Make payload be non-zero. - for i := header.IPv6MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) - } - - // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[header.IPv6MinimumSize:totalLen] - - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - - ep.HandlePacket(&r, view.ToVectorisedView()) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) - } -} - -func TestIPv6ReceiveControl(t *testing.T) { - newUint16 := func(v uint16) *uint16 { return &v } - - const mtu = 0xffff - cases := []struct { - name string - expectedCount int - fragmentOffset *uint16 - typ header.ICMPv6Type - code uint8 - expectedTyp stack.ControlType - expectedExtra uint32 - trunc int - }{ - {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8}, - {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8}, - {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8}, - {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, - } - r, err := buildIPv6Route( - localIpv6Addr, - "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", - ) - if err != nil { - t.Fatal(err) - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - defer ep.Close() - - dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4 - if c.fragmentOffset != nil { - dataOffset += header.IPv6FragmentHeaderSize - } - view := buffer.NewView(dataOffset + 8) - - // Create the outer IPv6 header. - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 20, - SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", - DstAddr: localIpv6Addr, - }) - - // Create the ICMP header. - icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) - icmp.SetType(c.typ) - icmp.SetCode(c.code) - copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) - - // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) - ip.Encode(&header.IPv6Fields{ - PayloadLength: 100, - NextHeader: 10, - HopLimit: 20, - SrcAddr: localIpv6Addr, - DstAddr: remoteIpv6Addr, - }) - - // Build the fragmentation header if needed. - if c.fragmentOffset != nil { - ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) - frag.Encode(&header.IPv6FragmentFields{ - NextHeader: 10, - FragmentOffset: *c.fragmentOffset, - M: true, - Identification: 0x12345678, - }) - } - - // Make payload be non-zero. - for i := dataOffset; i < len(view); i++ { - view[i] = uint8(i) - } - - // Give packet to IPv6 endpoint, dispatcher will validate that - // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra - - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD deleted file mode 100644 index be84fa63d..000000000 --- a/pkg/tcpip/network/ipv4/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv4", - srcs = [ - "icmp.go", - "ipv4.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/network/fragmentation", - "//pkg/tcpip/network/hash", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv4_test", - size = "small", - srcs = ["ipv4_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go new file mode 100755 index 000000000..6b2cc0142 --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package ipv4 + diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go deleted file mode 100644 index 3207a3d46..000000000 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ /dev/null @@ -1,364 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv4_test - -import ( - "bytes" - "encoding/hex" - "math/rand" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestExcludeBroadcast(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) - - const defaultMTU = 65536 - id, _ := channel.New(256, defaultMTU, "") - if testing.Verbose() { - id = sniffer.New(id) - } - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{ - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", - NIC: 1, - }}) - - randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} - - var wq waiter.Queue - t.Run("WithoutPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Cannot connect using a broadcast address as the source. - if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) - } - - // However, we can bind to a broadcast address to listen. - if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil { - t.Errorf("Bind failed: %v", err) - } - }) - - t.Run("WithPrimaryAddress", func(t *testing.T) { - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - defer ep.Close() - - // Add a valid primary endpoint address, now we can connect. - if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - if err := ep.Connect(randomAddr); err != nil { - t.Errorf("Connect failed: %v", err) - } - }) -} - -// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much -// data should already be in the header before WritePacket. extraLength -// indicates how much extra space should be in the header. The payload is made -// from many Views of the sizes listed in viewSizes. -func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { - hdr := buffer.NewPrependable(hdrLength + extraLength) - hdr.Prepend(hdrLength) - rand.Read(hdr.View()) - - var views []buffer.View - totalLength := 0 - for _, s := range viewSizes { - newView := buffer.NewView(s) - rand.Read(newView) - views = append(views, newView) - totalLength += s - } - payload := buffer.NewVectorisedView(totalLength, views) - return hdr, payload -} - -// comparePayloads compared the contents of all the packets against the contents -// of the source packet. -func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { - t.Helper() - // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) - source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Payload.ToView()...) - - // Make a copy of the IP header, which will be modified in some fields to make - // an expected header. - sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) - sourceCopy.SetChecksum(0) - sourceCopy.SetFlagsFragmentOffset(0, 0) - sourceCopy.SetTotalLength(0) - var offset uint16 - // Build up an array of the bytes sent. - var reassembledPayload []byte - for i, packet := range packets { - // Confirm that the packet is valid. - allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Payload) - ip := header.IPv4(allBytes.ToView()) - if !ip.IsValid(len(ip)) { - t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) - } - if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { - t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) - } - if got, want := len(ip), int(mtu); got > want { - t.Errorf("fragment is too large, got %d want %d", got, want) - } - if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) - } - if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { - t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) - } - if i < len(packets)-1 { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) - } else { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) - } - reassembledPayload = append(reassembledPayload, ip.Payload()...) - offset += ip.TotalLength() - uint16(ip.HeaderLength()) - // Clear out the checksum and length from the ip because we can't compare - // it. - sourceCopy.SetTotalLength(uint16(len(ip))) - sourceCopy.SetChecksum(0) - sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { - t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) - } - } - expected := source[source.HeaderLength():] - if !bytes.Equal(reassembledPayload, expected) { - t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) - } -} - -type errorChannel struct { - *channel.Endpoint - Ch chan packetInfo - packetCollectorErrors []*tcpip.Error -} - -// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket -// will return successive errors from packetCollectorErrors until the list is -// empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) { - _, e := channel.New(size, mtu, linkAddr) - ec := errorChannel{ - Endpoint: e, - Ch: make(chan packetInfo, size), - packetCollectorErrors: packetCollectorErrors, - } - - return stack.RegisterLinkEndpoint(e), &ec -} - -// packetInfo holds all the information about an outbound packet. -type packetInfo struct { - Header buffer.Prependable - Payload buffer.VectorisedView -} - -// Drain removes all outbound packets from the channel and counts them. -func (e *errorChannel) Drain() int { - c := 0 - for { - select { - case <-e.Ch: - c++ - default: - return c - } - } -} - -// WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - p := packetInfo{ - Header: hdr, - Payload: payload, - } - - select { - case e.Ch <- p: - default: - } - - nextError := (*tcpip.Error)(nil) - if len(e.packetCollectorErrors) > 0 { - nextError = e.packetCollectorErrors[0] - e.packetCollectorErrors = e.packetCollectorErrors[1:] - } - return nextError -} - -type context struct { - stack.Route - linkEP *errorChannel -} - -func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { - // Make the packet and write it. - s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) - _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - linkEPId := stack.RegisterLinkEndpoint(linkEP) - s.CreateNIC(1, linkEPId) - s.AddAddress(1, ipv4.ProtocolNumber, "\x10\x00\x00\x01") - s.SetRouteTable([]tcpip.Route{{ - Destination: "\x10\x00\x00\x02", - Mask: "\xff\xff\xff\xff", - Gateway: "", - NIC: 1, - }}) - r, err := s.FindRoute(0, "\x10\x00\x00\x01", "\x10\x00\x00\x02", ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) - } - return context{ - Route: r, - linkEP: linkEP, - } -} - -func TestFragmentation(t *testing.T) { - var manyPayloadViewsSizes [1000]int - for i := range manyPayloadViewsSizes { - manyPayloadViewsSizes[i] = 7 - } - fragTests := []struct { - description string - mtu uint32 - gso *stack.GSO - hdrLength int - extraLength int - payloadViewsSizes []int - expectedFrags int - }{ - {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, - {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, - {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, - {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, - } - - for _, ft := range fragTests { - t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := packetInfo{ - Header: hdr, - // Save the source payload because WritePacket will modify it. - Payload: payload.Clone([]buffer.View{}), - } - c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42) - if err != nil { - t.Errorf("err got %v, want %v", err, nil) - } - - var results []packetInfo - L: - for { - select { - case pi := <-c.linkEP.Ch: - results = append(results, pi) - default: - break L - } - } - - if got, want := len(results), ft.expectedFrags; got != want { - t.Errorf("len(result) got %d, want %d", got, want) - } - if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet len(result) got %d, want %d", got, want) - } - compareFragments(t, results, source, ft.mtu) - }) - } -} - -// TestFragmentationErrors checks that errors are returned from write packet -// correctly. -func TestFragmentationErrors(t *testing.T) { - fragTests := []struct { - description string - mtu uint32 - hdrLength int - payloadViewsSizes []int - packetCollectorErrors []*tcpip.Error - }{ - {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, - {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, - } - - for _, ft := range fragTests { - t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) - c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42) - for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { - if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { - t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) - } - } - // We only need to check that last error because all the ones before are - // nil. - if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { - t.Errorf("err got %v, want %v", got, want) - } - if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { - t.Errorf("after linkEP error len(result) got %d, want %d", got, want) - } - }) - } -} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD deleted file mode 100644 index fae7f4507..000000000 --- a/pkg/tcpip/network/ipv6/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ipv6", - srcs = [ - "icmp.go", - "ipv6.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - ], -) - -go_test( - name = "ipv6_test", - size = "small", - srcs = ["icmp_test.go"], - embed = [":ipv6"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/icmp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go deleted file mode 100644 index d46d68e73..000000000 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ipv6 - -import ( - "fmt" - "reflect" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") - linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") -) - -var ( - lladdr0 = header.LinkLocalAddr(linkAddr0) - lladdr1 = header.LinkLocalAddr(linkAddr1) -) - -type stubLinkEndpoint struct { - stack.LinkEndpoint -} - -func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 -} - -func (*stubLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error { - return nil -} - -func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {} - -type stubDispatcher struct { - stack.TransportDispatcher -} - -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) { -} - -type stubLinkAddressCache struct { - stack.LinkAddressCache -} - -func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID { - return 0 -} - -func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { -} - -func TestICMPCounts(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) - { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: lladdr1, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), - NIC: 1, - }}, - ) - - ep, err := s.NetworkProtocolInstance(ProtocolNumber).NewEndpoint(0, lladdr1, &stubLinkAddressCache{}, &stubDispatcher{}, nil) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } - - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - defer r.Release() - - types := []struct { - typ header.ICMPv6Type - size int - }{ - {header.ICMPv6DstUnreachable, header.ICMPv6DstUnreachableMinimumSize}, - {header.ICMPv6PacketTooBig, header.ICMPv6PacketTooBigMinimumSize}, - {header.ICMPv6TimeExceeded, header.ICMPv6MinimumSize}, - {header.ICMPv6ParamProblem, header.ICMPv6MinimumSize}, - {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, - {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, - {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, - {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize}, - {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, - {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, - {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, - } - - handleIPv6Payload := func(hdr buffer.Prependable) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: r.DefaultTTL(), - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) - ep.HandlePacket(&r, hdr.View().ToVectorisedView()) - } - - for _, typ := range types { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - handleIPv6Payload(hdr) - } - - // Construct an empty ICMP packet so that - // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize)) - - icmpv6Stats := s.Stats().ICMP.V6PacketsReceived - visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { - if got, want := s.Value(), uint64(1); got != want { - t.Errorf("got %s = %d, want = %d", name, got, want) - } - }) - if t.Failed() { - t.Logf("stats:\n%+v", s.Stats()) - } -} - -func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { - t := v.Type() - for i := 0; i < v.NumField(); i++ { - v := v.Field(i) - switch v.Kind() { - case reflect.Ptr: - f(t.Field(i).Name, v.Interface().(*tcpip.StatCounter)) - case reflect.Struct: - visitStats(v, f) - default: - panic(fmt.Sprintf("unexpected type %s", v.Type())) - } - } -} - -type testContext struct { - s0 *stack.Stack - s1 *stack.Stack - - linkEP0 *channel.Endpoint - linkEP1 *channel.Endpoint -} - -type endpointWithResolutionCapability struct { - stack.LinkEndpoint -} - -func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities { - return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired -} - -func newTestContext(t *testing.T) *testContext { - c := &testContext{ - s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), - s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), - } - - const defaultMTU = 65536 - _, linkEP0 := channel.New(256, defaultMTU, linkAddr0) - c.linkEP0 = linkEP0 - wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0} - id0 := stack.RegisterLinkEndpoint(wrappedEP0) - if testing.Verbose() { - id0 = sniffer.New(id0) - } - if err := c.s0.CreateNIC(1, id0); err != nil { - t.Fatalf("CreateNIC s0: %v", err) - } - if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress lladdr0: %v", err) - } - if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil { - t.Fatalf("AddAddress sn lladdr0: %v", err) - } - - _, linkEP1 := channel.New(256, defaultMTU, linkAddr1) - c.linkEP1 = linkEP1 - wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1} - id1 := stack.RegisterLinkEndpoint(wrappedEP1) - if err := c.s1.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { - t.Fatalf("AddAddress lladdr1: %v", err) - } - if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil { - t.Fatalf("AddAddress sn lladdr1: %v", err) - } - - c.s0.SetRouteTable( - []tcpip.Route{{ - Destination: lladdr1, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), - NIC: 1, - }}, - ) - c.s1.SetRouteTable( - []tcpip.Route{{ - Destination: lladdr0, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), - NIC: 1, - }}, - ) - - return c -} - -func (c *testContext) cleanup() { - close(c.linkEP0.C) - close(c.linkEP1.C) -} - -type routeArgs struct { - src, dst *channel.Endpoint - typ header.ICMPv6Type -} - -func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { - t.Helper() - - pkt := <-args.src.C - - { - views := []buffer.View{pkt.Header, pkt.Payload} - size := len(pkt.Header) + len(pkt.Payload) - vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv) - } - - if pkt.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pkt.Proto) - return - } - ipv6 := header.IPv6(pkt.Header) - transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) - if transProto != header.ICMPv6ProtocolNumber { - t.Errorf("unexpected transport protocol number %d", transProto) - return - } - icmpv6 := header.ICMPv6(ipv6.Payload()) - if got, want := icmpv6.Type(), args.typ; got != want { - t.Errorf("got ICMPv6 type = %d, want = %d", got, want) - return - } - if fn != nil { - fn(t, icmpv6) - } -} - -func TestLinkResolution(t *testing.T) { - c := newTestContext(t) - defer c.cleanup() - - r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - defer r.Release() - - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - payload := tcpip.SlicePayload(hdr.View()) - - // We can't send our payload directly over the route because that - // doesn't provoke NDP discovery. - var wq waiter.Queue - ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } - - for { - _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}) - if resCh != nil { - if err != tcpip.ErrNoLinkAddress { - t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err) - } - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, - } { - routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { - if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { - t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) - } - }) - } - <-resCh - continue - } - if err != nil { - t.Fatalf("ep.Write(_) = _, _, %s", err) - } - break - } - - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, - } { - routeICMPv6Packet(t, args, nil) - } -} diff --git a/pkg/tcpip/network/ipv6/ipv6_state_autogen.go b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go new file mode 100755 index 000000000..53319e0c4 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package ipv6 + diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD deleted file mode 100644 index 989058413..000000000 --- a/pkg/tcpip/ports/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "ports", - srcs = ["ports.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", - visibility = ["//:sandbox"], - deps = [ - "//pkg/tcpip", - ], -) - -go_test( - name = "ports_test", - srcs = ["ports_test.go"], - embed = [":ports"], - deps = [ - "//pkg/tcpip", - ], -) diff --git a/pkg/tcpip/ports/ports_state_autogen.go b/pkg/tcpip/ports/ports_state_autogen.go new file mode 100755 index 000000000..664cc3e71 --- /dev/null +++ b/pkg/tcpip/ports/ports_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package ports + diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go deleted file mode 100644 index 689401661..000000000 --- a/pkg/tcpip/ports/ports_test.go +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ports - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" -) - -const ( - fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeNetworkNumber tcpip.NetworkProtocolNumber = 2 - - fakeIPAddress = tcpip.Address("\x08\x08\x08\x08") - fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09") -) - -type portReserveTestAction struct { - port uint16 - ip tcpip.Address - want *tcpip.Error - reuse bool - release bool -} - -func TestPortReservation(t *testing.T) { - for _, test := range []struct { - tname string - actions []portReserveTestAction - }{ - { - tname: "bind to ip", - actions: []portReserveTestAction{ - {port: 80, ip: fakeIPAddress, want: nil}, - {port: 80, ip: fakeIPAddress1, want: nil}, - /* N.B. Order of tests matters! */ - {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, - }, - }, - { - tname: "bind to inaddr any", - actions: []portReserveTestAction{ - {port: 22, ip: anyIPAddress, want: nil}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - /* release fakeIPAddress, but anyIPAddress is still inuse */ - {port: 22, ip: fakeIPAddress, release: true}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, - /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ - {port: 22, ip: anyIPAddress, want: nil, release: true}, - {port: 22, ip: fakeIPAddress, want: nil}, - }, - }, { - tname: "bind to zero port", - actions: []portReserveTestAction{ - {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, reuse: true, want: nil}, - }, - }, { - tname: "bind to ip with reuseport", - actions: []portReserveTestAction{ - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, - - {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - - {port: 25, ip: anyIPAddress, reuse: true, want: nil}, - }, - }, { - tname: "bind to inaddr any with reuseport", - actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, - - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, release: true, want: nil}, - - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: nil}, - }, - }, - } { - t.Run(test.tname, func(t *testing.T) { - pm := NewPortManager() - net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} - - for _, test := range test.actions { - if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port) - continue - } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse) - if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want) - } - if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) - } - } - }) - - } -} - -func TestPickEphemeralPort(t *testing.T) { - pm := NewPortManager() - customErr := &tcpip.Error{} - for _, test := range []struct { - name string - f func(port uint16) (bool, *tcpip.Error) - wantErr *tcpip.Error - wantPort uint16 - }{ - { - name: "no-port-available", - f: func(port uint16) (bool, *tcpip.Error) { - return false, nil - }, - wantErr: tcpip.ErrNoPortAvailable, - }, - { - name: "port-tester-error", - f: func(port uint16) (bool, *tcpip.Error) { - return false, customErr - }, - wantErr: customErr, - }, - { - name: "only-port-16042-available", - f: func(port uint16) (bool, *tcpip.Error) { - if port == FirstEphemeral+42 { - return true, nil - } - return false, nil - }, - wantPort: FirstEphemeral + 42, - }, - { - name: "only-port-under-16000-available", - f: func(port uint16) (bool, *tcpip.Error) { - if port < FirstEphemeral { - return true, nil - } - return false, nil - }, - wantErr: tcpip.ErrNoPortAvailable, - }, - } { - t.Run(test.name, func(t *testing.T) { - if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr { - t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) - } - }) - } -} diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD deleted file mode 100644 index 996939581..000000000 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "tun_tcp_connect", - srcs = ["main.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go deleted file mode 100644 index 3ac381631..000000000 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// This sample creates a stack with TCP and IPv4 protocols on top of a TUN -// device, and connects to a peer. Similar to "nc <address> <port>". While the -// sample is running, attempts to connect to its IPv4 address will result in -// a RST segment. -// -// As an example of how to run it, a TUN device can be created and enabled on -// a linux host as follows (this only needs to be done once per boot): -// -// [sudo] ip tuntap add user <username> mode tun <device-name> -// [sudo] ip link set <device-name> up -// [sudo] ip addr add <ipv4-address>/<mask-length> dev <device-name> -// -// A concrete example: -// -// $ sudo ip tuntap add user wedsonaf mode tun tun0 -// $ sudo ip link set tun0 up -// $ sudo ip addr add 192.168.1.1/24 dev tun0 -// -// Then one can run tun_tcp_connect as such: -// -// $ ./tun/tun_tcp_connect tun0 192.168.1.2 0 192.168.1.1 1234 -// -// This will attempt to connect to the linux host's stack. One can run nc in -// listen mode to accept a connect from tun_tcp_connect and exchange data. -package main - -import ( - "bufio" - "fmt" - "log" - "math/rand" - "net" - "os" - "strconv" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// writer reads from standard input and writes to the endpoint until standard -// input is closed. It signals that it's done by closing the provided channel. -func writer(ch chan struct{}, ep tcpip.Endpoint) { - defer func() { - ep.Shutdown(tcpip.ShutdownWrite) - close(ch) - }() - - r := bufio.NewReader(os.Stdin) - for { - v := buffer.NewView(1024) - n, err := r.Read(v) - if err != nil { - return - } - - v.CapLength(n) - for len(v) > 0 { - n, _, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - if err != nil { - fmt.Println("Write failed:", err) - return - } - - v.TrimFront(int(n)) - } - } -} - -func main() { - if len(os.Args) != 6 { - log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-ipv4-address> <local-port> <remote-ipv4-address> <remote-port>") - } - - tunName := os.Args[1] - addrName := os.Args[2] - portName := os.Args[3] - remoteAddrName := os.Args[4] - remotePortName := os.Args[5] - - rand.Seed(time.Now().UnixNano()) - - addr := tcpip.Address(net.ParseIP(addrName).To4()) - remote := tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(net.ParseIP(remoteAddrName).To4()), - } - - var localPort uint16 - if v, err := strconv.Atoi(portName); err != nil { - log.Fatalf("Unable to convert port %v: %v", portName, err) - } else { - localPort = uint16(v) - } - - if v, err := strconv.Atoi(remotePortName); err != nil { - log.Fatalf("Unable to convert port %v: %v", remotePortName, err) - } else { - remote.Port = uint16(v) - } - - // Create the stack with ipv4 and tcp protocols, then add a tun-based - // NIC and ipv4 address. - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) - - mtu, err := rawfile.GetMTU(tunName) - if err != nil { - log.Fatal(err) - } - - fd, err := tun.Open(tunName) - if err != nil { - log.Fatal(err) - } - - linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) - if err != nil { - log.Fatal(err) - } - if err := s.CreateNIC(1, sniffer.New(linkID)); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil { - log.Fatal(err) - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", - NIC: 1, - }, - }) - - // Create TCP endpoint. - var wq waiter.Queue - ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - log.Fatal(e) - } - - // Bind if a port is specified. - if localPort != 0 { - if err := ep.Bind(tcpip.FullAddress{0, "", localPort}); err != nil { - log.Fatal("Bind failed: ", err) - } - } - - // Issue connect request and wait for it to complete. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - terr := ep.Connect(remote) - if terr == tcpip.ErrConnectStarted { - fmt.Println("Connect is pending...") - <-notifyCh - terr = ep.GetSockOpt(tcpip.ErrorOption{}) - } - wq.EventUnregister(&waitEntry) - - if terr != nil { - log.Fatal("Unable to connect: ", terr) - } - - fmt.Println("Connected") - - // Start the writer in its own goroutine. - writerCompletedCh := make(chan struct{}) - go writer(writerCompletedCh, ep) // S/R-SAFE: sample code. - - // Read data and write to standard output until the peer closes the - // connection from its side. - wq.EventRegister(&waitEntry, waiter.EventIn) - for { - v, _, err := ep.Read(nil) - if err != nil { - if err == tcpip.ErrClosedForReceive { - break - } - - if err == tcpip.ErrWouldBlock { - <-notifyCh - continue - } - - log.Fatal("Read() failed:", err) - } - - os.Stdout.Write(v) - } - wq.EventUnregister(&waitEntry) - - // The reader has completed. Now wait for the writer as well. - <-writerCompletedCh - - ep.Close() -} diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD deleted file mode 100644 index dad8ef399..000000000 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "tun_tcp_echo", - srcs = ["main.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/link/fdbased", - "//pkg/tcpip/link/rawfile", - "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go deleted file mode 100644 index da425394a..000000000 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// This sample creates a stack with TCP and IPv4 protocols on top of a TUN -// device, and listens on a port. Data received by the server in the accepted -// connections is echoed back to the clients. -package main - -import ( - "flag" - "log" - "math/rand" - "net" - "os" - "strconv" - "strings" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" - "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var tap = flag.Bool("tap", false, "use tap istead of tun") -var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device") - -func echo(wq *waiter.Queue, ep tcpip.Endpoint) { - defer ep.Close() - - // Create wait queue entry that notifies a channel. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - - wq.EventRegister(&waitEntry, waiter.EventIn) - defer wq.EventUnregister(&waitEntry) - - for { - v, _, err := ep.Read(nil) - if err != nil { - if err == tcpip.ErrWouldBlock { - <-notifyCh - continue - } - - return - } - - ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - } -} - -func main() { - flag.Parse() - if len(flag.Args()) != 3 { - log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-address> <local-port>") - } - - tunName := flag.Arg(0) - addrName := flag.Arg(1) - portName := flag.Arg(2) - - rand.Seed(time.Now().UnixNano()) - - // Parse the mac address. - maddr, err := net.ParseMAC(*mac) - if err != nil { - log.Fatalf("Bad MAC address: %v", *mac) - } - - // Parse the IP address. Support both ipv4 and ipv6. - parsedAddr := net.ParseIP(addrName) - if parsedAddr == nil { - log.Fatalf("Bad IP address: %v", addrName) - } - - var addr tcpip.Address - var proto tcpip.NetworkProtocolNumber - if parsedAddr.To4() != nil { - addr = tcpip.Address(parsedAddr.To4()) - proto = ipv4.ProtocolNumber - } else if parsedAddr.To16() != nil { - addr = tcpip.Address(parsedAddr.To16()) - proto = ipv6.ProtocolNumber - } else { - log.Fatalf("Unknown IP type: %v", addrName) - } - - localPort, err := strconv.Atoi(portName) - if err != nil { - log.Fatalf("Unable to convert port %v: %v", portName, err) - } - - // Create the stack with ip and tcp protocols, then add a tun-based - // NIC and address. - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) - - mtu, err := rawfile.GetMTU(tunName) - if err != nil { - log.Fatal(err) - } - - var fd int - if *tap { - fd, err = tun.OpenTAP(tunName) - } else { - fd, err = tun.Open(tunName) - } - if err != nil { - log.Fatal(err) - } - - linkID, err := fdbased.New(&fdbased.Options{ - FDs: []int{fd}, - MTU: mtu, - EthernetHeader: *tap, - Address: tcpip.LinkAddress(maddr), - }) - if err != nil { - log.Fatal(err) - } - if err := s.CreateNIC(1, linkID); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, proto, addr); err != nil { - log.Fatal(err) - } - - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - log.Fatal(err) - } - - // Add default route. - s.SetRouteTable([]tcpip.Route{ - { - Destination: tcpip.Address(strings.Repeat("\x00", len(addr))), - Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))), - Gateway: "", - NIC: 1, - }, - }) - - // Create TCP endpoint, bind it, then start listening. - var wq waiter.Queue - ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) - if err != nil { - log.Fatal(e) - } - - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}); err != nil { - log.Fatal("Bind failed: ", err) - } - - if err := ep.Listen(10); err != nil { - log.Fatal("Listen failed: ", err) - } - - // Wait for connections to appear. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventIn) - defer wq.EventUnregister(&waitEntry) - - for { - n, wq, err := ep.Accept() - if err != nil { - if err == tcpip.ErrWouldBlock { - <-notifyCh - continue - } - - log.Fatal("Accept() failed:", err) - } - - go echo(wq, n) // S/R-SAFE: sample code. - } -} diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD deleted file mode 100644 index 76b5f4ffa..000000000 --- a/pkg/tcpip/seqnum/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "seqnum", - srcs = ["seqnum.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum", - visibility = [ - "//visibility:public", - ], -) diff --git a/pkg/tcpip/seqnum/seqnum_state_autogen.go b/pkg/tcpip/seqnum/seqnum_state_autogen.go new file mode 100755 index 000000000..bf76f6ac4 --- /dev/null +++ b/pkg/tcpip/seqnum/seqnum_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package seqnum + diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD deleted file mode 100644 index 28d11c797..000000000 --- a/pkg/tcpip/stack/BUILD +++ /dev/null @@ -1,60 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "stack", - srcs = [ - "linkaddrcache.go", - "nic.go", - "registration.go", - "route.go", - "stack.go", - "stack_global_state.go", - "transport_demuxer.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/stack", - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/ilist", - "//pkg/sleep", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/hash/jenkins", - "//pkg/tcpip/header", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/waiter", - ], -) - -go_test( - name = "stack_x_test", - size = "small", - srcs = [ - "stack_test.go", - "transport_test.go", - ], - deps = [ - ":stack", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/waiter", - ], -) - -go_test( - name = "stack_test", - size = "small", - srcs = ["linkaddrcache_test.go"], - embed = [":stack"], - deps = [ - "//pkg/sleep", - "//pkg/tcpip", - ], -) diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go deleted file mode 100644 index 924f4d240..000000000 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "sync" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sleep" - "gvisor.dev/gvisor/pkg/tcpip" -) - -type testaddr struct { - addr tcpip.FullAddress - linkAddr tcpip.LinkAddress -} - -var testaddrs []testaddr - -type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration -} - -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { - go func() { - if r.delay > 0 { - time.Sleep(r.delay) - } - r.fakeRequest(addr) - }() - return nil -} - -func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testaddrs { - if ta.addr.Addr == addr { - r.cache.add(ta.addr, ta.linkAddr) - break - } - } -} - -func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "broadcast" { - return "mac_broadcast", true - } - return "", false -} - -func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return 1 -} - -func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - - for { - if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - return got, err - } - s.Fetch(true) - } -} - -func init() { - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - testaddrs = append(testaddrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } -} - -func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - for i := len(testaddrs) - 1; i >= 0; i-- { - e := testaddrs[i] - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) - } - } - // Expect to find at least half of the most recent entries. - for i := 0; i < linkAddrCacheSize/2; i++ { - e := testaddrs[i] - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) - } - } - // The earliest entries should no longer be in the cache. - for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- { - e := testaddrs[i] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) - } - } -} - -func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - - var wg sync.WaitGroup - for r := 0; r < 16; r++ { - wg.Add(1) - go func() { - for _, e := range testaddrs { - c.add(e.addr, e.linkAddr) - c.get(e.addr, nil, "", nil, nil) // make work for gotsan - } - wg.Done() - }() - } - wg.Wait() - - // All goroutines add in the same order and add more values than - // can fit in the cache, so our eviction strategy requires that - // the last entry be present and the first be missing. - e := testaddrs[len(testaddrs)-1] - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - - e = testaddrs[0] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } -} - -func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) - e := testaddrs[0] - c.add(e.addr, e.linkAddr) - time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } -} - -func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - e := testaddrs[0] - l2 := e.linkAddr + "2" - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - - c.add(e.addr, l2) - got, _, err = c.get(e.addr, nil, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != l2 { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2) - } -} - -func TestCacheResolution(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) - linkRes := &testLinkAddressResolver{cache: c} - for i, ta := range testaddrs { - got, err := getBlocking(c, ta.addr, linkRes) - if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) - } - if got != ta.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr) - } - } - - // Check that after resolved, address stays in the cache and never returns WouldBlock. - for i := 0; i < 10; i++ { - e := testaddrs[len(testaddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - } -} - -func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) - linkRes := &testLinkAddressResolver{cache: c} - - // First, sanity check that resolution is working... - e := testaddrs[0] - got, err := getBlocking(c, e.addr, linkRes) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) - } - if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) - } - - e.addr.Addr += "2" - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } -} - -func TestCacheResolutionTimeout(t *testing.T) { - resolverDelay := 500 * time.Millisecond - expiration := resolverDelay / 10 - c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) - linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} - - e := testaddrs[0] - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) - } -} - -// TestStaticResolution checks that static link addresses are resolved immediately and don't -// send resolution requests. -func TestStaticResolution(t *testing.T) { - c := newLinkAddrCache(1<<63-1, time.Millisecond, 1) - linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute} - - addr := tcpip.Address("broadcast") - want := tcpip.LinkAddress("mac_broadcast") - got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil) - if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err) - } - if got != want { - t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) - } -} diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go new file mode 100755 index 000000000..2a15cfa0a --- /dev/null +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -0,0 +1,59 @@ +// automatically generated by stateify. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *TransportEndpointID) beforeSave() {} +func (x *TransportEndpointID) save(m state.Map) { + x.beforeSave() + m.Save("LocalPort", &x.LocalPort) + m.Save("LocalAddress", &x.LocalAddress) + m.Save("RemotePort", &x.RemotePort) + m.Save("RemoteAddress", &x.RemoteAddress) +} + +func (x *TransportEndpointID) afterLoad() {} +func (x *TransportEndpointID) load(m state.Map) { + m.Load("LocalPort", &x.LocalPort) + m.Load("LocalAddress", &x.LocalAddress) + m.Load("RemotePort", &x.RemotePort) + m.Load("RemoteAddress", &x.RemoteAddress) +} + +func (x *GSOType) save(m state.Map) { + m.SaveValue("", (int)(*x)) +} + +func (x *GSOType) load(m state.Map) { + m.LoadValue("", new(int), func(y interface{}) { *x = (GSOType)(y.(int)) }) +} + +func (x *GSO) beforeSave() {} +func (x *GSO) save(m state.Map) { + x.beforeSave() + m.Save("Type", &x.Type) + m.Save("NeedsCsum", &x.NeedsCsum) + m.Save("CsumOffset", &x.CsumOffset) + m.Save("MSS", &x.MSS) + m.Save("L3HdrLen", &x.L3HdrLen) + m.Save("MaxSize", &x.MaxSize) +} + +func (x *GSO) afterLoad() {} +func (x *GSO) load(m state.Map) { + m.Load("Type", &x.Type) + m.Load("NeedsCsum", &x.NeedsCsum) + m.Load("CsumOffset", &x.CsumOffset) + m.Load("MSS", &x.MSS) + m.Load("L3HdrLen", &x.L3HdrLen) + m.Load("MaxSize", &x.MaxSize) +} + +func init() { + state.Register("stack.TransportEndpointID", (*TransportEndpointID)(nil), state.Fns{Save: (*TransportEndpointID).save, Load: (*TransportEndpointID).load}) + state.Register("stack.GSOType", (*GSOType)(nil), state.Fns{Save: (*GSOType).save, Load: (*GSOType).load}) + state.Register("stack.GSO", (*GSO)(nil), state.Fns{Save: (*GSO).save, Load: (*GSO).load}) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go deleted file mode 100644 index 69884af03..000000000 --- a/pkg/tcpip/stack/stack_test.go +++ /dev/null @@ -1,1147 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package stack_test contains tests for the stack. It is in its own package so -// that the tests can also validate that all definitions needed to implement -// transport and network protocols are properly exported by the stack package. -package stack_test - -import ( - "bytes" - "fmt" - "math" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fakeNetHeaderLen = 12 - - // fakeControlProtocol is used for control packets that represent - // destination port unreachable. - fakeControlProtocol tcpip.TransportProtocolNumber = 2 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 -) - -// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only -// use the first three: destination address, source address, and transport -// protocol. They're all one byte fields to simplify parsing. -type fakeNetworkEndpoint struct { - nicid tcpip.NICID - id stack.NetworkEndpointID - proto *fakeNetworkProtocol - dispatcher stack.TransportDispatcher - linkEP stack.LinkEndpoint -} - -func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) -} - -func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicid -} - -func (*fakeNetworkEndpoint) DefaultTTL() uint8 { - return 123 -} - -func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { - return &f.id -} - -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ - - // Consume the network header. - b := vv.First() - vv.TrimFront(fakeNetHeaderLen) - - // Handle control packets. - if b[2] == uint8(fakeControlProtocol) { - nb := vv.First() - if len(nb) < fakeNetHeaderLen { - return - } - - vv.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv) - return - } - - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) -} - -func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen -} - -func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - -func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.linkEP.Capabilities() -} - -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { - // Increment the sent packet count in the protocol descriptor. - f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ - - // Add the protocol's header to the packet and send it to the link - // endpoint. - b := hdr.Prepend(fakeNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(protocol) - - if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - f.HandlePacket(r, vv) - } - if loop&stack.PacketOut == 0 { - return nil - } - - return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) -} - -func (*fakeNetworkEndpoint) Close() {} - -type fakeNetGoodOption bool - -type fakeNetBadOption bool - -type fakeNetInvalidValueOption int - -type fakeNetOptions struct { - good bool -} - -// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the -// number of packets sent and received via endpoints of this protocol. The index -// where packets are added is given by the packet's destination address MOD 10. -type fakeNetworkProtocol struct { - packetCount [10]int - sendPacketCount [10]int - opts fakeNetOptions -} - -func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fakeNetNumber -} - -func (f *fakeNetworkProtocol) MinimumPacketSize() int { - return fakeNetHeaderLen -} - -func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) -} - -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { - return &fakeNetworkEndpoint{ - nicid: nicid, - id: stack.NetworkEndpointID{addr}, - proto: f, - dispatcher: dispatcher, - linkEP: linkEP, - }, nil -} - -func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { - switch v := option.(type) { - case fakeNetGoodOption: - f.opts.good = bool(v) - return nil - case fakeNetInvalidValueOption: - return tcpip.ErrInvalidOptionValue - default: - return tcpip.ErrUnknownProtocolOption - } -} - -func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { - switch v := option.(type) { - case *fakeNetGoodOption: - *v = fakeNetGoodOption(f.opts.good) - return nil - default: - return tcpip.ErrUnknownProtocolOption - } -} - -func TestNetworkReceive(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and two - // addresses attached to it: 1 & 2. - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Make sure packet with wrong address is not delivered. - buf[0] = 3 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to first endpoint. - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to second endpoint. - buf[0] = 2 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet is not delivered if protocol number is wrong. - linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet that is too small is dropped. - buf.CapLength(2) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } -} - -func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) { - r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute failed: %v", err) - } - defer r.Release() - - hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil { - t.Errorf("WritePacket failed: %v", err) - } -} - -func TestNetworkSend(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and one - // address: 1. The route table sends all packets through the only - // existing nic. - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("NewNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Make sure that the link-layer endpoint received the outbound packet. - sendTo(t, s, "\x03", nil) - if c := linkEP.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } -} - -func TestNetworkSendMultiRoute(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\x01", "\x00", 1}, - {"\x00", "\x01", "\x00", 2}, - }) - - // Send a packet to an odd destination. - sendTo(t, s, "\x05", nil) - - if c := linkEP1.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } - - // Send a packet to an even destination. - sendTo(t, s, "\x06", nil) - - if c := linkEP2.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } -} - -func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { - r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute failed: %v", err) - } - - defer r.Release() - - if r.LocalAddress != expectedSrcAddr { - t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress) - } - - if r.RemoteAddress != dstAddr { - t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress) - } -} - -func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { - _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err) - } -} - -func TestRoutes(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - - id1, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - id2, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\x01", "\x00", 1}, - {"\x00", "\x01", "\x00", 2}, - }) - - // Test routes to odd address. - testRoute(t, s, 0, "", "\x05", "\x01") - testRoute(t, s, 0, "\x01", "\x05", "\x01") - testRoute(t, s, 1, "\x01", "\x05", "\x01") - testRoute(t, s, 0, "\x03", "\x05", "\x03") - testRoute(t, s, 1, "\x03", "\x05", "\x03") - - // Test routes to even address. - testRoute(t, s, 0, "", "\x06", "\x02") - testRoute(t, s, 0, "\x02", "\x06", "\x02") - testRoute(t, s, 2, "\x02", "\x06", "\x02") - testRoute(t, s, 0, "\x04", "\x06", "\x04") - testRoute(t, s, 2, "\x04", "\x06", "\x04") - - // Try to send to odd numbered address from even numbered ones, then - // vice-versa. - testNoRoute(t, s, 0, "\x02", "\x05") - testNoRoute(t, s, 2, "\x02", "\x05") - testNoRoute(t, s, 0, "\x04", "\x05") - testNoRoute(t, s, 2, "\x04", "\x05") - - testNoRoute(t, s, 0, "\x01", "\x06") - testNoRoute(t, s, 1, "\x01", "\x06") - testNoRoute(t, s, 0, "\x03", "\x06") - testNoRoute(t, s, 1, "\x03", "\x06") -} - -func TestAddressRemoval(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - - // Remove the address, then check that packet doesn't get delivered - // anymore. - if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) - } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - - // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress failed: %v", err) - } -} - -func TestDelayedRemovalDueToRoute(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - - // Get a route, check that packet is still deliverable. - r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute failed: %v", err) - } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 2 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2) - } - - // Remove the address, then check that packet is still deliverable - // because the route is keeping the address alive. - if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) - } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) - } - - // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress failed: %v", err) - } - - // Release the route, then check that packet is not deliverable anymore. - r.Release() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) - } -} - -func TestPromiscuousMode(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Write a packet, and check that it doesn't get delivered as we don't - // have a matching endpoint. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } - - // Set promiscuous mode, then check that packet is delivered. - if err := s.SetPromiscuousMode(1, true); err != nil { - t.Fatalf("SetPromiscuousMode failed: %v", err) - } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - - // Check that we can't get a route as there is no local address. - _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err) - } - - // Set promiscuous mode to false, then check that packet can't be - // delivered anymore. - if err := s.SetPromiscuousMode(1, false); err != nil { - t.Fatalf("SetPromiscuousMode failed: %v", err) - } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } -} - -func TestAddressSpoofing(t *testing.T) { - srcAddr := tcpip.Address("\x01") - dstAddr := tcpip.Address("\x02") - - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) - - // With address spoofing disabled, FindRoute does not permit an address - // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err == nil { - t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) - } - - // With address spoofing enabled, FindRoute permits any address to be used - // as the source. - if err := s.SetSpoofing(1, true); err != nil { - t.Fatalf("SetSpoofing failed: %v", err) - } - r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute failed: %v", err) - } - if r.LocalAddress != srcAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr) - } - if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) - } -} - -func TestBroadcastNeedsNoRoute(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - s.SetRouteTable([]tcpip.Route{}) - - // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) - } - - if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, header.IPv4Any, err) - } - r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) - } - - if r.LocalAddress != header.IPv4Any { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, header.IPv4Any) - } - - if r.RemoteAddress != header.IPv4Broadcast { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, header.IPv4Broadcast) - } - - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) - } -} - -func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { - for _, tc := range []struct { - name string - routeNeeded bool - address tcpip.Address - }{ - // IPv4 multicast address range: 224.0.0.0 - 239.255.255.255 - // <=> 0xe0.0x00.0x00.0x00 - 0xef.0xff.0xff.0xff - {"IPv4 Multicast 1", false, "\xe0\x00\x00\x00"}, - {"IPv4 Multicast 2", false, "\xef\xff\xff\xff"}, - {"IPv4 Unicast 1", true, "\xdf\xff\xff\xff"}, - {"IPv4 Unicast 2", true, "\xf0\x00\x00\x00"}, - {"IPv4 Unicast 3", true, "\x00\x00\x00\x00"}, - - // IPv6 multicast address is 0xff[8] + flags[4] + scope[4] + groupId[112] - {"IPv6 Multicast 1", false, "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 2", false, "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 3", false, "\xff\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 link-local address starts with fe80::/10. - {"IPv6 Unicast Link-Local 1", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 2", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, - {"IPv6 Unicast Link-Local 3", false, "\xfe\x80\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff"}, - {"IPv6 Unicast Link-Local 4", false, "\xfe\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 5", false, "\xfe\xbf\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 addresses that are neither multicast nor link-local. - {"IPv6 Unicast Not Link-Local 1", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 2", true, "\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - {"IPv6 Unicast Not Link-local 3", true, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 4", true, "\xfe\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 5", true, "\xfe\xdf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 6", true, "\xfd\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - } { - t.Run(tc.name, func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{}) - - var anyAddr tcpip.Address - if len(tc.address) == header.IPv4AddressSize { - anyAddr = header.IPv4Any - } else { - anyAddr = header.IPv6Any - } - - want := tcpip.ErrNetworkUnreachable - if tc.routeNeeded { - want = tcpip.ErrNoRoute - } - - // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - - if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) - } - - if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { - // Route table is empty but we need a route, this should cause an error. - if err != tcpip.ErrNoRoute { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute) - } - } else { - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", anyAddr, tc.address, fakeNetNumber, err) - } - if r.LocalAddress != anyAddr { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, anyAddr) - } - if r.RemoteAddress != tc.address { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, tc.address) - } - } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - }) - } -} - -// Set the subnet, then check that packet is delivered. -func TestSubnetAcceptsMatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - buf[0] = 1 - fakeNet.packetCount[1] = 0 - subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) - } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) - } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } -} - -// Set destination outside the subnet, then check it doesn't get delivered. -func TestSubnetRejectsNonmatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - buf[0] = 1 - fakeNet.packetCount[1] = 0 - subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) - } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) - } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } -} - -func TestNetworkOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{}) - - // Try an unsupported network protocol. - if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.NetworkProtocol) - }{ - {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) { - t.Helper() - fakeNet := p.(*fakeNetworkProtocol) - if fakeNet.opts.good != true { - t.Fatalf("fakeNet.opts.good = false, want = true") - } - var v fakeNetGoodOption - if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v) - } - }}, - {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber)) - } - } -} - -func TestSubnetAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - addr := tcpip.Address("\x01\x01\x01\x01") - mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - subnet, err := tcpip.NewSubnet(addr, mask) - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) - } - - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatalf("ContainsSubnet failed: %v", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") - } - - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) - } - - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatalf("ContainsSubnet failed: %v", err) - } else if !contained { - t.Fatal("got s.ContainsSubnet(...) = false, want = true") - } - - if err := s.RemoveSubnet(1, subnet); err != nil { - t.Fatalf("RemoveSubnet failed: %v", err) - } - - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatalf("ContainsSubnet failed: %v", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") - } -} - -func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { - for _, addrLen := range []int{4, 16} { - t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) { - for canBe := 0; canBe < 3; canBe++ { - t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { - for never := 0; never < 3; never++ { - t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - // Insert <canBe> primary and <never> never-primary addresses. - // Each one will add a network endpoint to the NIC. - primaryAddrAdded := make(map[tcpip.Address]tcpip.Subnet) - for i := 0; i < canBe+never; i++ { - var behavior stack.PrimaryEndpointBehavior - if i < canBe { - behavior = stack.CanBePrimaryEndpoint - } else { - behavior = stack.NeverPrimaryEndpoint - } - // Add an address and in case of a primary one also add a - // subnet. - address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) - if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions failed: %v", err) - } - if behavior == stack.CanBePrimaryEndpoint { - mask := tcpip.AddressMask(strings.Repeat("\xff", len(address))) - subnet, err := tcpip.NewSubnet(address, mask) - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) - } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) - } - // Remember the address/subnet. - primaryAddrAdded[address] = subnet - } - } - // Check that GetMainNICAddress returns an address if at least - // one primary address was added. In that case make sure the - // address/subnet matches what we added. - if len(primaryAddrAdded) == 0 { - // No primary addresses present, expect an error. - if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %v", err, tcpip.ErrNoLinkAddress) - } - } else { - // At least one primary address was added, expect a valid - // address and subnet. - gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatalf("GetMainNICAddress failed: %v", err) - } - expectedSubnet, ok := primaryAddrAdded[gotAddress] - if !ok { - t.Fatalf("GetMainNICAddress: got address = %v, wanted any in {%v}", gotAddress, primaryAddrAdded) - } - if gotSubnet != expectedSubnet { - t.Fatalf("GetMainNICAddress: got subnet = %v, wanted %v", gotSubnet, expectedSubnet) - } - } - }) - } - }) - } - }) - } -} - -func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - for _, tc := range []struct { - name string - address tcpip.Address - }{ - {"IPv4", "\x01\x01\x01\x01"}, - {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, - } { - t.Run(tc.name, func(t *testing.T) { - address := tc.address - mask := tcpip.AddressMask(strings.Repeat("\xff", len(address))) - subnet, err := tcpip.NewSubnet(address, mask) - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatalf("AddSubnet failed: %v", err) - } - - // Check that we get the right initial address and subnet. - if gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { - t.Fatalf("GetMainNICAddress failed: %v", err) - } else if gotAddress != address { - t.Fatalf("got GetMainNICAddress = (%v, ...), want = (%v, ...)", gotAddress, address) - } else if gotSubnet != subnet { - t.Fatalf("got GetMainNICAddress = (..., %v), want = (..., %v)", gotSubnet, subnet) - } - - if err := s.RemoveSubnet(1, subnet); err != nil { - t.Fatalf("RemoveSubnet failed: %v", err) - } - - if err := s.RemoveAddress(1, address); err != nil { - t.Fatalf("RemoveAddress failed: %v", err) - } - - // Check that we get an error after removal. - if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %v", err, tcpip.ErrNoLinkAddress) - } - }) - } -} - -func TestNICStats(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - // Route all packets for address \x01 to NIC 1. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\xff", "\x00", 1}, - }) - - // Send a packet to address 1. - buf := buffer.NewView(30) - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) - } - - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - sendTo(t, s, "\x01", payload) - want := uint64(linkEP1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) - } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } -} - -func TestNICForwarding(t *testing.T) { - // Create a stack with the fake network protocol, two NICs, each with - // an address. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - s.SetForwarding(true) - - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) - } - - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) - } - - // Route all packets to address 3 to NIC 2. - s.SetRouteTable([]tcpip.Route{ - {"\x03", "\xff", "\x00", 2}, - }) - - // Send a packet to address 3. - buf := buffer.NewView(30) - buf[0] = 3 - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) - - select { - case <-linkEP2.C: - default: - t.Fatal("Packet not forwarded") - } - - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } -} - -func init() { - stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { - return &fakeNetworkProtocol{} - }) -} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go deleted file mode 100644 index 10fc8f195..000000000 --- a/pkg/tcpip/stack/transport_test.go +++ /dev/null @@ -1,530 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen = 3 -) - -// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't -// use it. -type fakeTransportEndpoint struct { - id stack.TransportEndpointID - stack *stack.Stack - netProto tcpip.NetworkProtocolNumber - proto *fakeTransportProtocol - peerAddr tcpip.Address - route stack.Route - - // acceptQueue is non-nil iff bound. - acceptQueue []fakeTransportEndpoint -} - -func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto} -} - -func (f *fakeTransportEndpoint) Close() { - f.route.Release() -} - -func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { - return mask -} - -func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return buffer.View{}, tcpip.ControlMessages{}, nil -} - -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { - if len(f.route.RemoteAddress) == 0 { - return 0, nil, tcpip.ErrNoRoute - } - - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.Get(p.Size()) - if err != nil { - return 0, nil, err - } - if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil { - return 0, nil, err - } - - return uintptr(len(v)), nil, nil -} - -func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil -} - -// SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: - return nil - } - return tcpip.ErrInvalidEndpointState -} - -func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - f.peerAddr = addr.Addr - - // Find the route. - r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - return tcpip.ErrNoRoute - } - defer r.Release() - - // Try to register so that we can start receiving packets. - f.id.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false) - if err != nil { - return err - } - - f.route = r.Clone() - - return nil -} - -func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Reset() { -} - -func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { - return nil -} - -func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - if len(f.acceptQueue) == 0 { - return nil, nil, nil - } - a := f.acceptQueue[0] - f.acceptQueue = f.acceptQueue[1:] - return &a, nil, nil -} - -func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { - if err := f.stack.RegisterTransportEndpoint( - a.NIC, - []tcpip.NetworkProtocolNumber{fakeNetNumber}, - fakeTransNumber, - stack.TransportEndpointID{LocalAddress: a.Addr}, - f, - false, - ); err != nil { - return err - } - f.acceptQueue = []fakeTransportEndpoint{} - return nil -} - -func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { - // Increment the number of received packets. - f.proto.packetCount++ - if f.acceptQueue != nil { - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - id: id, - stack: f.stack, - netProto: f.netProto, - proto: f.proto, - peerAddr: r.RemoteAddress, - route: r.Clone(), - }) - } -} - -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { - // Increment the number of received control packets. - f.proto.controlCount++ -} - -func (f *fakeTransportEndpoint) State() uint32 { - return 0 -} - -type fakeTransportGoodOption bool - -type fakeTransportBadOption bool - -type fakeTransportInvalidValueOption int - -type fakeTransportProtocolOptions struct { - good bool -} - -// fakeTransportProtocol is a transport-layer protocol descriptor. It -// aggregates the number of packets received via endpoints of this protocol. -type fakeTransportProtocol struct { - packetCount int - controlCount int - opts fakeTransportProtocolOptions -} - -func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { - return fakeTransNumber -} - -func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto), nil -} - -func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return nil, tcpip.ErrUnknownProtocol -} - -func (*fakeTransportProtocol) MinimumPacketSize() int { - return fakeTransHeaderLen -} - -func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) { - return 0, 0, nil -} - -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { - return true -} - -func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error { - switch v := option.(type) { - case fakeTransportGoodOption: - f.opts.good = bool(v) - return nil - case fakeTransportInvalidValueOption: - return tcpip.ErrInvalidOptionValue - default: - return tcpip.ErrUnknownProtocolOption - } -} - -func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { - switch v := option.(type) { - case *fakeTransportGoodOption: - *v = fakeTransportGoodOption(f.opts.good) - return nil - default: - return tcpip.ErrUnknownProtocolOption - } -} - -func TestTransportReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the packet. - buf := buffer.NewView(30) - - // Make sure packet with wrong protocol is not delivered. - buf[0] = 1 - buf[2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[0] = 1 - buf[1] = 3 - buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet is delivered. - buf[0] = 1 - buf[1] = 2 - buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeTrans.packetCount != 1 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) - } -} - -func TestTransportControlReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the control packet. - buf := buffer.NewView(2*fakeNetHeaderLen + 30) - - // Outer packet contains the control protocol number. - buf[0] = 1 - buf[1] = 0xfe - buf[2] = uint8(fakeControlProtocol) - - // Make sure packet with wrong protocol is not delivered. - buf[fakeNetHeaderLen+0] = 0 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[fakeNetHeaderLen+0] = 3 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet is delivered. - buf[fakeNetHeaderLen+0] = 2 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeTrans.controlCount != 1 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) - } -} - -func TestTransportSend(t *testing.T) { - id, _ := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) - - // Create endpoint and bind it. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - // Create buffer that will hold the payload. - view := buffer.NewView(30) - _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("write failed: %v", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - if fakeNet.sendPacketCount[2] != 1 { - t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1) - } -} - -func TestTransportOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - - // Try an unsupported transport protocol. - if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.TransportProtocol) - }{ - {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) { - t.Helper() - fakeTrans := p.(*fakeTransportProtocol) - if fakeTrans.opts.good != true { - t.Fatalf("fakeTrans.opts.good = false, want = true") - } - var v fakeTransportGoodOption - if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v) - } - - }}, - {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber)) - } - } -} - -func TestTransportForwarding(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - s.SetForwarding(true) - - // TODO(b/123449044): Change this to a channel NIC. - id1 := loopback.New() - if err := s.CreateNIC(1, id1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) - } - - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) - } - - // Route all packets to address 3 to NIC 2 and all packets to address - // 1 to NIC 1. - s.SetRouteTable([]tcpip.Route{ - {"\x03", "\xff", "\x00", 2}, - {"\x01", "\xff", "\x00", 1}, - }) - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Send a packet to address 1 from address 3. - req := buffer.NewView(30) - req[0] = 1 - req[1] = 3 - req[2] = byte(fakeTransNumber) - linkEP2.Inject(fakeNetNumber, req.ToVectorisedView()) - - aep, _, err := ep.Accept() - if err != nil || aep == nil { - t.Fatalf("Accept failed: %v, %v", aep, err) - } - - resp := buffer.NewView(30) - if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - var p channel.PacketInfo - select { - case p = <-linkEP2.C: - default: - t.Fatal("Response packet not forwarded") - } - - if dst := p.Header[0]; dst != 3 { - t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) - } - if src := p.Header[1]; src != 1 { - t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) - } -} - -func init() { - stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol { - return &fakeTransportProtocol{} - }) -} diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go new file mode 100755 index 000000000..36beffe67 --- /dev/null +++ b/pkg/tcpip/tcpip_state_autogen.go @@ -0,0 +1,40 @@ +// automatically generated by stateify. + +package tcpip + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *FullAddress) beforeSave() {} +func (x *FullAddress) save(m state.Map) { + x.beforeSave() + m.Save("NIC", &x.NIC) + m.Save("Addr", &x.Addr) + m.Save("Port", &x.Port) +} + +func (x *FullAddress) afterLoad() {} +func (x *FullAddress) load(m state.Map) { + m.Load("NIC", &x.NIC) + m.Load("Addr", &x.Addr) + m.Load("Port", &x.Port) +} + +func (x *ControlMessages) beforeSave() {} +func (x *ControlMessages) save(m state.Map) { + x.beforeSave() + m.Save("HasTimestamp", &x.HasTimestamp) + m.Save("Timestamp", &x.Timestamp) +} + +func (x *ControlMessages) afterLoad() {} +func (x *ControlMessages) load(m state.Map) { + m.Load("HasTimestamp", &x.HasTimestamp) + m.Load("Timestamp", &x.Timestamp) +} + +func init() { + state.Register("tcpip.FullAddress", (*FullAddress)(nil), state.Fns{Save: (*FullAddress).save, Load: (*FullAddress).load}) + state.Register("tcpip.ControlMessages", (*ControlMessages)(nil), state.Fns{Save: (*ControlMessages).save, Load: (*ControlMessages).load}) +} diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go deleted file mode 100644 index ebb1c1b56..000000000 --- a/pkg/tcpip/tcpip_test.go +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip - -import ( - "fmt" - "net" - "strings" - "testing" -) - -func TestSubnetContains(t *testing.T) { - tests := []struct { - s Address - m AddressMask - a Address - want bool - }{ - {"\xa0", "\xf0", "\x90", false}, - {"\xa0", "\xf0", "\xa0", true}, - {"\xa0", "\xf0", "\xa5", true}, - {"\xa0", "\xf0", "\xaf", true}, - {"\xa0", "\xf0", "\xb0", false}, - {"\xa0", "\xf0", "", false}, - {"\xa0", "\xf0", "\xa0\x00", false}, - {"\xc2\x80", "\xff\xf0", "\xc2\x80", true}, - {"\xc2\x80", "\xff\xf0", "\xc2\x00", false}, - {"\xc2\x00", "\xff\xf0", "\xc2\x00", true}, - {"\xc2\x00", "\xff\xf0", "\xc2\x80", false}, - } - for _, tt := range tests { - s, err := NewSubnet(tt.s, tt.m) - if err != nil { - t.Errorf("NewSubnet(%v, %v) = %v", tt.s, tt.m, err) - continue - } - if got := s.Contains(tt.a); got != tt.want { - t.Errorf("Subnet(%v).Contains(%v) = %v, want %v", s, tt.a, got, tt.want) - } - } -} - -func TestSubnetBits(t *testing.T) { - tests := []struct { - a AddressMask - want1 int - want0 int - }{ - {"\x00", 0, 8}, - {"\x00\x00", 0, 16}, - {"\x36", 4, 4}, - {"\x5c", 4, 4}, - {"\x5c\x5c", 8, 8}, - {"\x5c\x36", 8, 8}, - {"\x36\x5c", 8, 8}, - {"\x36\x36", 8, 8}, - {"\xff", 8, 0}, - {"\xff\xff", 16, 0}, - } - for _, tt := range tests { - s := &Subnet{mask: tt.a} - got1, got0 := s.Bits() - if got1 != tt.want1 || got0 != tt.want0 { - t.Errorf("Subnet{mask: %x}.Bits() = %d, %d, want %d, %d", tt.a, got1, got0, tt.want1, tt.want0) - } - } -} - -func TestSubnetPrefix(t *testing.T) { - tests := []struct { - a AddressMask - want int - }{ - {"\x00", 0}, - {"\x00\x00", 0}, - {"\x36", 0}, - {"\x86", 1}, - {"\xc5", 2}, - {"\xff\x00", 8}, - {"\xff\x36", 8}, - {"\xff\x8c", 9}, - {"\xff\xc8", 10}, - {"\xff", 8}, - {"\xff\xff", 16}, - } - for _, tt := range tests { - s := &Subnet{mask: tt.a} - got := s.Prefix() - if got != tt.want { - t.Errorf("Subnet{mask: %x}.Bits() = %d want %d", tt.a, got, tt.want) - } - } -} - -func TestSubnetCreation(t *testing.T) { - tests := []struct { - a Address - m AddressMask - want error - }{ - {"\xa0", "\xf0", nil}, - {"\xa0\xa0", "\xf0", errSubnetLengthMismatch}, - {"\xaa", "\xf0", errSubnetAddressMasked}, - {"", "", nil}, - } - for _, tt := range tests { - if _, err := NewSubnet(tt.a, tt.m); err != tt.want { - t.Errorf("NewSubnet(%v, %v) = %v, want %v", tt.a, tt.m, err, tt.want) - } - } -} - -func TestRouteMatch(t *testing.T) { - tests := []struct { - d Address - m AddressMask - a Address - want bool - }{ - {"\xc2\x80", "\xff\xf0", "\xc2\x80", true}, - {"\xc2\x80", "\xff\xf0", "\xc2\x00", false}, - {"\xc2\x00", "\xff\xf0", "\xc2\x00", true}, - {"\xc2\x00", "\xff\xf0", "\xc2\x80", false}, - } - for _, tt := range tests { - r := Route{Destination: tt.d, Mask: tt.m} - if got := r.Match(tt.a); got != tt.want { - t.Errorf("Route(%v).Match(%v) = %v, want %v", r, tt.a, got, tt.want) - } - } -} - -func TestAddressString(t *testing.T) { - for _, want := range []string{ - // Taken from stdlib. - "2001:db8::123:12:1", - "2001:db8::1", - "2001:db8:0:1:0:1:0:1", - "2001:db8:1:0:1:0:1:0", - "2001::1:0:0:1", - "2001:db8:0:0:1::", - "2001:db8::1:0:0:1", - "2001:db8::a:b:c:d", - - // Leading zeros. - "::1", - // Trailing zeros. - "8::", - // No zeros. - "1:1:1:1:1:1:1:1", - // Longer sequence is after other zeros, but not at the end. - "1:0:0:1::1", - // Longer sequence is at the beginning, shorter sequence is at - // the end. - "::1:1:1:0:0", - // Longer sequence is not at the beginning, shorter sequence is - // at the end. - "1::1:1:0:0", - // Longer sequence is at the beginning, shorter sequence is not - // at the end. - "::1:1:0:0:1", - // Neither sequence is at an end, longer is after shorter. - "1:0:0:1::1", - // Shorter sequence is at the beginning, longer sequence is not - // at the end. - "0:0:1:1::1", - // Shorter sequence is at the beginning, longer sequence is at - // the end. - "0:0:1:1:1::", - // Short sequences at both ends, longer one in the middle. - "0:1:1::1:1:0", - // Short sequences at both ends, longer one in the middle. - "0:1::1:0:0", - // Short sequences at both ends, longer one in the middle. - "0:0:1::1:0", - // Longer sequence surrounded by shorter sequences, but none at - // the end. - "1:0:1::1:0:1", - } { - addr := Address(net.ParseIP(want)) - if got := addr.String(); got != want { - t.Errorf("Address(%x).String() = '%s', want = '%s'", addr, got, want) - } - } -} - -func TestStatsString(t *testing.T) { - got := fmt.Sprintf("%+v", Stats{}.FillIn()) - - matchers := []string{ - // Print root-level stats correctly. - "UnknownProtocolRcvdPackets:0", - // Print protocol-specific stats correctly. - "TCP:{ActiveConnectionOpenings:0", - } - - for _, m := range matchers { - if !strings.Contains(got, m) { - t.Errorf("string.Contains(got, %q) = false", m) - } - } - if t.Failed() { - t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) - } -} diff --git a/pkg/tcpip/time.s b/pkg/tcpip/time.s deleted file mode 100644 index fb37360ac..000000000 --- a/pkg/tcpip/time.s +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Empty assembly file so empty func definitions work. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD deleted file mode 100644 index 62182a3e6..000000000 --- a/pkg/tcpip/transport/icmp/BUILD +++ /dev/null @@ -1,47 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") - -go_template_instance( - name = "icmp_packet_list", - out = "icmp_packet_list.go", - package = "icmp", - prefix = "icmpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*icmpPacket", - "Linker": "*icmpPacket", - }, -) - -go_library( - name = "icmp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "icmp_packet_list.go", - "protocol.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp", - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) - -filegroup( - name = "autogen", - srcs = [ - "icmp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go new file mode 100755 index 000000000..1b35e5b4a --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go @@ -0,0 +1,173 @@ +package icmp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type icmpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type icmpPacketList struct { + head *icmpPacket + tail *icmpPacket +} + +// Reset resets list l to the empty state. +func (l *icmpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *icmpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *icmpPacketList) Front() *icmpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *icmpPacketList) Back() *icmpPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *icmpPacketList) PushFront(e *icmpPacket) { + icmpPacketElementMapper{}.linkerFor(e).SetNext(l.head) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *icmpPacketList) PushBack(e *icmpPacket) { + icmpPacketElementMapper{}.linkerFor(e).SetNext(nil) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *icmpPacketList) PushBackList(m *icmpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) { + a := icmpPacketElementMapper{}.linkerFor(b).Next() + icmpPacketElementMapper{}.linkerFor(e).SetNext(a) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(b) + icmpPacketElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + icmpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) { + b := icmpPacketElementMapper{}.linkerFor(a).Prev() + icmpPacketElementMapper{}.linkerFor(e).SetNext(a) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(b) + icmpPacketElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + icmpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *icmpPacketList) Remove(e *icmpPacket) { + prev := icmpPacketElementMapper{}.linkerFor(e).Prev() + next := icmpPacketElementMapper{}.linkerFor(e).Next() + + if prev != nil { + icmpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type icmpPacketEntry struct { + next *icmpPacket + prev *icmpPacket +} + +// Next returns the entry that follows e in the list. +func (e *icmpPacketEntry) Next() *icmpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *icmpPacketEntry) Prev() *icmpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *icmpPacketEntry) SetNext(elem *icmpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go new file mode 100755 index 000000000..5caa1cd64 --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go @@ -0,0 +1,98 @@ +// automatically generated by stateify. + +package icmp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *icmpPacket) beforeSave() {} +func (x *icmpPacket) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("icmpPacketEntry", &x.icmpPacketEntry) + m.Save("senderAddress", &x.senderAddress) + m.Save("timestamp", &x.timestamp) +} + +func (x *icmpPacket) afterLoad() {} +func (x *icmpPacket) load(m state.Map) { + m.Load("icmpPacketEntry", &x.icmpPacketEntry) + m.Load("senderAddress", &x.senderAddress) + m.Load("timestamp", &x.timestamp) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("netProto", &x.netProto) + m.Save("transProto", &x.transProto) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("rcvReady", &x.rcvReady) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("id", &x.id) + m.Save("state", &x.state) + m.Save("bindNICID", &x.bindNICID) + m.Save("bindAddr", &x.bindAddr) + m.Save("regNICID", &x.regNICID) +} + +func (x *endpoint) load(m state.Map) { + m.Load("netProto", &x.netProto) + m.Load("transProto", &x.transProto) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("rcvReady", &x.rcvReady) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("id", &x.id) + m.Load("state", &x.state) + m.Load("bindNICID", &x.bindNICID) + m.Load("bindAddr", &x.bindAddr) + m.Load("regNICID", &x.regNICID) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *icmpPacketList) beforeSave() {} +func (x *icmpPacketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *icmpPacketList) afterLoad() {} +func (x *icmpPacketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *icmpPacketEntry) beforeSave() {} +func (x *icmpPacketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *icmpPacketEntry) afterLoad() {} +func (x *icmpPacketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("icmp.icmpPacket", (*icmpPacket)(nil), state.Fns{Save: (*icmpPacket).save, Load: (*icmpPacket).load}) + state.Register("icmp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("icmp.icmpPacketList", (*icmpPacketList)(nil), state.Fns{Save: (*icmpPacketList).save, Load: (*icmpPacketList).load}) + state.Register("icmp.icmpPacketEntry", (*icmpPacketEntry)(nil), state.Fns{Save: (*icmpPacketEntry).save, Load: (*icmpPacketEntry).load}) +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD deleted file mode 100644 index 34a14bf7f..000000000 --- a/pkg/tcpip/transport/raw/BUILD +++ /dev/null @@ -1,45 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") - -go_template_instance( - name = "packet_list", - out = "packet_list.go", - package = "raw", - prefix = "packet", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*packet", - "Linker": "*packet", - }, -) - -go_library( - name = "raw", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "packet_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sleep", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) - -filegroup( - name = "autogen", - srcs = [ - "packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/raw/packet_list.go b/pkg/tcpip/transport/raw/packet_list.go new file mode 100755 index 000000000..2e9074934 --- /dev/null +++ b/pkg/tcpip/transport/raw/packet_list.go @@ -0,0 +1,173 @@ +package raw + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type packetElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (packetElementMapper) linkerFor(elem *packet) *packet { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type packetList struct { + head *packet + tail *packet +} + +// Reset resets list l to the empty state. +func (l *packetList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *packetList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *packetList) Front() *packet { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *packetList) Back() *packet { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *packetList) PushFront(e *packet) { + packetElementMapper{}.linkerFor(e).SetNext(l.head) + packetElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + packetElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *packetList) PushBack(e *packet) { + packetElementMapper{}.linkerFor(e).SetNext(nil) + packetElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *packetList) PushBackList(m *packetList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(m.head) + packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *packetList) InsertAfter(b, e *packet) { + a := packetElementMapper{}.linkerFor(b).Next() + packetElementMapper{}.linkerFor(e).SetNext(a) + packetElementMapper{}.linkerFor(e).SetPrev(b) + packetElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + packetElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *packetList) InsertBefore(a, e *packet) { + b := packetElementMapper{}.linkerFor(a).Prev() + packetElementMapper{}.linkerFor(e).SetNext(a) + packetElementMapper{}.linkerFor(e).SetPrev(b) + packetElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + packetElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *packetList) Remove(e *packet) { + prev := packetElementMapper{}.linkerFor(e).Prev() + next := packetElementMapper{}.linkerFor(e).Next() + + if prev != nil { + packetElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + packetElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type packetEntry struct { + next *packet + prev *packet +} + +// Next returns the entry that follows e in the list. +func (e *packetEntry) Next() *packet { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *packetEntry) Prev() *packet { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *packetEntry) SetNext(elem *packet) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *packetEntry) SetPrev(elem *packet) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go new file mode 100755 index 000000000..bdbf6207a --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_state_autogen.go @@ -0,0 +1,96 @@ +// automatically generated by stateify. + +package raw + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *packet) beforeSave() {} +func (x *packet) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("packetEntry", &x.packetEntry) + m.Save("timestampNS", &x.timestampNS) + m.Save("senderAddr", &x.senderAddr) +} + +func (x *packet) afterLoad() {} +func (x *packet) load(m state.Map) { + m.Load("packetEntry", &x.packetEntry) + m.Load("timestampNS", &x.timestampNS) + m.Load("senderAddr", &x.senderAddr) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("netProto", &x.netProto) + m.Save("transProto", &x.transProto) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("closed", &x.closed) + m.Save("connected", &x.connected) + m.Save("bound", &x.bound) + m.Save("registeredNIC", &x.registeredNIC) + m.Save("boundNIC", &x.boundNIC) + m.Save("boundAddr", &x.boundAddr) +} + +func (x *endpoint) load(m state.Map) { + m.Load("netProto", &x.netProto) + m.Load("transProto", &x.transProto) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("closed", &x.closed) + m.Load("connected", &x.connected) + m.Load("bound", &x.bound) + m.Load("registeredNIC", &x.registeredNIC) + m.Load("boundNIC", &x.boundNIC) + m.Load("boundAddr", &x.boundAddr) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *packetList) beforeSave() {} +func (x *packetList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *packetList) afterLoad() {} +func (x *packetList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *packetEntry) beforeSave() {} +func (x *packetEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *packetEntry) afterLoad() {} +func (x *packetEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("raw.packet", (*packet)(nil), state.Fns{Save: (*packet).save, Load: (*packet).load}) + state.Register("raw.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("raw.packetList", (*packetList)(nil), state.Fns{Save: (*packetList).save, Load: (*packetList).load}) + state.Register("raw.packetEntry", (*packetEntry)(nil), state.Fns{Save: (*packetEntry).save, Load: (*packetEntry).load}) +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD deleted file mode 100644 index 4cd25e8e2..000000000 --- a/pkg/tcpip/transport/tcp/BUILD +++ /dev/null @@ -1,97 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "tcp_segment_list", - out = "tcp_segment_list.go", - package = "tcp", - prefix = "segment", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*segment", - "Linker": "*segment", - }, -) - -go_library( - name = "tcp", - srcs = [ - "accept.go", - "connect.go", - "cubic.go", - "cubic_state.go", - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "rcv.go", - "reno.go", - "sack.go", - "sack_scoreboard.go", - "segment.go", - "segment_heap.go", - "segment_queue.go", - "segment_state.go", - "snd.go", - "snd_state.go", - "tcp_segment_list.go", - "timer.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp", - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/rand", - "//pkg/sleep", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/tmutex", - "//pkg/waiter", - "@com_github_google_btree//:go_default_library", - ], -) - -filegroup( - name = "autogen", - srcs = [ - "tcp_segment_list.go", - ], - visibility = ["//:sandbox"], -) - -go_test( - name = "tcp_test", - size = "small", - srcs = [ - "dual_stack_test.go", - "sack_scoreboard_test.go", - "tcp_noracedetector_test.go", - "tcp_sack_test.go", - "tcp_test.go", - "tcp_timestamp_test.go", - ], - # FIXME(b/68809571) - tags = ["flaky"], - deps = [ - ":tcp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp/testing/context", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go deleted file mode 100644 index d9f79e8c5..000000000 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ /dev/null @@ -1,572 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestV4MappedConnectOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Start connection attempt, it must fail. - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrNoRoute { - t.Fatalf("Unexpected return value from Connect: %v", err) - } -} - -func testV4Connect(t *testing.T, c *context.Context) { - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV4MappedConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func testV6Connect(t *testing.T, c *context.Context) { - // Start connection attempt to IPv6 address. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetV6Packet() - checker.IPv6(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - - tcp := header.TCP(header.IPv6(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV6Connect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to local address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV4RefuseOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), - ), - ) -} - -func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), - ), - ) -} - -func testV4Accept(t *testing.T, c *context.Context) { - c.SetGSOEnabled(true) - defer c.SetGSOEnabled(false) - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv4(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Make sure we get the same error when calling the original ep and the - // new one. This validates that v4-mapped endpoints are still able to - // query the V6Only flag, whereas pure v4 endpoints are not. - var v tcpip.V6OnlyOption - expected := c.EP.GetSockOpt(&v) - if err := nep.GetSockOpt(&v); err != expected { - t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestAddr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) - } - - data := "Don't panic" - nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - b = c.GetPacket() - tcp = header.TCP(header.IPv4(b).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -func TestV4AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV6AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetV6Packet() - tcp := header.TCP(header.IPv6(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv6(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Make sure we can still query the v6 only status of the new endpoint, - // that is, that it is in fact a v6 socket. - var v tcpip.V6OnlyOption - if err := nep.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockOpt failed failed: %v", err) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestV6Addr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) - } -} - -func TestV4AcceptOnV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go deleted file mode 100644 index b4e5ba0df..000000000 --- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" -) - -const smss = 1500 - -func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreboard { - s := tcp.NewSACKScoreboard(smss, iss) - for _, blk := range blocks { - s.Insert(blk) - } - return s -} - -func TestSACKScoreboardIsSACKED(t *testing.T) { - type blockTest struct { - block header.SACKBlock - sacked bool - } - testCases := []struct { - comment string - scoreboardBlocks []header.SACKBlock - blockTests []blockTest - iss seqnum.Value - }{ - { - "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks", - []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}}, - []blockTest{ - {header.SACKBlock{15, 21}, true}, - {header.SACKBlock{200, 201}, false}, - {header.SACKBlock{50, 51}, false}, - {header.SACKBlock{53, 120}, true}, - }, - 0, - }, - { - "Test disjoint SACKBlocks", - []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}}, - []blockTest{ - {header.SACKBlock{2288624809, 2288810057}, true}, - {header.SACKBlock{2288811477, 2288838565}, true}, - {header.SACKBlock{2288810057, 2288811477}, false}, - }, - 2288624809, - }, - { - "Test sequence number wrap around", - []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}, - []blockTest{ - {header.SACKBlock{4294254144, 4294254145}, true}, - {header.SACKBlock{4294254143, 4294254144}, false}, - {header.SACKBlock{4294254144, 1}, true}, - {header.SACKBlock{225652, 5350509}, false}, - {header.SACKBlock{5340409, 5350509}, true}, - {header.SACKBlock{5350509, 5350609}, false}, - }, - 4294254144, - }, - { - "Test disjoint SACKBlocks out of order", - []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}}, - []blockTest{ - {header.SACKBlock{827426028, 827428867}, true}, - {header.SACKBlock{827450168, 827450275}, false}, - }, - 827426000, - }, - } - for _, tc := range testCases { - sb := initScoreboard(tc.scoreboardBlocks, tc.iss) - for _, blkTest := range tc.blockTests { - if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want { - t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want) - } - } - } -} - -func TestSACKScoreboardIsRangeLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - block header.SACKBlock - lost bool - }{ - // Block not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number covered by this block. - {block: header.SACKBlock{0, 1}, lost: true}, - - // These blocks have all been SACKed and should not be - // considered lost. - {block: header.SACKBlock{1, 2}, lost: false}, - {block: header.SACKBlock{25, 26}, lost: false}, - {block: header.SACKBlock{1, 45}, lost: false}, - - // Same as the first case above. - {block: header.SACKBlock{50, 51}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{119, 120}, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {block: header.SACKBlock{120, 121}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{125, 126}, lost: false}, - - // This block has not been SACKed and there are nDupAckThreshold - // number of SACKed blocks after it. - {block: header.SACKBlock{141, 145}, lost: true}, - - // This block has not been SACKed and there are less than - // nDupAckThreshold SACKed sequences after it. - {block: header.SACKBlock{151, 152}, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsRangeLost(tc.block); got != want { - t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want) - } - } -} - -func TestSACKScoreboardIsLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - seq seqnum.Value - lost bool - }{ - // Sequence number not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number. - {seq: 0, lost: true}, - - // These sequence numbers have all been SACKed and should not be - // considered lost. - {seq: 1, lost: false}, - {seq: 25, lost: false}, - {seq: 45, lost: false}, - - // Same as first case above. - {seq: 50, lost: true}, - - // This block has been SACKed and should not be considered lost. - {seq: 119, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {seq: 120, lost: true}, - - // This sequence number has been SACKed and should not be - // considered lost. - {seq: 125, lost: false}, - - // This sequence number has not been SACKed and there are - // nDupAckThreshold number of SACKed blocks after it. - {seq: 141, lost: true}, - - // This sequence number has not been SACKed and there are less - // than nDupAckThreshold SACKed sequences after it. - {seq: 151, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsLost(tc.seq); got != want { - t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want) - } - } -} - -func TestSACKScoreboardDelete(t *testing.T) { - blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}} - s := initScoreboard(blocks, 4294254143) - s.Delete(5340408) - if s.Empty() { - t.Fatalf("s.Empty() = true, want false") - } - if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } - s.Delete(5340410) - if s.Empty() { - t.Fatal("s.Empty() = true, want false") - } - newSB := header.SACKBlock{5340410, 5350509} - if !s.IsSACKED(newSB) { - t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s) - } - s.Delete(5350509) - lastOctet := header.SACKBlock{5350508, 5350509} - if s.IsSACKED(lastOctet) { - t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet) - } - - s.Delete(5350510) - if !s.Empty() { - t.Fatal("s.Empty() = false, want true") - } - if got, want := s.Sacked(), seqnum.Size(0); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go deleted file mode 100644 index 272bbcdbd..000000000 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ /dev/null @@ -1,519 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// These tests are flaky when run under the go race detector due to some -// iterations taking long enough that the retransmit timer can kick in causing -// the congestion window measurements to fail due to extra packets etc. -// -// +build !race - -package tcp_test - -import ( - "fmt" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" -) - -func TestFastRecovery(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - for i := 0; i < 3; i++ { - c.SendAck(790, rtxOffset) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want) - } - - // Now send 7 mode duplicate acks. Each of these should cause a window - // inflation by 1 and cause the sender to send an extra packet. - for i := 0; i < 7; i++ { - c.SendAck(790, rtxOffset) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the retransmit due to partial ack. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) - } - - // Receive the 10 extra packets that should have been released due to - // the congestion window inflation in recovery. - for i := 0; i < 10; i++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // A partial ACK during recovery should reduce congestion window by the - // number acked. Since we had "expected" packets outstanding before sending - // partial ack and we acked expected/2 , the cwnd and outstanding should - // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered - // fast recovery). Which means the sender should not send any more packets - // till we ack this one. - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", - 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(790, recover) - - // At this point, the cwnd should reset to expected/2 and there are 10 - // packets outstanding. - // - // NOTE: Technically netstack is incorrect in that we adjust the cwnd on - // the same segment that takes us out of recovery. But because of that - // the actual cwnd at exit of recovery will be expected/2 + 1 as we - // acked a cwnd worth of packets which will increase the cwnd further by - // 1 in congestion avoidance. - // - // Now in the first iteration since there are 10 packets outstanding. - // We would expect to get expected/2 +1 - 10 packets. But subsequent - // iterations will send us expected/2 + 1 + 1 (per iteration). - expected = expected/2 + 1 - 10 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 10 - } - expected++ - } -} - -func TestExponentialIncreaseDuringSlowStart(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - const iterations = 7 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // Double the number of expected packets for the next iteration. - expected *= 2 - } -} - -func TestCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd/2. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected/2. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 (which "consumes" expected/2-1 of the - // acknowledgements), then the congestion avoidance part will consume - // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack - // remains in the "ack count" (which will cause cwnd to be incremented - // once it reaches cwnd acks). - // - // So we're straight into congestion avoidance with cwnd set to - // expected/2 + 1. - // - // Check that packets trains of cwnd packets are sent, and that cwnd is - // incremented by 1 after we acknowledge each packet. - expected = expected/2 + 1 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - expected++ - } -} - -// cubicCwnd returns an estimate of a cubic window given the -// originalCwnd, wMax, last congestion event time and sRTT. -func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int { - cwnd := float64(origCwnd) - // We wait 50ms between each iteration so sRTT as computed by cubic - // should be close to 50ms. - elapsed := (time.Since(congEventTime) + sRTT).Seconds() - k := math.Cbrt(float64(wMax) * 0.3 / 0.7) - wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax) - cwnd += (wtRTT - cwnd) / cwnd - return int(cwnd) -} - -func TestCubicCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - enableCUBIC(t, c) - - c.CreateConnected(789, 30000, nil) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd * 0.7. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all pending data. - c.SendAck(790, bytesRead) - - // Store away the time we sent the ACK and assuming a 200ms RTO - // we estimate that the sender will have an RTO 200ms from now - // and go back into slow start. - packetDropTime := time.Now().Add(200 * time.Millisecond) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 essentially putting the connection - // straight into congestion avoidance. - wMax := expected - // Lower expected as per cubic spec after a congestion event. - expected = int(float64(expected) * 0.7) - cwnd := expected - for i := 0; i < iterations; i++ { - // Cubic grows window independent of ACKs. Cubic Window growth - // is a function of time elapsed since last congestion event. - // As a result the congestion window does not grow - // deterministically in response to ACKs. - // - // We need to roughly estimate what the cwnd of the sender is - // based on when we sent the dupacks. - cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) - - packetsExpected := cwnd - for j := 0; j < packetsExpected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - t.Logf("expected packets received, next trying to receive any extra packets that may come") - - // If our estimate was correct there should be no more pending packets. - // We attempt to read a packet a few times with a short sleep in between - // to ensure that we don't see the sender send any unexpected packets. - unexpectedPackets := 0 - for { - gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) - if !gotPacket { - break - } - bytesRead += maxPayload - unexpectedPackets++ - time.Sleep(1 * time.Millisecond) - } - if unexpectedPackets != 0 { - t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - } -} - -func TestRetransmit(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - const iterations = 7 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in two shots. Packets will only be written at the - // MTU size though. - half := data[:len(data)/2] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - half = data[len(data)/2:] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Wait for a timeout and retransmit. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) - } - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the remaining data, making sure that acknowledged data is not - // retransmitted. - for offset := rtxOffset; offset < len(data); offset += maxPayload { - c.ReceiveAndCheckPacket(data, offset, maxPayload) - c.SendAck(790, offset+maxPayload) - } - - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) -} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go deleted file mode 100644 index 4e7f1a740..000000000 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ /dev/null @@ -1,564 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "fmt" - "log" - "reflect" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" -) - -// createConnectedWithSACKPermittedOption creates and connects c.ep with the -// SACKPermitted option enabled if the stack in the context has the SACK support -// enabled. -func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) -} - -// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS -// option enabled if the stack in the context has SACK and TS enabled. -func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) -} - -func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { - t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err) - } -} - -// TestSackPermittedConnect establishes a connection with the SACK option -// enabled. -func TestSackPermittedConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - rep := createConnectedWithSACKPermittedOption(c) - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // Restore the saved sequence number so that the - // VerifyXXX calls use the right sequence number for - // checking ACK numbers. - rep.NextSeqNum = savedSeqNum - if sackEnabled { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackDisabledConnect establishes a connection with the SACK option -// disabled and verifies that no SACKs are sent for out of order segments. -func TestSackDisabledConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.CreateConnectedWithOptions(header.TCPSynOptions{}) - - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older sequence number and - // no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackPermittedAccept accepts and establishes a connection with the -// SACKPermitted option enabled if the connection request specifies the -// SACKPermitted option. In case of SYN cookies SACK should be disabled as we -// don't encode the SACK information in the cookie. -func TestSackPermittedAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - sackPermitted bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number. - rep.NextSeqNum = savedSeqNum - if sackEnabled && tc.sackPermitted { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -// TestSackDisabledAccept accepts and establishes a connection with -// the SACKPermitted option disabled and verifies that no SACKs are -// sent for out of order packets. -func TestSackDisabledAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number and no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -func TestUpdateSACKBlocks(t *testing.T) { - testCases := []struct { - segStart seqnum.Value - segEnd seqnum.Value - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - updated []header.SACKBlock - }{ - // Trivial cases where current SACK block list is empty and we - // have an out of order delivery. - {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}}, - {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}}, - {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}}, - - // Cases where current SACK block list is not empty and we have - // an out of order delivery. Tests that the updated SACK block - // list has the first block as the one that contains the new - // SACK block representing the segment that was just delivered. - {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}}, - - // Ensure that we only retain header.MaxSACKBlocks and drop the - // oldest one if adding a new block exceeds - // header.MaxSACKBlocks. - {24, 30, 9, - []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}}, - []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}}, - - // Cases where segment extends an existing SACK block. - {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}}, - {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}}, - - // Cases where segment contains rcvNxt. - {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}}, - } - - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) { - t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want) - } - - } -} - -func TestTrimSackBlockList(t *testing.T) { - testCases := []struct { - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - trimmed []header.SACKBlock - }{ - // Simple cases where we trim whole entries. - {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}}, - {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}}, - {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}}, - {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - // Cases where we need to update a block. - {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}}, - {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}}, - {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}}, - {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - } - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.TrimSACKBlockList(&sack, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) { - t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want) - } - } -} - -func TestSACKRecovery(t *testing.T) { - const maxPayload = 10 - // See: tcp.makeOptions for why tsOptionSize is set to 12 here. - const tsOptionSize = 12 - // Enabling SACK means the payload size is reduced to account - // for the extra space required for the TCP options. - // - // We increase the MTU by 40 bytes to account for SACK and Timestamp - // options. - const maxTCPOptionSize = 40 - - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) - defer c.Cleanup() - - c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { - // We use log.Printf instead of t.Logf here because this probe - // can fire even when the test function has finished. This is - // because closing the endpoint in cleanup() does not mean the - // actual worker loop terminates immediately as it still has to - // do a full TCP shutdown. But this test can finish running - // before the shutdown is done. Using t.Logf in such a case - // causes the test to panic due to logging after test finished. - log.Printf("state: %+v\n", s) - }) - setStackSACKPermitted(t, c, true) - createConnectedWithSACKAndTS(c) - - const iterations = 7 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end := start.Add(10) - for i := 0; i < 3; i++ { - c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - t.Errorf("got %s.Value() = %v, want = %v", s.name, got, want) - } - } - - // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause - // window inflation and sending of packets is completely handled by the - // SACK Recovery algorithm. We should see no packets being released, as - // the cwnd at this point after entering recovery should be half of the - // outstanding number of packets in flight. - for i := 0; i < 7; i++ { - c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. This along with the 10 sacked - // segments above should reduce the outstanding below the current - // congestion window allowing the sender to transmit data. - rtxOffset = bytesRead - expected*maxPayload/2 - - // Now send a partial ACK w/ a SACK block that indicates that the next 3 - // segments are lost and we have received 6 segments after the lost - // segments. This should cause the sender to immediately transmit all 3 - // segments in response to this ACK unlike in FastRecovery where only 1 - // segment is retransmitted per ACK. - start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end = start.Add(60) - c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) - - // At this point, we acked expected/2 packets and we SACKED 6 packets and - // 3 segments were considered lost due to the SACK block we sent. - // - // So total packets outstanding can be calculated as follows after 7 - // iterations of slow start -> 10/20/40/80/160/320/640. So expected - // should be 640 at start, then we went to recover at which point the - // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the - // network). - // Outstanding at this point after acking half the window - // (320 packets) will be: - // outstanding = 640-320-6(due to SACK block)-3 = 311 - // - // The last 3 is due to the fact that the first 3 packets after - // rtxOffset will be considered lost due to the SACK blocks sent. - // Receive the retransmit due to partial ack. - - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - // Receive the 2 extra packets that should have been retransmitted as - // those should be considered lost and immediately retransmitted based - // on the SACK information in the previous ACK sent above. - for i := 0; i < 2; i++ { - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize) - } - - // Now we should get 9 more new unsent packets as the cwnd is 323 and - // outstanding is 311. - for i := 0; i < 9; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - // In SACK recovery only the first segment is fast retransmitted when - // entering recovery. - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } - - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(790, recover) - - // At this point, the cwnd should reset to expected/2 and there are 9 - // packets outstanding. - // - // Now in the first iteration since there are 9 packets outstanding. - // We would expect to get expected/2 - 9 packets. But subsequent - // iterations will send us expected/2 + 1 (per iteration). - expected = expected/2 - 9 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 9 - } - expected++ - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go new file mode 100755 index 000000000..029f98a11 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go @@ -0,0 +1,173 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type segmentElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type segmentList struct { + head *segment + tail *segment +} + +// Reset resets list l to the empty state. +func (l *segmentList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *segmentList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *segmentList) Front() *segment { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *segmentList) Back() *segment { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *segmentList) PushFront(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(l.head) + segmentElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + segmentElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *segmentList) PushBack(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(nil) + segmentElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *segmentList) PushBackList(m *segmentList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) + segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *segmentList) InsertAfter(b, e *segment) { + a := segmentElementMapper{}.linkerFor(b).Next() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + segmentElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *segmentList) InsertBefore(a, e *segment) { + b := segmentElementMapper{}.linkerFor(a).Prev() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + segmentElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *segmentList) Remove(e *segment) { + prev := segmentElementMapper{}.linkerFor(e).Prev() + next := segmentElementMapper{}.linkerFor(e).Next() + + if prev != nil { + segmentElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + segmentElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type segmentEntry struct { + next *segment + prev *segment +} + +// Next returns the entry that follows e in the list. +func (e *segmentEntry) Next() *segment { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *segmentEntry) Prev() *segment { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *segmentEntry) SetNext(elem *segment) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *segmentEntry) SetPrev(elem *segment) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go new file mode 100755 index 000000000..1674c6a08 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -0,0 +1,431 @@ +// automatically generated by stateify. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *cubicState) beforeSave() {} +func (x *cubicState) save(m state.Map) { + x.beforeSave() + var t unixTime = x.saveT() + m.SaveValue("t", t) + m.Save("wLastMax", &x.wLastMax) + m.Save("wMax", &x.wMax) + m.Save("numCongestionEvents", &x.numCongestionEvents) + m.Save("c", &x.c) + m.Save("k", &x.k) + m.Save("beta", &x.beta) + m.Save("wC", &x.wC) + m.Save("wEst", &x.wEst) + m.Save("s", &x.s) +} + +func (x *cubicState) afterLoad() {} +func (x *cubicState) load(m state.Map) { + m.Load("wLastMax", &x.wLastMax) + m.Load("wMax", &x.wMax) + m.Load("numCongestionEvents", &x.numCongestionEvents) + m.Load("c", &x.c) + m.Load("k", &x.k) + m.Load("beta", &x.beta) + m.Load("wC", &x.wC) + m.Load("wEst", &x.wEst) + m.Load("s", &x.s) + m.LoadValue("t", new(unixTime), func(y interface{}) { x.loadT(y.(unixTime)) }) +} + +func (x *SACKInfo) beforeSave() {} +func (x *SACKInfo) save(m state.Map) { + x.beforeSave() + m.Save("Blocks", &x.Blocks) + m.Save("NumBlocks", &x.NumBlocks) +} + +func (x *SACKInfo) afterLoad() {} +func (x *SACKInfo) load(m state.Map) { + m.Load("Blocks", &x.Blocks) + m.Load("NumBlocks", &x.NumBlocks) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var lastError string = x.saveLastError() + m.SaveValue("lastError", lastError) + var state EndpointState = x.saveState() + m.SaveValue("state", state) + var hardError string = x.saveHardError() + m.SaveValue("hardError", hardError) + var acceptedChan []*endpoint = x.saveAcceptedChan() + m.SaveValue("acceptedChan", acceptedChan) + m.Save("netProto", &x.netProto) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("rcvList", &x.rcvList) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvBufUsed", &x.rcvBufUsed) + m.Save("id", &x.id) + m.Save("isRegistered", &x.isRegistered) + m.Save("v6only", &x.v6only) + m.Save("isConnectNotified", &x.isConnectNotified) + m.Save("broadcast", &x.broadcast) + m.Save("workerRunning", &x.workerRunning) + m.Save("workerCleanup", &x.workerCleanup) + m.Save("sendTSOk", &x.sendTSOk) + m.Save("recentTS", &x.recentTS) + m.Save("tsOffset", &x.tsOffset) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("sackPermitted", &x.sackPermitted) + m.Save("sack", &x.sack) + m.Save("reusePort", &x.reusePort) + m.Save("delay", &x.delay) + m.Save("cork", &x.cork) + m.Save("scoreboard", &x.scoreboard) + m.Save("reuseAddr", &x.reuseAddr) + m.Save("slowAck", &x.slowAck) + m.Save("segmentQueue", &x.segmentQueue) + m.Save("synRcvdCount", &x.synRcvdCount) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("sndBufUsed", &x.sndBufUsed) + m.Save("sndClosed", &x.sndClosed) + m.Save("sndBufInQueue", &x.sndBufInQueue) + m.Save("sndQueue", &x.sndQueue) + m.Save("cc", &x.cc) + m.Save("packetTooBigCount", &x.packetTooBigCount) + m.Save("sndMTU", &x.sndMTU) + m.Save("keepalive", &x.keepalive) + m.Save("rcv", &x.rcv) + m.Save("snd", &x.snd) + m.Save("bindAddress", &x.bindAddress) + m.Save("connectingAddress", &x.connectingAddress) + m.Save("gso", &x.gso) +} + +func (x *endpoint) load(m state.Map) { + m.Load("netProto", &x.netProto) + m.LoadWait("waiterQueue", &x.waiterQueue) + m.LoadWait("rcvList", &x.rcvList) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvBufUsed", &x.rcvBufUsed) + m.Load("id", &x.id) + m.Load("isRegistered", &x.isRegistered) + m.Load("v6only", &x.v6only) + m.Load("isConnectNotified", &x.isConnectNotified) + m.Load("broadcast", &x.broadcast) + m.Load("workerRunning", &x.workerRunning) + m.Load("workerCleanup", &x.workerCleanup) + m.Load("sendTSOk", &x.sendTSOk) + m.Load("recentTS", &x.recentTS) + m.Load("tsOffset", &x.tsOffset) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("sackPermitted", &x.sackPermitted) + m.Load("sack", &x.sack) + m.Load("reusePort", &x.reusePort) + m.Load("delay", &x.delay) + m.Load("cork", &x.cork) + m.Load("scoreboard", &x.scoreboard) + m.Load("reuseAddr", &x.reuseAddr) + m.Load("slowAck", &x.slowAck) + m.LoadWait("segmentQueue", &x.segmentQueue) + m.Load("synRcvdCount", &x.synRcvdCount) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("sndBufUsed", &x.sndBufUsed) + m.Load("sndClosed", &x.sndClosed) + m.Load("sndBufInQueue", &x.sndBufInQueue) + m.LoadWait("sndQueue", &x.sndQueue) + m.Load("cc", &x.cc) + m.Load("packetTooBigCount", &x.packetTooBigCount) + m.Load("sndMTU", &x.sndMTU) + m.Load("keepalive", &x.keepalive) + m.LoadWait("rcv", &x.rcv) + m.LoadWait("snd", &x.snd) + m.Load("bindAddress", &x.bindAddress) + m.Load("connectingAddress", &x.connectingAddress) + m.Load("gso", &x.gso) + m.LoadValue("lastError", new(string), func(y interface{}) { x.loadLastError(y.(string)) }) + m.LoadValue("state", new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) }) + m.LoadValue("hardError", new(string), func(y interface{}) { x.loadHardError(y.(string)) }) + m.LoadValue("acceptedChan", new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *keepalive) beforeSave() {} +func (x *keepalive) save(m state.Map) { + x.beforeSave() + m.Save("enabled", &x.enabled) + m.Save("idle", &x.idle) + m.Save("interval", &x.interval) + m.Save("count", &x.count) + m.Save("unacked", &x.unacked) +} + +func (x *keepalive) afterLoad() {} +func (x *keepalive) load(m state.Map) { + m.Load("enabled", &x.enabled) + m.Load("idle", &x.idle) + m.Load("interval", &x.interval) + m.Load("count", &x.count) + m.Load("unacked", &x.unacked) +} + +func (x *receiver) beforeSave() {} +func (x *receiver) save(m state.Map) { + x.beforeSave() + m.Save("ep", &x.ep) + m.Save("rcvNxt", &x.rcvNxt) + m.Save("rcvAcc", &x.rcvAcc) + m.Save("rcvWndScale", &x.rcvWndScale) + m.Save("closed", &x.closed) + m.Save("pendingRcvdSegments", &x.pendingRcvdSegments) + m.Save("pendingBufUsed", &x.pendingBufUsed) + m.Save("pendingBufSize", &x.pendingBufSize) +} + +func (x *receiver) afterLoad() {} +func (x *receiver) load(m state.Map) { + m.Load("ep", &x.ep) + m.Load("rcvNxt", &x.rcvNxt) + m.Load("rcvAcc", &x.rcvAcc) + m.Load("rcvWndScale", &x.rcvWndScale) + m.Load("closed", &x.closed) + m.Load("pendingRcvdSegments", &x.pendingRcvdSegments) + m.Load("pendingBufUsed", &x.pendingBufUsed) + m.Load("pendingBufSize", &x.pendingBufSize) +} + +func (x *renoState) beforeSave() {} +func (x *renoState) save(m state.Map) { + x.beforeSave() + m.Save("s", &x.s) +} + +func (x *renoState) afterLoad() {} +func (x *renoState) load(m state.Map) { + m.Load("s", &x.s) +} + +func (x *SACKScoreboard) beforeSave() {} +func (x *SACKScoreboard) save(m state.Map) { + x.beforeSave() + m.Save("smss", &x.smss) + m.Save("maxSACKED", &x.maxSACKED) +} + +func (x *SACKScoreboard) afterLoad() {} +func (x *SACKScoreboard) load(m state.Map) { + m.Load("smss", &x.smss) + m.Load("maxSACKED", &x.maxSACKED) +} + +func (x *segment) beforeSave() {} +func (x *segment) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + var options []byte = x.saveOptions() + m.SaveValue("options", options) + var rcvdTime unixTime = x.saveRcvdTime() + m.SaveValue("rcvdTime", rcvdTime) + var xmitTime unixTime = x.saveXmitTime() + m.SaveValue("xmitTime", xmitTime) + m.Save("segmentEntry", &x.segmentEntry) + m.Save("refCnt", &x.refCnt) + m.Save("viewToDeliver", &x.viewToDeliver) + m.Save("sequenceNumber", &x.sequenceNumber) + m.Save("ackNumber", &x.ackNumber) + m.Save("flags", &x.flags) + m.Save("window", &x.window) + m.Save("csum", &x.csum) + m.Save("csumValid", &x.csumValid) + m.Save("parsedOptions", &x.parsedOptions) + m.Save("hasNewSACKInfo", &x.hasNewSACKInfo) +} + +func (x *segment) afterLoad() {} +func (x *segment) load(m state.Map) { + m.Load("segmentEntry", &x.segmentEntry) + m.Load("refCnt", &x.refCnt) + m.Load("viewToDeliver", &x.viewToDeliver) + m.Load("sequenceNumber", &x.sequenceNumber) + m.Load("ackNumber", &x.ackNumber) + m.Load("flags", &x.flags) + m.Load("window", &x.window) + m.Load("csum", &x.csum) + m.Load("csumValid", &x.csumValid) + m.Load("parsedOptions", &x.parsedOptions) + m.Load("hasNewSACKInfo", &x.hasNewSACKInfo) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) + m.LoadValue("options", new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) }) + m.LoadValue("rcvdTime", new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) }) + m.LoadValue("xmitTime", new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) }) +} + +func (x *segmentQueue) beforeSave() {} +func (x *segmentQueue) save(m state.Map) { + x.beforeSave() + m.Save("list", &x.list) + m.Save("limit", &x.limit) + m.Save("used", &x.used) +} + +func (x *segmentQueue) afterLoad() {} +func (x *segmentQueue) load(m state.Map) { + m.LoadWait("list", &x.list) + m.Load("limit", &x.limit) + m.Load("used", &x.used) +} + +func (x *sender) beforeSave() {} +func (x *sender) save(m state.Map) { + x.beforeSave() + var lastSendTime unixTime = x.saveLastSendTime() + m.SaveValue("lastSendTime", lastSendTime) + var rttMeasureTime unixTime = x.saveRttMeasureTime() + m.SaveValue("rttMeasureTime", rttMeasureTime) + m.Save("ep", &x.ep) + m.Save("dupAckCount", &x.dupAckCount) + m.Save("fr", &x.fr) + m.Save("sndCwnd", &x.sndCwnd) + m.Save("sndSsthresh", &x.sndSsthresh) + m.Save("sndCAAckCount", &x.sndCAAckCount) + m.Save("outstanding", &x.outstanding) + m.Save("sndWnd", &x.sndWnd) + m.Save("sndUna", &x.sndUna) + m.Save("sndNxt", &x.sndNxt) + m.Save("sndNxtList", &x.sndNxtList) + m.Save("rttMeasureSeqNum", &x.rttMeasureSeqNum) + m.Save("closed", &x.closed) + m.Save("writeNext", &x.writeNext) + m.Save("writeList", &x.writeList) + m.Save("rtt", &x.rtt) + m.Save("rto", &x.rto) + m.Save("srttInited", &x.srttInited) + m.Save("maxPayloadSize", &x.maxPayloadSize) + m.Save("gso", &x.gso) + m.Save("sndWndScale", &x.sndWndScale) + m.Save("maxSentAck", &x.maxSentAck) + m.Save("cc", &x.cc) +} + +func (x *sender) load(m state.Map) { + m.Load("ep", &x.ep) + m.Load("dupAckCount", &x.dupAckCount) + m.Load("fr", &x.fr) + m.Load("sndCwnd", &x.sndCwnd) + m.Load("sndSsthresh", &x.sndSsthresh) + m.Load("sndCAAckCount", &x.sndCAAckCount) + m.Load("outstanding", &x.outstanding) + m.Load("sndWnd", &x.sndWnd) + m.Load("sndUna", &x.sndUna) + m.Load("sndNxt", &x.sndNxt) + m.Load("sndNxtList", &x.sndNxtList) + m.Load("rttMeasureSeqNum", &x.rttMeasureSeqNum) + m.Load("closed", &x.closed) + m.Load("writeNext", &x.writeNext) + m.Load("writeList", &x.writeList) + m.Load("rtt", &x.rtt) + m.Load("rto", &x.rto) + m.Load("srttInited", &x.srttInited) + m.Load("maxPayloadSize", &x.maxPayloadSize) + m.Load("gso", &x.gso) + m.Load("sndWndScale", &x.sndWndScale) + m.Load("maxSentAck", &x.maxSentAck) + m.Load("cc", &x.cc) + m.LoadValue("lastSendTime", new(unixTime), func(y interface{}) { x.loadLastSendTime(y.(unixTime)) }) + m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *rtt) beforeSave() {} +func (x *rtt) save(m state.Map) { + x.beforeSave() + m.Save("srtt", &x.srtt) + m.Save("rttvar", &x.rttvar) +} + +func (x *rtt) afterLoad() {} +func (x *rtt) load(m state.Map) { + m.Load("srtt", &x.srtt) + m.Load("rttvar", &x.rttvar) +} + +func (x *fastRecovery) beforeSave() {} +func (x *fastRecovery) save(m state.Map) { + x.beforeSave() + m.Save("active", &x.active) + m.Save("first", &x.first) + m.Save("last", &x.last) + m.Save("maxCwnd", &x.maxCwnd) + m.Save("highRxt", &x.highRxt) + m.Save("rescueRxt", &x.rescueRxt) +} + +func (x *fastRecovery) afterLoad() {} +func (x *fastRecovery) load(m state.Map) { + m.Load("active", &x.active) + m.Load("first", &x.first) + m.Load("last", &x.last) + m.Load("maxCwnd", &x.maxCwnd) + m.Load("highRxt", &x.highRxt) + m.Load("rescueRxt", &x.rescueRxt) +} + +func (x *unixTime) beforeSave() {} +func (x *unixTime) save(m state.Map) { + x.beforeSave() + m.Save("second", &x.second) + m.Save("nano", &x.nano) +} + +func (x *unixTime) afterLoad() {} +func (x *unixTime) load(m state.Map) { + m.Load("second", &x.second) + m.Load("nano", &x.nano) +} + +func (x *segmentList) beforeSave() {} +func (x *segmentList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *segmentList) afterLoad() {} +func (x *segmentList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *segmentEntry) beforeSave() {} +func (x *segmentEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *segmentEntry) afterLoad() {} +func (x *segmentEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("tcp.cubicState", (*cubicState)(nil), state.Fns{Save: (*cubicState).save, Load: (*cubicState).load}) + state.Register("tcp.SACKInfo", (*SACKInfo)(nil), state.Fns{Save: (*SACKInfo).save, Load: (*SACKInfo).load}) + state.Register("tcp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("tcp.keepalive", (*keepalive)(nil), state.Fns{Save: (*keepalive).save, Load: (*keepalive).load}) + state.Register("tcp.receiver", (*receiver)(nil), state.Fns{Save: (*receiver).save, Load: (*receiver).load}) + state.Register("tcp.renoState", (*renoState)(nil), state.Fns{Save: (*renoState).save, Load: (*renoState).load}) + state.Register("tcp.SACKScoreboard", (*SACKScoreboard)(nil), state.Fns{Save: (*SACKScoreboard).save, Load: (*SACKScoreboard).load}) + state.Register("tcp.segment", (*segment)(nil), state.Fns{Save: (*segment).save, Load: (*segment).load}) + state.Register("tcp.segmentQueue", (*segmentQueue)(nil), state.Fns{Save: (*segmentQueue).save, Load: (*segmentQueue).load}) + state.Register("tcp.sender", (*sender)(nil), state.Fns{Save: (*sender).save, Load: (*sender).load}) + state.Register("tcp.rtt", (*rtt)(nil), state.Fns{Save: (*rtt).save, Load: (*rtt).load}) + state.Register("tcp.fastRecovery", (*fastRecovery)(nil), state.Fns{Save: (*fastRecovery).save, Load: (*fastRecovery).load}) + state.Register("tcp.unixTime", (*unixTime)(nil), state.Fns{Save: (*unixTime).save, Load: (*unixTime).load}) + state.Register("tcp.segmentList", (*segmentList)(nil), state.Fns{Save: (*segmentList).save, Load: (*segmentList).load}) + state.Register("tcp.segmentEntry", (*segmentEntry)(nil), state.Fns{Save: (*segmentEntry).save, Load: (*segmentEntry).load}) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go deleted file mode 100644 index ef7ec759d..000000000 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ /dev/null @@ -1,3978 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65535 - - // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an - // IPv4 endpoint when the MTU is set to defaultMTU in the test. - defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize -) - -func TestGiveUpConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - defer wq.EventUnregister(&waitEntry) - - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - - // Close the connection, wait for completion. - ep.Close() - - // Wait for ep to become writable. - <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { - t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) - } -} - -func TestConnectIncrementActiveConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - - c.CreateConnected(789, 30000, nil) - if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) - } -} - -func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.FailedConnectionAttempts.Value() - - c.CreateConnected(789, 30000, nil) - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) - } -} - -func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - c.EP = ep - want := stats.TCP.FailedConnectionAttempts.Value() + 1 - - if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { - t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) - } - - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) - } -} - -func TestTCPSegmentsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - // SYN and ACK - want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, nil) - - if got := stats.TCP.SegmentsSent.Value(); got != want { - t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) - } -} - -func TestTCPResetsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - want := stats.TCP.SegmentsSent.Value() + 1 - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - // If the AckNum is not the increment of the last sequence number, a RST - // segment is sent back in response. - AckNum: c.IRS + 2, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.GetPacket() - if got := stats.TCP.ResetsSent.Value(); got != want { - t.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want) - } -} - -func TestTCPResetsReceivedIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - ackNum := seqnum.Value(789) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(ackNum, rcvWnd, nil) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: c.IRS.Add(2), - AckNum: ackNum.Add(2), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) - - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) - } -} - -func TestActiveHandshake(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) -} - -func TestNonBlockingClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - ep := c.EP - c.EP = nil - - // Close the endpoint and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %v", diff) - } -} - -func TestConnectResetAfterClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - ep := c.EP - c.EP = nil - - // Close the endpoint, make sure we get a FIN segment, then acknowledge - // to complete closure of sender, but don't send our own FIN. - ep.Close() - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for the ep to give up waiting for a FIN, and send a RST. - time.Sleep(3 * time.Second) - for { - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { - // This is a retransmit of the FIN, ignore it. - continue - } - - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - ), - ) - break - } -} - -func TestSimpleReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Receive data. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestOutOfOrderReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send second half of data first, with seqnum 3 ahead of expected. - data := []byte{1, 2, 3, 4, 5, 6} - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 793, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that we get an ACK specifying which seqnum is expected. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Wait 200ms and check that no data has been received. - time.Sleep(200 * time.Millisecond) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send the first 3 bytes now. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive data. - read := make([]byte, 0, 6) - for len(read) < len(data) { - v, _, err := c.EP.Read(nil) - if err != nil { - if err == tcpip.ErrWouldBlock { - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - continue - } - t.Fatalf("Read failed: %v", err) - } - - read = append(read, v...) - } - - // Check that we received the data in proper order. - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that the whole data is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestOutOfOrderFlood(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create a new connection with initial window size of 10. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send 100 packets before the actual one that is expected. - data := []byte{1, 2, 3, 4, 5, 6} - for i := 0; i < 100; i++ { - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 796, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Send packet with seqnum 793. It must be discarded because the - // out-of-order buffer was filled by the previous packets. - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 793, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now send the expected packet, seqnum 790. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that only packet 790 is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestRstOnCloseWithUnreadData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now that we know we have unread data, let's just close the connection - // and verify that netstack sends an RST rather than a FIN. - c.EP.Close() - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // This final ACK should be ignored because an ACK on a reset doesn't mean - // anything. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Cause a FIN to be generated. - c.EP.Shutdown(tcpip.ShutdownWrite) - - // Make sure we get the FIN but DON't ACK IT. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.SeqNum(uint32(c.IRS)+1), - )) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Cause a RST to be generated by closing the read end now since we have - // unread data. - c.EP.Shutdown(tcpip.ShutdownRead) - - // Make sure we get the RST - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // The ACK to the FIN should now be rejected since the connection has been - // closed by a RST. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + len(data)), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestShutdownRead(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) - } -} - -func TestFullWindowReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - _, _, err := c.EP.Read(nil) - if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %v", err) - } - - // Fill up the window. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that data is acknowledged, and window goes to zero. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), - ), - ) - - // Receive data and check it. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Check that we get an ACK for the newly non-zero window. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(10), - ), - ) -} - -func TestNoWindowShrinking(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Start off with a window size of 10, then shrink it to 5. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) - - opt = 5 - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) - } - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send 3 bytes, check that the peer acknowledges them. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that data is acknowledged, and that window doesn't go to zero - // just yet because it was previously set to 10. It must go to 7 now. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(7), - ), - ) - - // Send 7 more bytes, check that the window fills up. - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 793, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), - ), - ) - - // Receive data and check it. - read := make([]byte, 0, 10) - for len(read) < len(data) { - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - read = append(read, v...) - } - - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that we get an ACK for the newly non-zero window, which is the - // new size. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.Window(5), - ), - ) -} - -func TestSimpleSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestZeroWindowSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 0, nil) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Since the window is currently zero, check that no packet is received. - c.CheckNoPacket("Packet received when window is zero") - - // Open up the window. Data should be received now. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xbfff, - // that is, that it is scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestNonScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnected(789, 30000, &opt) - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake. - c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xbfff, - // that is, that it is scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestNonScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake. - c.PassiveConnect(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) -} - -func TestZeroScaledWindowReceive(t *testing.T) { - // This test ensures that the endpoint sends a non-zero window size - // advertisement when the scaled window transitions from 0 to non-zero, - // but the actual window (not scaled) hasn't gotten to zero. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size such that a window scale of 4 will be used. - const wnd = 65535 * 10 - const ws = uint32(4) - opt := tcpip.ReceiveBufferSizeOption(wnd) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - // Write chunks of 50000 bytes. - remain := wnd - sent := 0 - data := make([]byte, 50000) - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(remain>>ws)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Make the window non-zero, but the scaled window zero. - if remain >= 16 { - data = data[:remain-15] - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + sent), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(0), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Read some data. An ack should be sent in response to that. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(len(v)>>ws)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestSegmentMerging(t *testing.T) { - tests := []struct { - name string - stop func(tcpip.Endpoint) - resume func(tcpip.Endpoint) - }{ - { - "stop work", - func(ep tcpip.Endpoint) { - ep.(interface{ StopWork() }).StopWork() - }, - func(ep tcpip.Endpoint) { - ep.(interface{ ResumeWork() }).ResumeWork() - }, - }, - { - "cork", - func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(1)) - }, - func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(0)) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - // Prevent the endpoint from processing packets. - test.stop(c.EP) - - var allData []byte - for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - // Let the endpoint process the segments that we just sent. - test.resume(c.EP) - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(allData)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { - t.Fatalf("got data = %v, want = %v", got, allData) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))), - RcvWnd: 30000, - }) - }) - } -} - -func TestDelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - c.EP.SetSockOpt(tcpip.DelayOption(1)) - - var allData []byte - for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - seq := c.IRS.Add(1) - for _, want := range [][]byte{allData[:1], allData[1:]} { - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(want)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) { - t.Fatalf("got data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(want))) - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) - } -} - -func TestUndelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - c.EP.SetSockOpt(tcpip.DelayOption(1)) - - allData := [][]byte{{0}, {1, 2, 3}} - for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - seq := c.IRS.Add(1) - - // Check that data is received. - first := c.GetPacket() - checker.IPv4(t, first, - checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) { - t.Fatalf("got first packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[0]))) - - // Check that we don't get the second packet yet. - c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - - c.EP.SetSockOpt(tcpip.DelayOption(0)) - - // Check that data is received. - second := c.GetPacket() - checker.IPv4(t, second, - checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) { - t.Fatalf("got second packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[1]))) - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) -} - -func TestMSSNotDelayed(t *testing.T) { - tests := []struct { - name string - fn func(tcpip.Endpoint) - }{ - {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.DelayOption(1)) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const maxPayload = 100 - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - test.fn(c.EP) - - allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} - for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) - } - } - - seq := c.IRS.Add(1) - - for i, data := range allData { - // Check that data is received. - packet := c.GetPacket() - checker.IPv4(t, packet, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) { - t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want) - } - - seq = seq.Add(seqnum.Size(len(data))) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seq, - RcvWnd: 30000, - }) - }) - } -} - -func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { - payloadMultiplier := 10 - dataLen := payloadMultiplier * maxPayload - data := make([]byte, dataLen) - for i := range data { - data[i] = byte(i) - } - - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that data is received in chunks. - bytesReceived := 0 - numPackets := 0 - for bytesReceived != dataLen { - b := c.GetPacket() - numPackets++ - tcpHdr := header.TCP(header.IPv4(b).Payload()) - payloadLen := len(tcpHdr.Payload()) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { - t.Fatalf("got data = %v, want = %v", p, pdata) - } - bytesReceived += payloadLen - var options []byte - if c.TimeStampEnabled { - // If timestamp option is enabled, echo back the timestamp and increment - // the TSEcr value included in the packet and send that back as the TSVal. - parsedOpts := tcpHdr.ParsedOptions() - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) - options = tsOpt[:] - } - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options, - }) - } - if numPackets == 1 { - t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet") - } -} - -func TestSendGreaterThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestActiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, 65535) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestPassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - const wndScale = 2 - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, wndScale, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 536 - const mtu = 2000 - c := context.New(t, mtu) - defer c.Cleanup() - - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 0 - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestForwarderSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - s := c.Stack() - ch := make(chan *tcpip.Error, 1) - f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err *tcpip.Error - c.EP, err = r.CreateEndpoint(&c.WQ) - ch <- err - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, 1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Wait for connection to be available. - select { - case err := <-ch: - if err != nil { - t.Fatalf("Error creating endpoint: %v", err) - } - case <-time.After(2 * time.Second): - t.Fatalf("Timed out waiting for connection") - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynOptionsOnActiveConnect(t *testing.T) { - const mtu = 1400 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - const wndScale = 2 - if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) - } - - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventOut) - defer c.WQ.EventUnregister(&we) - - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - - // Receive SYN packet. - b := c.GetPacket() - mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - // Wait for retransmit. - time.Sleep(1 * time.Second) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcpHdr.SourcePort()), - checker.SeqNum(tcpHdr.SequenceNumber()), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - // Send SYN-ACK. - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestCloseListener(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Close the listener and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %v", diff) - } -} - -func TestReceiveOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - // Send RST segment. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - RcvWnd: 30000, - }) - - // Try to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - -loop: - for { - switch _, _, err := c.EP.Read(nil); err { - case tcpip.ErrWouldBlock: - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for reset to arrive") - } - case tcpip.ErrConnectionReset: - break loop - default: - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) - } - } -} - -func TestSendOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - // Send RST segment. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - RcvWnd: 30000, - }) - - // Wait for the RST to be received. - time.Sleep(1 * time.Second) - - // Try to write. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset) - } -} - -func TestFinImmediately(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinRetransmit(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Don't acknowledge yet. We should get a retransmit of the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithNoPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - // Write something out, and have it acknowledged. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Shutdown, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingDataCwndFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - // Write enough segments to fill the congestion window before ACK'ing - // any of them. - view := buffer.NewView(10) - for i := tcp.InitialCwnd; i > 0; i-- { - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - } - - next := uint32(c.IRS) + 1 - for i := tcp.InitialCwnd; i > 0; i-- { - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - } - - // Shutdown the connection, check that the FIN segment isn't sent - // because the congestion window doesn't allow it. Wait until a - // retransmit is received. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Send the ACK that will allow the FIN to be sent as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - // Write something out, and acknowledge it to get cwnd to 2. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Write new data, but don't acknowledge it. - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPartialAck(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - // Write something out, and acknowledge it to get cwnd to 2. Also send - // FIN from the test side. - view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that we get an ACK for the fin. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Write new data, but don't acknowledge it. - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send an ACK for the data, but not for the FIN yet. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 791, - AckNum: seqnum.Value(next - 1), - RcvWnd: 30000, - }) - - // Check that we don't get a retransmit of the FIN. - c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) - - // Ack the FIN. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 791, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) -} - -func TestUpdateListenBacklog(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Update the backlog with another Listen() on the same endpoint. - if err := ep.Listen(20); err != nil { - t.Fatalf("Listen failed to update backlog: %v", err) - } - - ep.Close() -} - -func scaledSendWindow(t *testing.T, scale uint8) { - // This test ensures that the endpoint is using the right scaling by - // sending a buffer that is larger than the window size, and ensuring - // that the endpoint doesn't send more than allowed. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, nil, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - header.TCPOptionWS, 3, scale, header.TCPOptionNOP, - }) - - // Open up the window with a scaled value. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 1, - }) - - // Send some data. Check that it's capped by the window size. - view := buffer.NewView(65535) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Check that only data that fits in the scaled window is sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen((1<<scale)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Reset the connection to free resources. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: 790, - }) -} - -func TestScaledSendWindow(t *testing.T) { - for scale := uint8(0); scale <= 14; scale++ { - scaledSendWindow(t, scale) - } -} - -func TestReceivedValidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, nil) - stats := c.Stack().Stats() - want := stats.TCP.ValidSegmentsReceived.Value() + 1 - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want) - } -} - -func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, nil) - stats := c.Stack().Stats() - want := stats.TCP.InvalidSegmentsReceived.Value() + 1 - vv := c.BuildSegment(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] - tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 - - c.SendSegment(vv) - - if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) - } -} - -func TestReceivedIncorrectChecksumIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(789, 30000, nil) - stats := c.Stack().Stats() - want := stats.TCP.ChecksumErrors.Value() + 1 - vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] - // Overwrite a byte in the payload which should cause checksum - // verification to fail. - tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 - - c.SendSegment(vv) - - if got := stats.TCP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) - } -} - -func TestReceivedSegmentQueuing(t *testing.T) { - // This test sends 200 segments containing a few bytes each to an - // endpoint and checks that they're all received and acknowledged by - // the endpoint, that is, that none of the segments are dropped by - // internal queues. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - // Send 200 segments. - data := []byte{1, 2, 3} - for i := 0; i < 200; i++ { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(790 + i*len(data)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - } - - // Receive ACKs for all segments. - last := seqnum.Value(790 + 200*len(data)) - for { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - tcpHdr := header.TCP(header.IPv4(b).Payload()) - ack := seqnum.Value(tcpHdr.AckNumber()) - if ack == last { - break - } - - if last.LessThan(ack) { - t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last) - } - } -} - -func TestReadAfterClosedState(t *testing.T) { - // This test ensures that calling Read() or Peek() after the endpoint - // has transitioned to closedState still works if there is pending - // data. To transition to stateClosed without calling Close(), we must - // shutdown the send path and the peer must send its own FIN. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Shutdown immediately for write, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Send some data and acknowledge the FIN. - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: 790, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(uint32(791+len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Give the stack the chance to transition to closed state. Note that since - // both the sender and receiver are now closed, we effectively skip the - // TIME-WAIT state. - time.Sleep(1 * time.Second) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that peek works. - peekBuf := make([]byte, 10) - n, _, err := c.EP.Peek([][]byte{peekBuf}) - if err != nil { - t.Fatalf("Peek failed: %v", err) - } - - peekBuf = peekBuf[:n] - if !bytes.Equal(data, peekBuf) { - t.Fatalf("got data = %v, want = %v", peekBuf, data) - } - - // Receive data. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Now that we drained the queue, check that functions fail with the - // right error code. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) - } - - if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive) - } -} - -func TestReusePort(t *testing.T) { - // This test ensures that ports are immediately available for reuse - // after Close on the endpoints using them returns. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // First case, just an endpoint that was bound. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - c.EP.Close() - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - c.EP.Close() - - // Second case, an endpoint that was bound and is connecting.. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - c.EP.Close() - - // Third case, an endpoint that was bound and is listening. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } -} - -func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - var s tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - if int(s) != v { - t.Fatalf("got receive buffer size = %v, want = %v", s, v) - } -} - -func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - var s tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - if int(s) != v { - t.Fatalf("got send buffer size = %v, want = %v", s, v) - } -} - -func TestDefaultBufferSizes(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - defer func() { - if ep != nil { - ep.Close() - } - }() - - checkSendBufferSize(t, ep, tcp.DefaultBufferSize) - checkRecvBufferSize(t, ep, tcp.DefaultBufferSize) - - // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultBufferSize) - - // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*3) -} - -func TestMinMaxBufferSizes(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - defer ep.Close() - - // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - // Set values below the min. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkRecvBufferSize(t, ep, 200) - - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkSendBufferSize(t, ep, 300) - - // Set values above the max. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultBufferSize*20)); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*20) - - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultBufferSize*30)); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultBufferSize*30) -} - -func makeStack() (*stack.Stack, *tcpip.Error) { - s := stack.New([]string{ - ipv4.ProtocolName, - ipv6.ProtocolName, - }, []string{tcp.ProtocolName}, stack.Options{}) - - id := loopback.New() - if testing.Verbose() { - id = sniffer.New(id) - } - - if err := s.CreateNIC(1, id); err != nil { - return nil, err - } - - for _, ct := range []struct { - number tcpip.NetworkProtocolNumber - address tcpip.Address - }{ - {ipv4.ProtocolNumber, context.StackAddr}, - {ipv6.ProtocolNumber, context.StackV6Addr}, - } { - if err := s.AddAddress(1, ct.number, ct.address); err != nil { - return nil, err - } - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", - NIC: 1, - }, - { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", - NIC: 1, - }, - }) - - return s, nil -} - -func TestSelfConnect(t *testing.T) { - // This test ensures that intentional self-connects work. In particular, - // it checks that if an endpoint binds to say 127.0.0.1:1000 then - // connects to 127.0.0.1:1000, then it will be connected to itself, and - // is able to send and receive data through the same endpoint. - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) - defer wq.EventUnregister(&waitEntry) - - if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) - } - - <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - // Write something. - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Read back what was written. - wq.EventUnregister(&waitEntry) - wq.EventRegister(&waitEntry, waiter.EventIn) - rd, _, err := ep.Read(nil) - if err != nil { - if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %v", err) - } - <-notifyCh - rd, _, err = ep.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - } - - if !bytes.Equal(data, rd) { - t.Fatalf("got data = %v, want = %v", rd, data) - } -} - -func TestConnectAvoidsBoundPorts(t *testing.T) { - addressTypes := func(t *testing.T, network string) []string { - switch network { - case "ipv4": - return []string{"v4"} - case "ipv6": - return []string{"v6"} - case "dual": - return []string{"v6", "mapped"} - default: - t.Fatalf("unknown network: '%s'", network) - } - - panic("unreachable") - } - - address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { - switch addressType { - case "v4": - if isAny { - return "" - } - return context.StackAddr - case "v6": - if isAny { - return "" - } - return context.StackV6Addr - case "mapped": - if isAny { - return context.V4MappedWildcardAddr - } - return context.StackV4MappedAddr - default: - t.Fatalf("unknown address type: '%s'", addressType) - } - - panic("unreachable") - } - // This test ensures that Endpoint.Connect doesn't select already-bound ports. - networks := []string{"ipv4", "ipv6", "dual"} - for _, exhaustedNetwork := range networks { - t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { - for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { - t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { - for _, isAny := range []bool{false, true} { - t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { - for _, candidateNetwork := range networks { - t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { - for _, candidateAddressType := range addressTypes(t, candidateNetwork) { - t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - var eps []tcpip.Endpoint - defer func() { - for _, ep := range eps { - ep.Close() - } - }() - makeEP := func(network string) tcpip.Endpoint { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch network { - case "ipv4": - networkProtocolNumber = ipv4.ProtocolNumber - case "ipv6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatalf("unknown network: '%s'", network) - } - ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - eps = append(eps, ep) - switch network { - case "ipv4": - case "ipv6": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err) - } - case "dual": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err) - } - default: - t.Fatalf("unknown network: '%s'", network) - } - return ep - } - - var v4reserved, v6reserved bool - switch exhaustedAddressType { - case "v4", "mapped": - v4reserved = true - case "v6": - v6reserved = true - // Dual stack sockets bound to v6 any reserve on v4 as - // well. - if isAny { - switch exhaustedNetwork { - case "ipv6": - case "dual": - v4reserved = true - default: - t.Fatalf("unknown address type: '%s'", exhaustedNetwork) - } - } - default: - t.Fatalf("unknown address type: '%s'", exhaustedAddressType) - } - var collides bool - switch candidateAddressType { - case "v4", "mapped": - collides = v4reserved - case "v6": - collides = v6reserved - default: - t.Fatalf("unknown address type: '%s'", candidateAddressType) - } - - for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { - if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { - t.Fatalf("Bind(%d) failed: %v", i, err) - } - } - want := tcpip.ErrConnectStarted - if collides { - want = tcpip.ErrNoPortAvailable - } - if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { - t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want) - } - }) - } - }) - } - }) - } - }) - } - }) - } -} - -func TestPathMTUDiscovery(t *testing.T) { - // This test verifies the stack retransmits packets after it receives an - // ICMP packet indicating that the path MTU has been exceeded. - c := context.New(t, 1500) - defer c.Cleanup() - - // Create new connection with MSS of 1460. - const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - // Send 3200 bytes of data. - const writeSize = 3200 - data := buffer.NewView(writeSize) - for i := range data { - data[i] = byte(i) - } - - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { - var ret []byte - for i, size := range sizes { - p := c.GetPacket() - if i == which { - ret = p - } - checker.IPv4(t, p, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(seqNum), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - seqNum += uint32(size) - } - return ret - } - - // Receive three packets. - sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload} - first := receivePackets(c, sizes, 0, uint32(c.IRS)+1) - - // Send "packet too big" messages back to netstack. - const newMTU = 1200 - const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize - mtu := []byte{0, 0, newMTU / 256, newMTU % 256} - c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU) - - // See retransmitted packets. None exceeding the new max. - sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload} - receivePackets(c, sizes, -1, uint32(c.IRS)+1) -} - -func TestTCPEndpointProbe(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - invoked := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint ID is what we expect. - // - // We don't do an extensive validation of every field but a - // basic sanity test. - if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { - t.Fatalf("got LocalAddress: %q, want: %q", got, want) - } - if got, want := state.ID.LocalPort, c.Port; got != want { - t.Fatalf("got LocalPort: %d, want: %d", got, want) - } - if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { - t.Fatalf("got RemoteAddress: %q, want: %q", got, want) - } - if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { - t.Fatalf("got RemotePort: %d, want: %d", got, want) - } - - invoked <- struct{}{} - }) - - c.CreateConnected(789, 30000, nil) - - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - select { - case <-invoked: - case <-time.After(100 * time.Millisecond): - t.Fatalf("TCP Probe function was not called") - } -} - -func TestStackSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err *tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - var oldCC tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) - } - - got, want := cc, oldCC - // If SetTransportProtocolOption is expected to succeed - // then the returned value for congestion control should - // match the one specified in the - // SetTransportProtocolOption call above, else it should - // be what it was before the call to - // SetTransportProtocolOption. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) - } - }) - } -} - -func TestStackAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Query permitted congestion control algorithms. - var aCC tcpip.AvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) - } - if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) - } -} - -func TestStackSetAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.AvailableCongestionControlOption("xyz") - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) - } - - // Verify that we still get the expected list of congestion control options. - var cc tcpip.AvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) - } - if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) - } -} - -func TestEndpointSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err *tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, - } - - for _, connected := range []bool{false, true} { - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - var oldCC tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err) - } - - if connected { - c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) - } - - if err := c.EP.SetSockOpt(tc.cc); err != tc.err { - t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err) - } - - got, want := cc, oldCC - // If SetSockOpt is expected to succeed then the - // returned value for congestion control should match - // the one specified in the SetSockOpt above, else it - // should be what it was before the call to SetSockOpt. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) - } - }) - } - } -} - -func enableCUBIC(t *testing.T, c *context.Context) { - t.Helper() - opt := tcpip.CongestionControlOption("cubic") - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err) - } -} - -func TestKeepalive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(789, 30000, nil) - - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) - - // 5 unacked keepalives are sent. ACK each one, and check that the - // connection stays alive after 5. - for i := 0; i < 10; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Acknowledge the keepalive. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: c.IRS, - RcvWnd: 30000, - }) - } - - // Check that the connection is still alive. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) - } - - // Send some data and wait before ACKing it. Keepalives should be disabled - // during this period. - view := buffer.NewView(3) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - // Wait for the packet to be retransmitted. Verify that no keepalives - // were sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), - ), - ) - c.CheckNoPacket("Keepalive packet received while unACKed data is pending") - - next += uint32(len(view)) - - // Send ACK. Keepalives should start sending again. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: 790, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Now receive 5 keepalives, but don't ACK them. The connection - // should be reset after 5. - for i := 0; i < 5; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next-1)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // The connection should be terminated after 5 unacked keepalives. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - ), - ) - - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) - } -} - -func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { - // Send a SYN request. - irs = seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply.w - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss = seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), - } - - if synCookieInUse { - // When cookies are in use window scaling is disabled. - tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptions(), - })) - } - - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - return irs, iss -} - -// TestListenBacklogFull tests that netstack does not complete handshakes if the -// listen backlog for the endpoint is full. -func TestListenBacklogFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 2 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - for i := 0; i < listenBacklog; i++ { - executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */) - } - - time.Sleep(50 * time.Millisecond) - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 2, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() - if err != tcpip.ErrWouldBlock { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } - - // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) - - newEP, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -func TestListenSynRcvdQueueFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send two SYN's the first one should get a SYN-ACK, the - // second one should not get any response and is dropped as - // the synRcvd count will be equal to backlog. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - // - // NOTE: we did not complete the handshake for the previous one so the - // accept backlog should be empty and there should be one connection in - // synRcvd state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(889), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Now complete the previous connection and verify that there is a connection - // to accept. - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - newEP, _, err := c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - pkt := c.GetPacket() - tcp = header.TCP(header.IPv4(pkt).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -func TestListenBacklogFullSynCookieInUse(t *testing.T) { - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 1 - - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 1 - portOffset := uint16(0) - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - executeHandshake(t, c, context.TestPort+portOffset, false) - portOffset++ - // Wait for this to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - // The Syn should be dropped as the endpoint's backlog is full. - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Verify that there is only one acceptable connection at this point. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() - if err != tcpip.ErrWouldBlock { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } -} - -func TestPassiveConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - c.EP = ep - if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - stats := c.Stack().Stats() - want := stats.TCP.PassiveConnectionOpenings.Value() + 1 - - srcPort := uint16(context.TestPort) - executeHandshake(t, c, srcPort+1, false) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want) - } -} - -func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - c.EP = ep - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - srcPort := uint16(context.TestPort) - // Now attempt a handshakes it will fill up the accept backlog. - executeHandshake(t, c, srcPort, false) - - // Give time for the final ACK to be processed as otherwise the next handshake could - // get accepted before the previous one based on goroutine scheduling. - time.Sleep(50 * time.Millisecond) - - want := stats.TCP.ListenOverflowSynDrop.Value() + 1 - - // Now we will send one more SYN and this one should get dropped - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort + 2, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), - RcvWnd: 30000, - }) - - time.Sleep(50 * time.Millisecond) - if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want) - } - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } -} - -func TestEndpointBindListenAcceptState(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - aep, _, err := ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - aep, _, err = ep.Accept() - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - // Listening endpoint remains in listen state. - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - ep.Close() - // Give worker goroutines time to receive the close notification. - time.Sleep(1 * time.Second) - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - // Accepted endpoint remains open when the listen endpoint is closed. - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - -} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go deleted file mode 100644 index ad300b90b..000000000 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ /dev/null @@ -1,295 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "math/rand" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -// createConnectedWithTimestampOption creates and connects c.ep with the -// timestamp option enabled. -func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1}) -} - -// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on -// an active connect and sets the TS Echo Reply fields correctly when the -// SYN-ACK also indicates support for the TS option and provides a TSVal. -func TestTimeStampEnabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read and validate that we have data to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - // The following tests ensure that TS option once enabled behaves - // correctly as described in - // https://tools.ietf.org/html/rfc7323#section-4.3. - // - // We are not testing delayed ACKs here, but we do test out of order - // packet delivery and filling the sequence number hole created due to - // the out of order packet. - // - // The test also verifies that the sequence numbers and timestamps are - // as expected. - data := []byte{1, 2, 3} - - // First we increment tsVal by a small amount. - tsVal := rep.TSVal + 100 - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Next we send an out of order packet. - rep.NextSeqNum += 3 - tsVal += 200 - rep.SendPacketWithTS(data, tsVal) - - // The ACK should contain the original sequenceNumber and an older TS. - rep.NextSeqNum -= 6 - rep.VerifyACKWithTS(tsVal - 200) - - // Next we fill the hole and the returned ACK should contain the - // cumulative sequence number acking all data sent till now and have the - // latest timestamp sent below in its TSEcr field. - tsVal -= 100 - rep.SendPacketWithTS(data, tsVal) - rep.NextSeqNum += 3 - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal by a large value that doesn't result in a wrap around. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal again by a large value which should cause the - // timestamp value to wrap around. The returned ACK should contain the - // wrapped around timestamp in its tsEcr field and not the tsVal from - // the previous packet sent above. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // There should be 5 views to read and each of them should - // contain the same data. - for i := 0; i < 5; i++ { - got, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if want := data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } - } -} - -// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an -// active connect but if the SYN-ACK doesn't specify the TS option then -// timestamp option is not enabled and future packets do not contain a -// timestamp. -func TestTimeStampDisabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithOptions(header.TCPSynOptions{}) -} - -func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - - if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } - c := context.New(t, defaultMTU) - defer c.Cleanup() - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - tsVal := rand.Uint32() - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) - - // Now send some data and validate that timestamp is echoed correctly in the ACK. - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) - } - - // Check that data is received and that the timestamp option TSEcr field - // matches the expected value. - b := c.GetPacket() - checker.IPv4(t, b, - // Add 12 bytes for the timestamp option + 2 NOPs to align at 4 - // byte boundary. - checker.PayloadLen(len(data)+header.TCPMinimumSize+12), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - checker.TCPTimestampChecker(true, 0, tsVal+1), - ), - ) -} - -// TestTimeStampEnabledAccept tests that if the SYN on a passive connect -// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK -// and echoes the tsVal field of the original SYN in the tcEcr field of the -// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify -// that Timestamp option is enabled in both cases if requested in the original -// SYN. -func TestTimeStampEnabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - for _, tc := range testCases { - timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } - - c := context.New(t, defaultMTU) - defer c.Cleanup() - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now send some data with the accepted connection endpoint and validate - // that no timestamp option is sent in the TCP segment. - data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) - } - - // Check that data is received and that the timestamp option is disabled - // when SYN cookies are enabled/disabled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - checker.TCPTimestampChecker(false, 0, 0), - ), - ) -} - -// TestTimeStampDisabledAccept tests that Timestamp option is not used when the -// peer doesn't advertise it and connection is established with Accept(). -func TestTimeStampDisabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - for _, tc := range testCases { - timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func TestSendGreaterThanMTUWithOptions(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - createConnectedWithTimestampOption(c) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.EventIn) - defer c.WQ.EventUnregister(&we) - - droppedPacketsStat := c.Stack().Stats().DroppedPackets - droppedPackets := droppedPacketsStat.Value() - data := []byte{1, 2, 3} - // Send a packet with no TCP options/timestamp. - rep.SendPacket(data, nil) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Assert that DroppedPackets was not incremented. - if got, want := droppedPacketsStat.Value(), droppedPackets; got != want { - t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want) - } - - // Issue a read and we should data. - got, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if want := data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD deleted file mode 100644 index 19b0d31c5..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "context", - testonly = 1, - srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context", - visibility = [ - "//:sandbox", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go deleted file mode 100644 index 2cfeed224..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ /dev/null @@ -1,1032 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package context provides a test context for use in tcp tests. It also -// provides helper methods to assert/check certain behaviours. -package context - -import ( - "bytes" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - // StackAddr is the IPv4 address assigned to the stack. - StackAddr = "\x0a\x00\x00\x01" - - // StackPort is used as the listening port in tests for passive - // connects. - StackPort = 1234 - - // TestAddr is the source address for packets sent to the stack via the - // link layer endpoint. - TestAddr = "\x0a\x00\x00\x02" - - // TestPort is the TCP port used for packets sent to the stack - // via the link layer endpoint. - TestPort = 4096 - - // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - - // TestV6Addr is the source address for packets sent to the stack via - // the link layer endpoint. - TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - - // StackV4MappedAddr is StackAddr as a mapped v6 address. - StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr - - // TestV4MappedAddr is TestAddr as a mapped v6 address. - TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr - - // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0. - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - - // testInitialSequenceNumber is the initial sequence number sent in packets that - // are sent in response to a SYN or in the initial SYN sent to the stack. - testInitialSequenceNumber = 789 -) - -// defaultWindowScale value specified here depends on the tcp.DefaultBufferSize -// constant defined in the tcp/endpoint.go because the tcp.DefaultBufferSize is -// used in tcp.newHandshake to determine the window scale to use when sending a -// SYN/SYN-ACK. -var defaultWindowScale = tcp.FindWndScale(tcp.DefaultBufferSize) - -// Headers is used to represent the TCP header fields when building a -// new packet. -type Headers struct { - // SrcPort holds the src port value to be used in the packet. - SrcPort uint16 - - // DstPort holds the destination port value to be used in the packet. - DstPort uint16 - - // SeqNum is the value of the sequence number field in the TCP header. - SeqNum seqnum.Value - - // AckNum represents the acknowledgement number field in the TCP header. - AckNum seqnum.Value - - // Flags are the TCP flags in the TCP header. - Flags int - - // RcvWnd is the window to be advertised in the ReceiveWindow field of - // the TCP header. - RcvWnd seqnum.Size - - // TCPOpts holds the options to be sent in the option field of the TCP - // header. - TCPOpts []byte -} - -// Context provides an initialized Network stack and a link layer endpoint -// for use in TCP tests. -type Context struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - // IRS holds the initial sequence number in the SYN sent by endpoint in - // case of an active connect or the sequence number sent by the endpoint - // in the SYN-ACK sent in response to a SYN when listening in passive - // mode. - IRS seqnum.Value - - // Port holds the port bound by EP below in case of an active connect or - // the listening port number in case of a passive connect. - Port uint16 - - // EP is the test endpoint in the stack owned by this context. This endpoint - // is used in various tests to either initiate an active connect or is used - // as a passive listening endpoint to accept inbound connections. - EP tcpip.Endpoint - - // Wq is the wait queue associated with EP and is used to block for events - // on EP. - WQ waiter.Queue - - // TimeStampEnabled is true if ep is connected with the timestamp option - // enabled. - TimeStampEnabled bool -} - -// New allocates and initializes a test context containing a new -// stack and a link-layer endpoint. -func New(t *testing.T, mtu uint32) *Context { - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) - - // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) - } - - // Some of the congestion control tests send up to 640 packets, we so - // set the channel size to 1000. - id, linkEP := channel.New(1000, mtu, "") - if testing.Verbose() { - id = sniffer.New(id) - } - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", - NIC: 1, - }, - { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", - NIC: 1, - }, - }) - - return &Context{ - t: t, - s: s, - linkEP: linkEP, - } -} - -// Cleanup closes the context endpoint if required. -func (c *Context) Cleanup() { - if c.EP != nil { - c.EP.Close() - } -} - -// Stack returns a reference to the stack in the Context. -func (c *Context) Stack() *stack.Stack { - return c.s -} - -// CheckNoPacketTimeout verifies that no packet is received during the time -// specified by wait. -func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { - c.t.Helper() - - select { - case <-c.linkEP.C: - c.t.Fatal(errMsg) - - case <-time.After(wait): - } -} - -// CheckNoPacket verifies that no packet is received for 1 second. -func (c *Context) CheckNoPacket(errMsg string) { - c.CheckNoPacketTimeout(errMsg, 1*time.Second) -} - -// GetPacket reads a packet from the link layer endpoint and verifies -// that it is an IPv4 packet with the expected source and destination -// addresses. It will fail with an error if no packet is received for -// 2 seconds. -func (c *Context) GetPacket() []byte { - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) - - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) - } - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b - - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") - } - - return nil -} - -// GetPacketNonBlocking reads a packet from the link layer endpoint -// and verifies that it is an IPv4 packet with the expected source -// and destination address. If no packet is available it will return -// nil immediately. -func (c *Context) GetPacketNonBlocking() []byte { - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b - default: - return nil - } -} - -// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. -func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { - // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(p1) + len(p2)) - if len(buf) > maxTotalSize { - buf = buf[:maxTotalSize] - } - - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) - icmp.SetType(typ) - icmp.SetCode(code) - - copy(icmp[header.ICMPv4MinimumSize:], p1) - copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2) - - // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) -} - -// BuildSegment builds a TCP segment based on the given Headers and payload. -func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(tcp.ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv4MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), - Flags: uint8(h.Flags), - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - return buf.ToVectorisedView() -} - -// SendSegment sends a TCP segment that has already been built and written to a -// buffer.VectorisedView. -func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.Inject(ipv4.ProtocolNumber, s) -} - -// SendPacket builds and sends a TCP segment(with the provided payload & TCP -// headers) in an IPv4 packet via the link layer endpoint. -func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h)) -} - -// SendAck sends an ACK packet. -func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { - c.SendAckWithSACK(seq, bytesReceived, nil) -} - -// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified. -func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) { - options := make([]byte, 40) - offset := 0 - if len(sackBlocks) > 0 { - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) - } - - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options[:offset], - }) -} - -// ReceiveAndCheckPacket reads a packet from the link layer endpoint and -// verifies that the packet packet payload of packet matches the slice -// of data indicated by offset & size. -func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { - c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0) -} - -// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size and skips optlen bytes in addition to the IP -// TCP headers when comparing the data. -func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) { - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize+optlen), - checker.TCP( - checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } -} - -// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size. It returns true if a packet was received and -// processed. -func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { - b := c.GetPacketNonBlocking() - if b == nil { - return false - } - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } - return true -} - -// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only -// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 -// only endpoint instead of a default dual stack socket. -func (c *Context) CreateV6Endpoint(v6only bool) { - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.EP.SetSockOpt(v); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } -} - -// GetV6Packet reads a single packet from the link layer endpoint of the context -// and asserts that it is an IPv6 Packet with the expected src/dest addresses. -func (c *Context) GetV6Packet() []byte { - select { - case p := <-c.linkEP.C: - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) - } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) - - checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) - return b - - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") - } - - return nil -} - -// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of -// the context. -func (c *Context) SendV6Packet(payload []byte, h *Headers) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.TCPMinimumSize + len(payload)), - NextHeader: uint8(tcp.ProtocolNumber), - HopLimit: 65, - SrcAddr: TestV6Addr, - DstAddr: StackV6Addr, - }) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv6MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: header.TCPMinimumSize, - Flags: uint8(h.Flags), - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) -} - -// CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) { - c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) -} - -// Connect performs the 3-way handshake for c.EP with the provided Initial -// Sequence Number (iss) and receive window(rcvWnd) and any options if -// specified. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -// -// PreCondition: c.EP must already be created. -func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { - c.t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - c.SendPacket(nil, &Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: options, - }) - - // Receive ACK packet. - checker.IPv4(c.t, c.GetPacket(), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-notifyCh: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - c.Port = tcpHdr.SourcePort() -} - -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if epRcvBuf != nil { - if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - } - c.Connect(iss, rcvWnd, options) -} - -// RawEndpoint is just a small wrapper around a TCP endpoint's state to make -// sending data and ACK packets easy while being able to manipulate the sequence -// numbers and timestamp values as needed. -type RawEndpoint struct { - C *Context - SrcPort uint16 - DstPort uint16 - Flags int - NextSeqNum seqnum.Value - AckNum seqnum.Value - WndSize seqnum.Size - RecentTS uint32 // Stores the latest timestamp to echo back. - TSVal uint32 // TSVal stores the last timestamp sent by this endpoint. - - // SackPermitted is true if SACKPermitted option was negotiated for this endpoint. - SACKPermitted bool -} - -// SendPacketWithTS embeds the provided tsVal in the Timestamp option -// for the packet to be sent out. -func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) { - r.TSVal = tsVal - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:]) - r.SendPacket(payload, tsOpt[:]) -} - -// SendPacket is a small wrapper function to build and send packets. -func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { - packetHeaders := &Headers{ - SrcPort: r.SrcPort, - DstPort: r.DstPort, - Flags: r.Flags, - SeqNum: r.NextSeqNum, - AckNum: r.AckNum, - RcvWnd: r.WndSize, - TCPOpts: opts, - } - r.C.SendPacket(payload, packetHeaders) - r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) -} - -// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided -// tsVal. -func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { - // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.TCPTimestampChecker(true, 0, tsVal), - ), - ) - // Store the parsed TSVal from the ack as recentTS. - tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) - opts := tcpSeg.ParsedOptions() - r.RecentTS = opts.TSVal -} - -// VerifyACKNoSACK verifies that the ACK does not contain a SACK block. -func (r *RawEndpoint) VerifyACKNoSACK() { - r.VerifyACKHasSACK(nil) -} - -// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks. -func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { - // Read ACK and verify that the TCP options in the segment do - // not contain a SACK block. - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.TCPSACKBlockChecker(sackBlocks), - ), - ) -} - -// CreateConnectedWithOptions creates and connects c.ep with the specified TCP -// options enabled and returns a RawEndpoint which represents the other end of -// the connection. -// -// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK -// does not carry an option that was not requested. -func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) - defer c.WQ.EventUnregister(&waitEntry) - - testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} - err = c.EP.Connect(testFullAddr) - if err != tcpip.ErrConnectStarted { - c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) - } - // Receive SYN packet. - b := c.GetPacket() - // Validate that the syn has the timestamp option and a valid - // TS value. - mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) - - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{ - MSS: mss, - TS: true, - WS: defaultWindowScale, - SACKPermitted: c.SACKEnabled(), - }), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpSeg := header.TCP(header.IPv4(b).Payload()) - synOptions := header.ParseSynOptions(tcpSeg.Options(), false) - - // Build options w/ tsVal to be sent in the SYN-ACK. - synAckOptions := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - if wantOptions.TS { - offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:]) - } - if wantOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(synAckOptions[offset:]) - } - - offset += header.AddTCPOptionPadding(synAckOptions, offset) - - // Build SYN-ACK. - c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) - iss := seqnum.Value(testInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - TCPOpts: synAckOptions[:offset], - }) - - // Read ACK. - ackPacket := c.GetPacket() - - // Verify TCP header fields. - tcpCheckers := []checker.TransportChecker{ - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS) + 1), - checker.AckNum(uint32(iss) + 1), - } - - // Verify that tsEcr of ACK packet is wantOptions.TSVal if the - // timestamp option was enabled, if not then we verify that - // there is no timestamp in the ACK packet. - if wantOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...)) - - ackSeg := header.TCP(header.IPv4(ackPacket).Payload()) - ackOptions := ackSeg.ParsedOptions() - - // Wait for connection to be established. - select { - case <-notifyCh: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Store the source port in use by the endpoint. - c.Port = tcpSeg.SourcePort() - - // Mark in context that timestamp option is enabled for this endpoint. - c.TimeStampEnabled = true - - return &RawEndpoint{ - C: c, - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagAck | header.TCPFlagPsh, - NextSeqNum: iss + 1, - AckNum: c.IRS.Add(1), - WndSize: 30000, - RecentTS: ackOptions.TSVal, - TSVal: wantOptions.TSVal, - SACKPermitted: wantOptions.SACKPermitted, - } -} - -// AcceptWithOptions initializes a listening endpoint and connects to it with the -// provided options enabled. It also verifies that the SYN-ACK has the expected -// values for the provided options. -// -// The function returns a RawEndpoint representing the other end of the accepted -// endpoint. -func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if err := ep.Listen(10); err != nil { - c.t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept() - if err == tcpip.ErrWouldBlock { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept() - if err != nil { - c.t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - return rep -} - -// PassiveConnect just disables WindowScaling and delegates the call to -// PassiveConnectWithOptions. -func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { - synOptions.WS = -1 - c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) -} - -// PassiveConnectWithOptions initiates a new connection (with the specified TCP -// options enabled) to the port on which the Context.ep is listening for new -// connections. It also validates that the SYN-ACK has the expected values for -// the enabled options. -// -// NOTE: MSS is not a negotiated option and it can be asymmetric -// in each direction. This function uses the maxPayload to set the MSS to be -// sent to the peer on a connect and validates that the MSS in the SYN-ACK -// response is equal to the MTU - (tcphdr len + iphdr len). -// -// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the -// value of the window scaling option to be sent in the SYN. If synOptions.WS > -// 0 then we send the WindowScale option. -func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - opts := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - offset += header.EncodeMSSOption(uint32(maxPayload), opts) - - if synOptions.WS >= 0 { - offset += header.EncodeWSOption(3, opts[offset:]) - } - if synOptions.TS { - offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:]) - } - - if synOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(opts[offset:]) - } - - paddingToAdd := 4 - offset%4 - // Now add any padding bytes that might be required to quad align the - // options. - for i := offset; i < offset+paddingToAdd; i++ { - opts[i] = header.TCPOptionNOP - } - offset += paddingToAdd - - // Send a SYN request. - iss := seqnum.Value(testInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - TCPOpts: opts[:offset], - }) - - // Receive the SYN-ACK reply. Make sure MSS and other expected options - // are present. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(StackPort), - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(iss) + 1), - checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), - } - - // If TS option was enabled in the original SYN then add a checker to - // validate the Timestamp option in the SYN-ACK. - if synOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, b, checker.TCP(tcpCheckers...)) - rcvWnd := seqnum.Size(30000) - ackHeaders := &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: rcvWnd, - } - - // If WS was expected to be in effect then scale the advertised window - // correspondingly. - if synOptions.WS > 0 { - ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS) - } - - parsedOpts := tcp.ParsedOptions() - if synOptions.TS { - // Echo the tsVal back to the peer in the tsEcr field of the - // timestamp option. - // Increment TSVal by 1 from the value sent in the SYN and echo - // the TSVal in the SYN-ACK in the TSEcr field. - opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:]) - ackHeaders.TCPOpts = opts[:] - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.Port = StackPort - - return &RawEndpoint{ - C: c, - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagPsh | header.TCPFlagAck, - NextSeqNum: iss + 1, - AckNum: c.IRS + 1, - WndSize: rcvWnd, - SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(), - RecentTS: parsedOpts.TSVal, - TSVal: synOptions.TSVal + 1, - } -} - -// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true -// for the Stack in the context. -func (c *Context) SACKEnabled() bool { - var v tcp.SACKEnabled - if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { - // Stack doesn't support SACK. So just return. - return false - } - return bool(v) -} - -// SetGSOEnabled enables or disables generic segmentation offload. -func (c *Context) SetGSOEnabled(enable bool) { - c.linkEP.GSO = enable -} - -// MSSWithoutOptions returns the value for the MSS used by the stack when no -// options are in use. -func (c *Context) MSSWithoutOptions() uint16 { - return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) -} diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD deleted file mode 100644 index 4bec48c0f..000000000 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tcpconntrack", - srcs = ["tcp_conntrack.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack", - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( - name = "tcpconntrack_test", - size = "small", - srcs = ["tcp_conntrack_test.go"], - deps = [ - ":tcpconntrack", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go deleted file mode 100644 index 93712cd45..000000000 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package tcpconntrack implements a TCP connection tracking object. It allows -// users with access to a segment stream to figure out when a connection is -// established, reset, and closed (and in the last case, who closed first). -package tcpconntrack - -import ( - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -// Result is returned when the state of a TCB is updated in response to an -// inbound or outbound segment. -type Result int - -const ( - // ResultDrop indicates that the segment should be dropped. - ResultDrop Result = iota - - // ResultConnecting indicates that the connection remains in a - // connecting state. - ResultConnecting - - // ResultAlive indicates that the connection remains alive (connected). - ResultAlive - - // ResultReset indicates that the connection was reset. - ResultReset - - // ResultClosedByPeer indicates that the connection was gracefully - // closed, and the inbound stream was closed first. - ResultClosedByPeer - - // ResultClosedBySelf indicates that the connection was gracefully - // closed, and the outbound stream was closed first. - ResultClosedBySelf -) - -// TCB is a TCP Control Block. It holds state necessary to keep track of a TCP -// connection and inform the caller when the connection has been closed. -type TCB struct { - inbound stream - outbound stream - - // State handlers. - handlerInbound func(*TCB, header.TCP) Result - handlerOutbound func(*TCB, header.TCP) Result - - // firstFin holds a pointer to the first stream to send a FIN. - firstFin *stream - - // state is the current state of the stream. - state Result -} - -// Init initializes the state of the TCB according to the initial SYN. -func (t *TCB) Init(initialSyn header.TCP) Result { - t.handlerInbound = synSentStateInbound - t.handlerOutbound = synSentStateOutbound - - iss := seqnum.Value(initialSyn.SequenceNumber()) - t.outbound.una = iss - t.outbound.nxt = iss.Add(logicalLen(initialSyn)) - t.outbound.end = t.outbound.nxt - - // Even though "end" is a sequence number, we don't know the initial - // receive sequence number yet, so we store the window size until we get - // a SYN from the peer. - t.inbound.una = 0 - t.inbound.nxt = 0 - t.inbound.end = seqnum.Value(initialSyn.WindowSize()) - t.state = ResultConnecting - return t.state -} - -// UpdateStateInbound updates the state of the TCB based on the supplied inbound -// segment. -func (t *TCB) UpdateStateInbound(tcp header.TCP) Result { - st := t.handlerInbound(t, tcp) - if st != ResultDrop { - t.state = st - } - return st -} - -// UpdateStateOutbound updates the state of the TCB based on the supplied -// outbound segment. -func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { - st := t.handlerOutbound(t, tcp) - if st != ResultDrop { - t.state = st - } - return st -} - -// IsAlive returns true as long as the connection is established(Alive) -// or connecting state. -func (t *TCB) IsAlive() bool { - return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed()) -} - -// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream. -func (t *TCB) OutboundSendSequenceNumber() seqnum.Value { - return t.outbound.nxt -} - -// InboundSendSequenceNumber returns the snd.NXT for the inbound stream. -func (t *TCB) InboundSendSequenceNumber() seqnum.Value { - return t.inbound.nxt -} - -// adapResult modifies the supplied "Result" according to the state of the TCB; -// if r is anything other than "Alive", or if one of the streams isn't closed -// yet, it is returned unmodified. Otherwise it's converted to either -// ClosedBySelf or ClosedByPeer depending on which stream was closed first. -func (t *TCB) adaptResult(r Result) Result { - // Check the unmodified case. - if r != ResultAlive || !t.inbound.closed() || !t.outbound.closed() { - return r - } - - // Find out which was closed first. - if t.firstFin == &t.outbound { - return ResultClosedBySelf - } - - return ResultClosedByPeer -} - -// synSentStateInbound is the state handler for inbound segments when the -// connection is in SYN-SENT state. -func synSentStateInbound(t *TCB, tcp header.TCP) Result { - flags := tcp.Flags() - ackPresent := flags&header.TCPFlagAck != 0 - ack := seqnum.Value(tcp.AckNumber()) - - // Ignore segment if ack is present but not acceptable. - if ackPresent && !(ack-1).InRange(t.outbound.una, t.outbound.nxt) { - return ResultConnecting - } - - // If reset is specified, we will let the packet through no matter what - // but we will also destroy the connection if the ACK is present (and - // implicitly acceptable). - if flags&header.TCPFlagRst != 0 { - if ackPresent { - t.inbound.rstSeen = true - return ResultReset - } - return ResultConnecting - } - - // Ignore segment if SYN is not set. - if flags&header.TCPFlagSyn == 0 { - return ResultConnecting - } - - // Update state informed by this SYN. - irs := seqnum.Value(tcp.SequenceNumber()) - t.inbound.una = irs - t.inbound.nxt = irs.Add(logicalLen(tcp)) - t.inbound.end += irs - - t.outbound.end = t.outbound.una.Add(seqnum.Size(tcp.WindowSize())) - - // If the ACK was set (it is acceptable), update our unacknowledgement - // tracking. - if ackPresent { - // Advance the "una" and "end" indices of the outbound stream. - if t.outbound.una.LessThan(ack) { - t.outbound.una = ack - } - - if end := ack.Add(seqnum.Size(tcp.WindowSize())); t.outbound.end.LessThan(end) { - t.outbound.end = end - } - } - - // Update handlers so that new calls will be handled by new state. - t.handlerInbound = allOtherInbound - t.handlerOutbound = allOtherOutbound - - return ResultAlive -} - -// synSentStateOutbound is the state handler for outbound segments when the -// connection is in SYN-SENT state. -func synSentStateOutbound(t *TCB, tcp header.TCP) Result { - // Drop outbound segments that aren't retransmits of the original one. - if tcp.Flags() != header.TCPFlagSyn || - tcp.SequenceNumber() != uint32(t.outbound.una) { - return ResultDrop - } - - // Update the receive window. We only remember the largest value seen. - if wnd := seqnum.Value(tcp.WindowSize()); wnd > t.inbound.end { - t.inbound.end = wnd - } - - return ResultConnecting -} - -// update updates the state of inbound and outbound streams, given the supplied -// inbound segment. For outbound segments, this same function can be called with -// swapped inbound/outbound streams. -func update(tcp header.TCP, inbound, outbound *stream, firstFin **stream) Result { - // Ignore segments out of the window. - s := seqnum.Value(tcp.SequenceNumber()) - if !inbound.acceptable(s, dataLen(tcp)) { - return ResultAlive - } - - flags := tcp.Flags() - if flags&header.TCPFlagRst != 0 { - inbound.rstSeen = true - return ResultReset - } - - // Ignore segments that don't have the ACK flag, and those with the SYN - // flag. - if flags&header.TCPFlagAck == 0 || flags&header.TCPFlagSyn != 0 { - return ResultAlive - } - - // Ignore segments that acknowledge not yet sent data. - ack := seqnum.Value(tcp.AckNumber()) - if outbound.nxt.LessThan(ack) { - return ResultAlive - } - - // Advance the "una" and "end" indices of the outbound stream. - if outbound.una.LessThan(ack) { - outbound.una = ack - } - - if end := ack.Add(seqnum.Size(tcp.WindowSize())); outbound.end.LessThan(end) { - outbound.end = end - } - - // Advance the "nxt" index of the inbound stream. - end := s.Add(logicalLen(tcp)) - if inbound.nxt.LessThan(end) { - inbound.nxt = end - } - - // Note the index of the FIN segment. And stash away a pointer to the - // first stream to see a FIN. - if flags&header.TCPFlagFin != 0 && !inbound.finSeen { - inbound.finSeen = true - inbound.fin = end - 1 - - if *firstFin == nil { - *firstFin = inbound - } - } - - return ResultAlive -} - -// allOtherInbound is the state handler for inbound segments in all states -// except SYN-SENT. -func allOtherInbound(t *TCB, tcp header.TCP) Result { - return t.adaptResult(update(tcp, &t.inbound, &t.outbound, &t.firstFin)) -} - -// allOtherOutbound is the state handler for outbound segments in all states -// except SYN-SENT. -func allOtherOutbound(t *TCB, tcp header.TCP) Result { - return t.adaptResult(update(tcp, &t.outbound, &t.inbound, &t.firstFin)) -} - -// streams holds the state of a TCP unidirectional stream. -type stream struct { - // The interval [una, end) is the allowed interval as defined by the - // receiver, i.e., anything less than una has already been acknowledged - // and anything greater than or equal to end is beyond the receiver - // window. The interval [una, nxt) is the acknowledgable range, whose - // right edge indicates the sequence number of the next byte to be sent - // by the sender, i.e., anything greater than or equal to nxt hasn't - // been sent yet. - una seqnum.Value - nxt seqnum.Value - end seqnum.Value - - // finSeen indicates if a FIN has already been sent on this stream. - finSeen bool - - // fin is the sequence number of the FIN. It is only valid after finSeen - // is set to true. - fin seqnum.Value - - // rstSeen indicates if a RST has already been sent on this stream. - rstSeen bool -} - -// acceptable determines if the segment with the given sequence number and data -// length is acceptable, i.e., if it's within the [una, end) window or, in case -// the window is zero, if it's a packet with no payload and sequence number -// equal to una. -func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - wnd := s.una.Size(s.end) - if wnd == 0 { - return segLen == 0 && segSeq == s.una - } - - // Make sure [segSeq, seqSeq+segLen) is non-empty. - if segLen == 0 { - segLen = 1 - } - - return seqnum.Overlap(s.una, wnd, segSeq, segLen) -} - -// closed determines if the stream has already been closed. This happens when -// a FIN has been set by the sender and acknowledged by the receiver. -func (s *stream) closed() bool { - return s.finSeen && s.fin.LessThan(s.una) -} - -// dataLen returns the length of the TCP segment payload. -func dataLen(tcp header.TCP) seqnum.Size { - return seqnum.Size(len(tcp) - int(tcp.DataOffset())) -} - -// logicalLen calculates the logical length of the TCP segment. -func logicalLen(tcp header.TCP) seqnum.Size { - l := dataLen(tcp) - flags := tcp.Flags() - if flags&header.TCPFlagSyn != 0 { - l++ - } - if flags&header.TCPFlagFin != 0 { - l++ - } - return l -} diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go deleted file mode 100644 index 5e271b7ca..000000000 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go +++ /dev/null @@ -1,511 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpconntrack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" -) - -// connected creates a connection tracker TCB and sets it to a connected state -// by performing a 3-way handshake. -func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: iss, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: irw, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: irs, - AckNum: iss + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: isw, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: iss + 1, - AckNum: irs + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: irw, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - return &tcb -} - -func TestConnectionRefused(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive RST. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionRefusedInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionResetInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestRetransmitOnSynSent(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Retransmit the same SYN. - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting) - } -} - -func TestRetransmitOnSynRcvd(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. This will cause the state to go to SYN-RCVD. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Retransmit the original SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Transmit a SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} - -func TestClosedBySelf(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Send FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1236, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestClosedByPeer(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Receive FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 791, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer) - } -} - -func TestSendAndReceiveDataClosedBySelf(t *testing.T) { - sseq := uint32(1234) - rseq := uint32(789) - tcb := connected(t, sseq, rseq, 30000, 50000) - sseq++ - rseq++ - - // Send some data. - tcp := make(header.TCP, header.TCPMinimumSize+1024) - - for i := uint32(0); i < 10; i++ { - // Send some data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - sseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - for i := uint32(0); i < 10; i++ { - // Receive some data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - rseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - // Send FIN. - tcp = tcp[:header.TCPMinimumSize] - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - sseq++ - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - rseq++ - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestIgnoreBadResetOnSynSent(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive a RST with a bad ACK, it should not cause the connection to - // be reset. - acks := []uint32{1234, 1236, 1000, 5000} - flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} - for _, a := range acks { - for _, f := range flags { - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: a, - DataOffset: header.TCPMinimumSize, - Flags: f, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - } - - // Complete the handshake. - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD deleted file mode 100644 index 6dac66b50..000000000 --- a/pkg/tcpip/transport/udp/BUILD +++ /dev/null @@ -1,66 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "udp_packet_list", - out = "udp_packet_list.go", - package = "udp", - prefix = "udpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*udpPacket", - "Linker": "*udpPacket", - }, -) - -go_library( - name = "udp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "udp_packet_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp", - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/waiter", - ], -) - -go_test( - name = "udp_x_test", - size = "small", - srcs = ["udp_test.go"], - deps = [ - ":udp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) - -filegroup( - name = "autogen", - srcs = [ - "udp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go new file mode 100755 index 000000000..673a9373b --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_packet_list.go @@ -0,0 +1,173 @@ +package udp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type udpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type udpPacketList struct { + head *udpPacket + tail *udpPacket +} + +// Reset resets list l to the empty state. +func (l *udpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *udpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *udpPacketList) Front() *udpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *udpPacketList) Back() *udpPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *udpPacketList) PushFront(e *udpPacket) { + udpPacketElementMapper{}.linkerFor(e).SetNext(l.head) + udpPacketElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *udpPacketList) PushBack(e *udpPacket) { + udpPacketElementMapper{}.linkerFor(e).SetNext(nil) + udpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *udpPacketList) PushBackList(m *udpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *udpPacketList) InsertAfter(b, e *udpPacket) { + a := udpPacketElementMapper{}.linkerFor(b).Next() + udpPacketElementMapper{}.linkerFor(e).SetNext(a) + udpPacketElementMapper{}.linkerFor(e).SetPrev(b) + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *udpPacketList) InsertBefore(a, e *udpPacket) { + b := udpPacketElementMapper{}.linkerFor(a).Prev() + udpPacketElementMapper{}.linkerFor(e).SetNext(a) + udpPacketElementMapper{}.linkerFor(e).SetPrev(b) + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *udpPacketList) Remove(e *udpPacket) { + prev := udpPacketElementMapper{}.linkerFor(e).Prev() + next := udpPacketElementMapper{}.linkerFor(e).Next() + + if prev != nil { + udpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + udpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type udpPacketEntry struct { + next *udpPacket + prev *udpPacket +} + +// Next returns the entry that follows e in the list. +func (e *udpPacketEntry) Next() *udpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *udpPacketEntry) Prev() *udpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *udpPacketEntry) SetNext(elem *udpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *udpPacketEntry) SetPrev(elem *udpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go new file mode 100755 index 000000000..eb9d419ef --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -0,0 +1,128 @@ +// automatically generated by stateify. + +package udp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *udpPacket) beforeSave() {} +func (x *udpPacket) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("udpPacketEntry", &x.udpPacketEntry) + m.Save("senderAddress", &x.senderAddress) + m.Save("timestamp", &x.timestamp) +} + +func (x *udpPacket) afterLoad() {} +func (x *udpPacket) load(m state.Map) { + m.Load("udpPacketEntry", &x.udpPacketEntry) + m.Load("senderAddress", &x.senderAddress) + m.Load("timestamp", &x.timestamp) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("netProto", &x.netProto) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("rcvReady", &x.rcvReady) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("id", &x.id) + m.Save("state", &x.state) + m.Save("bindNICID", &x.bindNICID) + m.Save("regNICID", &x.regNICID) + m.Save("dstPort", &x.dstPort) + m.Save("v6only", &x.v6only) + m.Save("multicastTTL", &x.multicastTTL) + m.Save("multicastAddr", &x.multicastAddr) + m.Save("multicastNICID", &x.multicastNICID) + m.Save("multicastLoop", &x.multicastLoop) + m.Save("reusePort", &x.reusePort) + m.Save("broadcast", &x.broadcast) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("multicastMemberships", &x.multicastMemberships) + m.Save("effectiveNetProtos", &x.effectiveNetProtos) +} + +func (x *endpoint) load(m state.Map) { + m.Load("netProto", &x.netProto) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("rcvReady", &x.rcvReady) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("id", &x.id) + m.Load("state", &x.state) + m.Load("bindNICID", &x.bindNICID) + m.Load("regNICID", &x.regNICID) + m.Load("dstPort", &x.dstPort) + m.Load("v6only", &x.v6only) + m.Load("multicastTTL", &x.multicastTTL) + m.Load("multicastAddr", &x.multicastAddr) + m.Load("multicastNICID", &x.multicastNICID) + m.Load("multicastLoop", &x.multicastLoop) + m.Load("reusePort", &x.reusePort) + m.Load("broadcast", &x.broadcast) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("multicastMemberships", &x.multicastMemberships) + m.Load("effectiveNetProtos", &x.effectiveNetProtos) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *multicastMembership) beforeSave() {} +func (x *multicastMembership) save(m state.Map) { + x.beforeSave() + m.Save("nicID", &x.nicID) + m.Save("multicastAddr", &x.multicastAddr) +} + +func (x *multicastMembership) afterLoad() {} +func (x *multicastMembership) load(m state.Map) { + m.Load("nicID", &x.nicID) + m.Load("multicastAddr", &x.multicastAddr) +} + +func (x *udpPacketList) beforeSave() {} +func (x *udpPacketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *udpPacketList) afterLoad() {} +func (x *udpPacketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *udpPacketEntry) beforeSave() {} +func (x *udpPacketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *udpPacketEntry) afterLoad() {} +func (x *udpPacketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("udp.udpPacket", (*udpPacket)(nil), state.Fns{Save: (*udpPacket).save, Load: (*udpPacket).load}) + state.Register("udp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("udp.multicastMembership", (*multicastMembership)(nil), state.Fns{Save: (*multicastMembership).save, Load: (*multicastMembership).load}) + state.Register("udp.udpPacketList", (*udpPacketList)(nil), state.Fns{Save: (*udpPacketList).save, Load: (*udpPacketList).load}) + state.Register("udp.udpPacketEntry", (*udpPacketEntry)(nil), state.Fns{Save: (*udpPacketEntry).save, Load: (*udpPacketEntry).load}) +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go deleted file mode 100644 index 75129a2ff..000000000 --- a/pkg/tcpip/transport/udp/udp_test.go +++ /dev/null @@ -1,1007 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package udp_test - -import ( - "bytes" - "math" - "math/rand" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr - testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr - multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testAddr = "\x0a\x00\x00\x02" - testPort = 4096 - multicastAddr = "\xe8\x2b\xd3\xea" - multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - multicastPort = 1234 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 -) - -type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -type headers struct { - srcPort uint16 - dstPort uint16 -} - -func newDualTestContext(t *testing.T, mtu uint32) *testContext { - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) - - id, linkEP := channel.New(256, mtu, "") - if testing.Verbose() { - id = sniffer.New(id) - } - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", - NIC: 1, - }, - { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", - NIC: 1, - }, - }) - - return &testContext{ - t: t, - s: s, - linkEP: linkEP, - } -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createV6Endpoint(v6only bool) { - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } -} - -func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte { - select { - case p := <-c.linkEP.C: - if p.Proto != protocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber) - } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) - - var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker) - var srcAddr, dstAddr tcpip.Address - switch protocolNumber { - case ipv4.ProtocolNumber: - checkerFn = checker.IPv4 - srcAddr, dstAddr = stackAddr, testAddr - if multicast { - dstAddr = multicastAddr - } - case ipv6.ProtocolNumber: - checkerFn = checker.IPv6 - srcAddr, dstAddr = stackV6Addr, testV6Addr - if multicast { - dstAddr = multicastV6Addr - } - default: - c.t.Fatalf("unknown protocol %d", protocolNumber) - } - checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr)) - return b - - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") - } - - return nil -} - -func (c *testContext) sendV6Packet(payload []byte, h *headers) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) -} - -func (c *testContext) sendPacket(payload []byte, h *headers) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: testAddr, - DstAddr: stackAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testAddr, stackAddr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) -} - -func newPayload() []byte { - b := make([]byte, 30+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func TestBindPortReuse(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - var eps [5]tcpip.Endpoint - reusePortOpt := tcpip.ReusePortOption(1) - - pollChannel := make(chan tcpip.Endpoint) - for i := 0; i < len(eps); i++ { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - var err *tcpip.Error - eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(eps[i]) - - defer eps[i].Close() - if err := eps[i].SetSockOpt(reusePortOpt); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - } - - npackets := 100000 - nports := 10000 - ports := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort + port, - dstPort: stackPort, - }) - - var addr tcpip.FullAddress - ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - stats[ep]++ - if i < nports { - ports[uint16(i)] = ep - } else { - // Check that all packets from one client are handled - // by the same socket. - if ports[port] != ep { - t.Fatalf("Port mismatch") - } - } - } - - if len(stats) != len(eps) { - t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps)) - } - - // Check that a packet distribution is fair between sockets. - for _, c := range stats { - n := float64(npackets) / float64(len(eps)) - // The deviation is less than 10%. - if math.Abs(float64(c)-n) > n/10 { - t.Fatal(c, n) - } - } -} - -func testV4Read(c *testContext) { - // Send a packet. - payload := newPayload() - c.sendPacket(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.EventIn) - defer c.wq.EventUnregister(&we) - - var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) - if err == tcpip.ErrWouldBlock { - // Wait for data to become available. - select { - case <-ch: - v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") - } - } - - // Check the peer address. - if addr.Addr != testAddr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) - } - - // Check the payload. - if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) - } -} - -func TestBindEphemeralPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } -} - -func TestBindReservedPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - addr, err := c.ep.GetLocalAddress() - if err != nil { - t.Fatalf("GetLocalAddress failed: %v", err) - } - - // We can't bind the address reserved by the connected endpoint above. - { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) - } - } - - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - // We can't bind ipv4-any on the port reserved by the connected endpoint - // above, since the endpoint is dual-stack. - if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) - } - // We can bind an ipv4 address on this port, though. - if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - }() - - // Once the connected endpoint releases its port reservation, we are able to - // bind ipv4-any once again. - c.ep.Close() - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - }() -} - -func TestV4ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Read(c) -} - -func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Read(c) -} - -func TestV4ReadOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - // Bind to local address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Read(c) -} - -func TestV6ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Send a packet. - payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.EventIn) - defer c.wq.EventUnregister(&we) - - var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) - if err == tcpip.ErrWouldBlock { - // Wait for data to become available. - select { - case <-ch: - v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") - } - } - - // Check the peer address. - if addr.Addr != testV6Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) - } - - // Check the payload. - if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) - } -} - -func TestV4ReadOnV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - // Create v4 UDP endpoint. - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Read(c) -} - -func testV4Write(c *testContext) uint16 { - // Write to V4 mapped address. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } - - return udp.SourcePort() -} - -func testV6Write(c *testContext) uint16 { - // Write to v6 address. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } - - return udp.SourcePort() -} - -func testDualWrite(c *testContext) uint16 { - v4Port := testV4Write(c) - v6Port := testV6Write(c) - if v4Port != v6Port { - c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) - } - - return v4Port -} - -func TestDualWriteUnbound(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - testDualWrite(c) -} - -func TestDualWriteBoundToWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - p := testDualWrite(c) - if p != stackPort { - c.t.Fatalf("Bad port: got %v, want %v", p, stackPort) - } -} - -func TestDualWriteConnectedToV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testV6Write(c) - - // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNetworkUnreachable { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable) - } -} - -func TestDualWriteConnectedToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testV4Write(c) - - // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } -} - -func TestV4WriteOnV6Only(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(true) - - // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNoRoute { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute) - } -} - -func TestV6WriteOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - // Bind to v4 mapped address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } -} - -func TestV6WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } -} - -func TestV4WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } -} - -func TestReadIncrementsPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - // Create IPv4 UDP endpoint - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testV4Read(c) - - var want uint64 = 1 - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { - c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) - } -} - -func TestWriteIncrementsPacketsSent(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createV6Endpoint(false) - - testDualWrite(c) - - var want uint64 = 2 - if got := c.s.Stats().UDP.PacketsSent.Value(); got != want { - c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) - } -} - -func TestTTL(t *testing.T) { - payload := tcpip.SlicePayload(buffer.View(newPayload())) - - for _, name := range []string{"v4", "v6", "dual"} { - t.Run(name, func(t *testing.T) { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch name { - case "v4": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } - - var variants []string - switch name { - case "v4": - variants = []string{"v4"} - case "v6": - variants = []string{"v6"} - case "dual": - variants = []string{"v6", "mapped"} - } - - for _, variant := range variants { - t.Run(variant, func(t *testing.T) { - for _, typ := range []string{"unicast", "multicast"} { - t.Run(typ, func(t *testing.T) { - var addr tcpip.Address - var port uint16 - switch typ { - case "unicast": - port = testPort - switch variant { - case "v4": - addr = testAddr - case "mapped": - addr = testV4MappedAddr - case "v6": - addr = testV6Addr - default: - t.Fatal("unknown test variant") - } - case "multicast": - port = multicastPort - switch variant { - case "v4": - addr = multicastAddr - case "mapped": - addr = multicastV4MappedAddr - case "v6": - addr = multicastV6Addr - default: - t.Fatal("unknown test variant") - } - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - switch name { - case "v4": - case "v6": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - case "dual": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - default: - t.Fatal("unknown test variant") - } - - const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) - } - - checkerFn := checker.IPv4 - switch variant { - case "v4", "mapped": - case "v6": - checkerFn = checker.IPv6 - default: - t.Fatal("unknown test variant") - } - var wantTTL uint8 - var multicast bool - switch typ { - case "unicast": - multicast = false - switch variant { - case "v4", "mapped": - ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - case "v6": - ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - default: - t.Fatal("unknown test variant") - } - case "multicast": - wantTTL = multicastTTL - multicast = true - default: - t.Fatal("unknown test variant") - } - - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch variant { - case "v4", "mapped": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } - - b := c.getPacket(networkProtocolNumber, multicast) - checkerFn(c.t, b, - checker.TTL(wantTTL), - checker.UDP( - checker.DstPort(port), - ), - ) - }) - } - }) - } - }) - } -} diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD deleted file mode 100644 index 98d51cc69..000000000 --- a/pkg/tmutex/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tmutex", - srcs = ["tmutex.go"], - importpath = "gvisor.dev/gvisor/pkg/tmutex", - visibility = ["//:sandbox"], -) - -go_test( - name = "tmutex_test", - size = "medium", - srcs = ["tmutex_test.go"], - embed = [":tmutex"], -) diff --git a/pkg/tmutex/tmutex_state_autogen.go b/pkg/tmutex/tmutex_state_autogen.go new file mode 100755 index 000000000..2b2bb599e --- /dev/null +++ b/pkg/tmutex/tmutex_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package tmutex + diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go deleted file mode 100644 index ce34c7962..000000000 --- a/pkg/tmutex/tmutex_test.go +++ /dev/null @@ -1,257 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmutex - -import ( - "fmt" - "runtime" - "sync" - "sync/atomic" - "testing" - "time" -) - -func TestBasicLock(t *testing.T) { - var m Mutex - m.Init() - - m.Lock() - - // Try blocking lock the mutex from a different goroutine. This must - // not block because the mutex is held. - ch := make(chan struct{}, 1) - go func() { - m.Lock() - ch <- struct{}{} - m.Unlock() - ch <- struct{}{} - }() - - select { - case <-ch: - t.Fatalf("Lock succeeded on locked mutex") - case <-time.After(100 * time.Millisecond): - } - - // Unlock the mutex and make sure that the goroutine waiting on Lock() - // unblocks and succeeds. - m.Unlock() - - select { - case <-ch: - case <-time.After(100 * time.Millisecond): - t.Fatalf("Lock failed to acquire unlocked mutex") - } - - // Make sure we can lock and unlock again. - m.Lock() - m.Unlock() -} - -func TestTryLock(t *testing.T) { - var m Mutex - m.Init() - - // Try to lock. It should succeed. - if !m.TryLock() { - t.Fatalf("TryLock failed on unlocked mutex") - } - - // Try to lock again, it should now fail. - if m.TryLock() { - t.Fatalf("TryLock succeeded on locked mutex") - } - - // Try blocking lock the mutex from a different goroutine. This must - // not block because the mutex is held. - ch := make(chan struct{}, 1) - go func() { - m.Lock() - ch <- struct{}{} - m.Unlock() - }() - - select { - case <-ch: - t.Fatalf("Lock succeeded on locked mutex") - case <-time.After(100 * time.Millisecond): - } - - // Unlock the mutex and make sure that the goroutine waiting on Lock() - // unblocks and succeeds. - m.Unlock() - - select { - case <-ch: - case <-time.After(100 * time.Millisecond): - t.Fatalf("Lock failed to acquire unlocked mutex") - } -} - -func TestMutualExclusion(t *testing.T) { - var m Mutex - m.Init() - - // Test mutual exclusion by running "gr" goroutines concurrently, and - // have each one increment a counter "iters" times within the critical - // section established by the mutex. - // - // If at the end the counter is not gr * iters, then we know that - // goroutines ran concurrently within the critical section. - // - // If one of the goroutines doesn't complete, it's likely a bug that - // causes to it to wait forever. - const gr = 1000 - const iters = 100000 - v := 0 - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(1) - go func() { - for j := 0; j < iters; j++ { - m.Lock() - v++ - m.Unlock() - } - wg.Done() - }() - } - - wg.Wait() - - if v != gr*iters { - t.Fatalf("Bad count: got %v, want %v", v, gr*iters) - } -} - -func TestMutualExclusionWithTryLock(t *testing.T) { - var m Mutex - m.Init() - - // Similar to the previous, with the addition of some goroutines that - // only increment the count if TryLock succeeds. - const gr = 1000 - const iters = 100000 - total := int64(gr * iters) - var tryTotal int64 - v := int64(0) - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(2) - go func() { - for j := 0; j < iters; j++ { - m.Lock() - v++ - m.Unlock() - } - wg.Done() - }() - go func() { - local := int64(0) - for j := 0; j < iters; j++ { - if m.TryLock() { - v++ - m.Unlock() - local++ - } - } - atomic.AddInt64(&tryTotal, local) - wg.Done() - }() - } - - wg.Wait() - - t.Logf("tryTotal = %d", tryTotal) - total += tryTotal - - if v != total { - t.Fatalf("Bad count: got %v, want %v", v, total) - } -} - -// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following -// differences: -// -// - The number of goroutines is variable, with the maximum value depending on -// GOMAXPROCS. -// -// - The number of iterations per benchmark is controlled by the benchmarking -// framework. -// -// - Care is taken to ensure that all goroutines participating in the benchmark -// have been created before the benchmark begins. -func BenchmarkTmutex(b *testing.B) { - for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var m Mutex - m.Init() - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for i := 0; i < n; i++ { - ready.Add(1) - end.Add(1) - go func() { - ready.Done() - <-begin - for j := 0; j < b.N; j++ { - m.Lock() - m.Unlock() - } - end.Done() - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } -} - -// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as -// a comparison point. -func BenchmarkSyncMutex(b *testing.B) { - for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var m sync.Mutex - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for i := 0; i < n; i++ { - ready.Add(1) - end.Add(1) - go func() { - ready.Done() - <-begin - for j := 0; j < b.N; j++ { - m.Lock() - m.Unlock() - } - end.Done() - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } -} diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD deleted file mode 100644 index 769509e80..000000000 --- a/pkg/unet/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "unet", - srcs = [ - "unet.go", - "unet_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/unet", - visibility = ["//visibility:public"], - deps = [ - "//pkg/abi/linux", - "//pkg/gate", - ], -) - -go_test( - name = "unet_test", - size = "small", - srcs = [ - "unet_test.go", - ], - embed = [":unet"], -) diff --git a/pkg/unet/unet_state_autogen.go b/pkg/unet/unet_state_autogen.go new file mode 100755 index 000000000..1f7c7fa59 --- /dev/null +++ b/pkg/unet/unet_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package unet + diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go deleted file mode 100644 index a3cc6f5d3..000000000 --- a/pkg/unet/unet_test.go +++ /dev/null @@ -1,735 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package unet - -import ( - "io/ioutil" - "os" - "path/filepath" - "reflect" - "sync" - "syscall" - "testing" - "time" -) - -func randomFilename() (string, error) { - // Return a randomly generated file in the test dir. - f, err := ioutil.TempFile("", "unet-test") - if err != nil { - return "", err - } - file := f.Name() - os.Remove(file) - f.Close() - - cwd, err := os.Getwd() - if err != nil { - return "", err - } - - // NOTE(b/26918832): We try to use relative path if possible. This is - // to help conforming to the unix path length limit. - if rel, err := filepath.Rel(cwd, file); err == nil { - return rel, nil - } - - return file, nil -} - -func TestConnectFailure(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - if _, err := Connect(name, false); err == nil { - t.Fatalf("connect was successful, expected err") - } -} - -func TestBindFailure(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) - } - defer ss.Close() - - if _, err = BindAndListen(name, false); err == nil { - t.Fatalf("second bind succeeded, expected non-nil err") - } -} - -func TestMultipleAccept(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) - } - defer ss.Close() - - // Connect backlog times asynchronously. - var wg sync.WaitGroup - defer wg.Wait() - for i := 0; i < backlog; i++ { - wg.Add(1) - go func() { - defer wg.Done() - s, err := Connect(name, false) - if err != nil { - t.Fatalf("connect failed, got err %v expected nil", err) - } - s.Close() - }() - } - - // Accept backlog times. - for i := 0; i < backlog; i++ { - s, err := ss.Accept() - if err != nil { - t.Errorf("accept failed, got err %v expected nil", err) - continue - } - s.Close() - } -} - -func TestServerClose(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) - } - - // Make sure the first close succeeds. - if err := ss.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) - } - - // The second one should fail. - if err := ss.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") - } -} - -func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - // Bind a server. - ss, err := BindAndListen(name, packet) - if err != nil { - t.Fatalf("error binding, got %v expected nil", err) - } - defer ss.Close() - - // Accept a client. - acceptSocket := make(chan *Socket) - acceptErr := make(chan error) - go func() { - server, err := ss.Accept() - if err != nil { - acceptErr <- err - } - acceptSocket <- server - }() - - // Connect the client. - client, err := Connect(name, packet) - if err != nil { - t.Fatalf("error connecting, got %v expected nil", err) - } - - // Grab the server handle. - select { - case server := <-acceptSocket: - return server, client - case err := <-acceptErr: - t.Fatalf("accept error: %v", err) - } - panic("unreachable") -} - -func TestSendRecv(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the server. - b := [][]byte{{'b'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } -} - -// TestSymmetric exists to assert that the two sockets received from socketPair -// are interchangeable. They should be, this just provides a basic sanity check -// by running TestSendRecv "backwards". -func TestSymmetric(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the server. - w := server.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the client. - b := [][]byte{{'b'}} - r := client.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } -} - -func TestPacket(t *testing.T) { - server, client := socketPair(t, true) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Write on the client again. - w = client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the server. - // - // This should only get back a single byte, despite the buffer - // being size two. This is because it's a _packet_ buffer. - b := [][]byte{{'b', 'b'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } - - // Do it again. - r = server.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } -} - -func TestClose(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - - // Make sure the first close succeeds. - if err := client.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) - } - - // The second one should fail. - if err := client.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") - } -} - -func TestNonBlockingSend(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Try up to 1000 writes, of 1000 bytes. - blockCount := 0 - for i := 0; i < 1000; i++ { - w := client.Writer(false) - if n, err := w.WriteVec([][]byte{make([]byte, 1000)}); n != 1000 || err != nil { - if err == syscall.EWOULDBLOCK || err == syscall.EAGAIN { - // We're good. That's what we wanted. - blockCount++ - } else { - t.Fatalf("for client write, got n=%d err=%v, expected n=1000 err=nil", n, err) - } - } - } - - if blockCount == 1000 { - // Shouldn't have _always_ blocked. - t.Fatalf("socket always blocked!") - } else if blockCount == 0 { - // Should have started blocking eventually. - t.Fatalf("socket never blocked!") - } -} - -func TestNonBlockingRecv(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - b := [][]byte{{'b'}} - r := client.Reader(false) - - // Expected to block immediately. - _, err := r.ReadVec(b) - if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) - } - - // Put some data in the pipe. - w := server.Writer(false) - if n, err := w.WriteVec(b); n != 1 || err != nil { - t.Fatalf("write failed with n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Expect it not to block. - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("read failed with n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Expect it to return a block error again. - r = client.Reader(false) - _, err = r.ReadVec(b) - if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) - } -} - -func TestRecvVectors(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) - } - - // Read on the server. - b := [][]byte{{'c'}, {'c'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) - } - if b[0][0] != 'a' || b[1][0] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) - } -} - -func TestSendVectors(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the client. - w := client.Writer(true) - if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) - } - - // Read on the server. - b := [][]byte{{'c', 'c'}} - r := server.Reader(true) - if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) - } - if b[0][0] != 'a' || b[0][1] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) - } -} - -func TestSendFDsNotEnabled(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Write on the server. - w := server.Writer(true) - w.PackFDs(0, 1, 2) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) - } - - // Read on the client, without enabling FDs. - b := [][]byte{{'b'}} - r := client.Reader(true) - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } - - // Make sure the FDs are not received. - fds, err := r.ExtractFDs() - if len(fds) != 0 || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) - } -} - -func sendFDs(t *testing.T, s *Socket, fds []int) { - w := s.Writer(true) - w.PackFDs(fds...) - if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for write, got n=%d err=%v, expected n=1 err=nil", n, err) - } -} - -func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { - expected := len(origFDs) - - // Count the number of FDs. - preEntries, err := ioutil.ReadDir("/proc/self/fd") - if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) - } - - // Read on the client. - b := [][]byte{{'b'}} - r := s.Reader(true) - if enableSize >= 0 { - r.EnableFDs(enableSize) - } - if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) - } - if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) - } - - // Count the new number of FDs. - postEntries, err := ioutil.ReadDir("/proc/self/fd") - if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) - } - if len(preEntries)+expected != len(postEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) - } - - // Make sure the FDs are there. - fds, err := r.ExtractFDs() - if len(fds) != expected || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) - } - - // Make sure they are different from the originals. - for i := 0; i < len(fds); i++ { - if fds[i] == origFDs[i] { - t.Errorf("got original fd for index %d, expected different", i) - } - } - - // Make sure they can be accessed as expected. - for i := 0; i < len(fds); i++ { - var st syscall.Stat_t - if err := syscall.Fstat(fds[i], &st); err != nil { - t.Errorf("fds[%d] can't be stated, got err %v expected nil", i, err) - } - } - - // Close them off. - r.CloseFDs() - - // Make sure the count is back to normal. - finalEntries, err := ioutil.ReadDir("/proc/self/fd") - if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) - } - if len(finalEntries) != len(preEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) - } -} - -func TestFDsSingle(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, 1, []int{0}) -} - -func TestFDsMultiple(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - // Basic case, multiple FDs. - sendFDs(t, server, []int{0, 1, 2}) - recvFDs(t, client, 3, []int{0, 1, 2}) -} - -// See TestSymmetric above. -func TestFDsSymmetric(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0, 1, 2}) - recvFDs(t, client, 3, []int{0, 1, 2}) -} - -func TestFDsReceiveLargeBuffer(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, 3, []int{0}) -} - -func TestFDsReceiveSmallBuffer(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0, 1, 2}) - - // Per the spec, we may still receive more than the buffer. In fact, - // it'll be rounded up and we can receive two with a size one buffer. - recvFDs(t, client, 1, []int{0, 1}) -} - -func TestFDsReceiveNotEnabled(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, -1, []int{}) -} - -func TestFDsReceiveSizeZero(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - sendFDs(t, server, []int{0}) - recvFDs(t, client, 0, []int{}) -} - -func TestGetPeerCred(t *testing.T) { - server, client := socketPair(t, false) - defer server.Close() - defer client.Close() - - want := &syscall.Ucred{ - Pid: int32(os.Getpid()), - Uid: uint32(os.Getuid()), - Gid: uint32(os.Getgid()), - } - - if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) - } -} - -func newClosedSocket() (*Socket, error) { - fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - return nil, err - } - - s, err := NewSocket(fd) - if err != nil { - syscall.Close(fd) - return nil, err - } - - return s, s.Close() -} - -func TestGetPeerCredFailure(t *testing.T) { - s, err := newClosedSocket() - if err != nil { - t.Fatalf("newClosedSocket got error %v want nil", err) - } - - want := "bad file descriptor" - if _, err := s.GetPeerCred(); err == nil || err.Error() != want { - t.Errorf("got s.GetPeerCred() = %v, want = %s", err, want) - } -} - -func TestAcceptClosed(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) - } - - if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) - } - - if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) - } -} - -func TestCloseAfterAcceptStart(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - time.Sleep(50 * time.Millisecond) - if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) - } - wg.Done() - }() - - if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) - } - - wg.Wait() -} - -func TestReleaseAfterAcceptStart(t *testing.T) { - name, err := randomFilename() - if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) - } - - ss, err := BindAndListen(name, false) - if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - time.Sleep(50 * time.Millisecond) - fd, err := ss.Release() - if err != nil { - t.Fatalf("Release failed, got err %v expected nil", err) - } - syscall.Close(fd) - wg.Done() - }() - - if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) - } - - wg.Wait() -} - -func TestControlMessage(t *testing.T) { - for i := 0; i <= 10; i++ { - var want []int - for j := 0; j < i; j++ { - want = append(want, i+j+1) - } - - var cm ControlMessage - cm.EnableFDs(i) - cm.PackFDs(want...) - got, err := cm.ExtractFDs() - if err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) - } - } -} - -func benchmarkSendRecv(b *testing.B, packet bool) { - server, client, err := SocketPair(packet) - if err != nil { - b.Fatalf("SocketPair: got %v, wanted nil", err) - } - defer server.Close() - defer client.Close() - go func() { - buf := make([]byte, 1) - for i := 0; i < b.N; i++ { - n, err := server.Read(buf) - if n != 1 || err != nil { - b.Fatalf("server.Read: got (%d, %v), wanted (1, nil)", n, err) - } - n, err = server.Write(buf) - if n != 1 || err != nil { - b.Fatalf("server.Write: got (%d, %v), wanted (1, nil)", n, err) - } - } - }() - buf := make([]byte, 1) - b.ResetTimer() - for i := 0; i < b.N; i++ { - n, err := client.Write(buf) - if n != 1 || err != nil { - b.Fatalf("client.Write: got (%d, %v), wanted (1, nil)", n, err) - } - n, err = client.Read(buf) - if n != 1 || err != nil { - b.Fatalf("client.Read: got (%d, %v), wanted (1, nil)", n, err) - } - } -} - -func BenchmarkSendRecvStream(b *testing.B) { - benchmarkSendRecv(b, false) -} - -func BenchmarkSendRecvPacket(b *testing.B) { - benchmarkSendRecv(b, true) -} diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD deleted file mode 100644 index b7f505a84..000000000 --- a/pkg/urpc/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "urpc", - srcs = ["urpc.go"], - importpath = "gvisor.dev/gvisor/pkg/urpc", - visibility = ["//:sandbox"], - deps = [ - "//pkg/fd", - "//pkg/log", - "//pkg/unet", - ], -) - -go_test( - name = "urpc_test", - size = "small", - srcs = ["urpc_test.go"], - embed = [":urpc"], - deps = ["//pkg/unet"], -) diff --git a/pkg/urpc/urpc_state_autogen.go b/pkg/urpc/urpc_state_autogen.go new file mode 100755 index 000000000..01bf2172f --- /dev/null +++ b/pkg/urpc/urpc_state_autogen.go @@ -0,0 +1,4 @@ +// automatically generated by stateify. + +package urpc + diff --git a/pkg/urpc/urpc_test.go b/pkg/urpc/urpc_test.go deleted file mode 100644 index c6c7ce9d4..000000000 --- a/pkg/urpc/urpc_test.go +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package urpc - -import ( - "errors" - "os" - "testing" - - "gvisor.dev/gvisor/pkg/unet" -) - -type test struct { -} - -type testArg struct { - StringArg string - IntArg int - FilePayload -} - -type testResult struct { - StringResult string - IntResult int - FilePayload -} - -func (t test) Func(a *testArg, r *testResult) error { - r.StringResult = a.StringArg - r.IntResult = a.IntArg - return nil -} - -func (t test) Err(a *testArg, r *testResult) error { - return errors.New("test error") -} - -func (t test) FailNoFile(a *testArg, r *testResult) error { - if a.Files == nil { - return errors.New("no file found") - } - - return nil -} - -func (t test) SendFile(a *testArg, r *testResult) error { - r.Files = []*os.File{os.Stdin, os.Stdout, os.Stderr} - return nil -} - -func (t test) TooManyFiles(a *testArg, r *testResult) error { - for i := 0; i <= maxFiles; i++ { - r.Files = append(r.Files, os.Stdin) - } - return nil -} - -func startServer(socket *unet.Socket) { - s := NewServer() - s.Register(test{}) - s.StartHandling(socket) -} - -func testClient() (*Client, error) { - serverSock, clientSock, err := unet.SocketPair(false) - if err != nil { - return nil, err - } - startServer(serverSock) - - return NewClient(clientSock), nil -} - -func TestCall(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.Func", &testArg{}, &r); err != nil { - t.Errorf("basic call failed: %v", err) - } else if r.StringResult != "" || r.IntResult != 0 { - t.Errorf("unexpected result, got %v expected zero value", r) - } - if err := c.Call("test.Func", &testArg{StringArg: "hello"}, &r); err != nil { - t.Errorf("basic call failed: %v", err) - } else if r.StringResult != "hello" { - t.Errorf("unexpected result, got %v expected hello", r.StringResult) - } - if err := c.Call("test.Func", &testArg{IntArg: 1}, &r); err != nil { - t.Errorf("basic call failed: %v", err) - } else if r.IntResult != 1 { - t.Errorf("unexpected result, got %v expected 1", r.IntResult) - } -} - -func TestUnknownMethod(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.Unknown", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } else if err.Error() != ErrUnknownMethod.Error() { - t.Errorf("expected test error, got %v", err) - } -} - -func TestErr(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.Err", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } else if err.Error() != "test error" { - t.Errorf("expected test error, got %v", err) - } -} - -func TestSendFile(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.FailNoFile", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } - if err := c.Call("test.FailNoFile", &testArg{FilePayload: FilePayload{Files: []*os.File{os.Stdin, os.Stdout, os.Stdin}}}, &r); err != nil { - t.Errorf("expected nil err, got %v", err) - } -} - -func TestRecvFile(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - if err := c.Call("test.SendFile", &testArg{}, &r); err != nil { - t.Errorf("expected nil err, got %v", err) - } - if r.Files == nil { - t.Errorf("expected file, got nil") - } -} - -func TestShutdown(t *testing.T) { - serverSock, clientSock, err := unet.SocketPair(false) - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - clientSock.Close() - - s := NewServer() - if err := s.Handle(serverSock); err == nil { - t.Errorf("expected non-nil err, got nil") - } -} - -func TestTooManyFiles(t *testing.T) { - c, err := testClient() - if err != nil { - t.Fatalf("error creating test client: %v", err) - } - defer c.Close() - - var r testResult - var a testArg - for i := 0; i <= maxFiles; i++ { - a.Files = append(a.Files, os.Stdin) - } - - // Client-side error. - if err := c.Call("test.Func", &a, &r); err != ErrTooManyFiles { - t.Errorf("expected ErrTooManyFiles, got %v", err) - } - - // Server-side error. - if err := c.Call("test.TooManyFiles", &testArg{}, &r); err == nil { - t.Errorf("expected non-nil err, got nil") - } else if err.Error() != "too many files" { - t.Errorf("expected too many files, got %v", err.Error()) - } -} diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD deleted file mode 100644 index 9173dfd0f..000000000 --- a/pkg/waiter/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -package(licenses = ["notice"]) - -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_template_instance( - name = "waiter_list", - out = "waiter_list.go", - package = "waiter", - prefix = "waiter", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Entry", - "Linker": "*Entry", - }, -) - -go_library( - name = "waiter", - srcs = [ - "waiter.go", - "waiter_list.go", - ], - importpath = "gvisor.dev/gvisor/pkg/waiter", - visibility = ["//visibility:public"], -) - -go_test( - name = "waiter_test", - size = "small", - srcs = [ - "waiter_test.go", - ], - embed = [":waiter"], -) - -filegroup( - name = "autogen", - srcs = [ - "waiter_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/waiter/waiter_list.go b/pkg/waiter/waiter_list.go new file mode 100755 index 000000000..00b304a31 --- /dev/null +++ b/pkg/waiter/waiter_list.go @@ -0,0 +1,173 @@ +package waiter + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type waiterElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (waiterElementMapper) linkerFor(elem *Entry) *Entry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type waiterList struct { + head *Entry + tail *Entry +} + +// Reset resets list l to the empty state. +func (l *waiterList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *waiterList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *waiterList) Front() *Entry { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *waiterList) Back() *Entry { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *waiterList) PushFront(e *Entry) { + waiterElementMapper{}.linkerFor(e).SetNext(l.head) + waiterElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + waiterElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *waiterList) PushBack(e *Entry) { + waiterElementMapper{}.linkerFor(e).SetNext(nil) + waiterElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *waiterList) PushBackList(m *waiterList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head) + waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *waiterList) InsertAfter(b, e *Entry) { + a := waiterElementMapper{}.linkerFor(b).Next() + waiterElementMapper{}.linkerFor(e).SetNext(a) + waiterElementMapper{}.linkerFor(e).SetPrev(b) + waiterElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + waiterElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *waiterList) InsertBefore(a, e *Entry) { + b := waiterElementMapper{}.linkerFor(a).Prev() + waiterElementMapper{}.linkerFor(e).SetNext(a) + waiterElementMapper{}.linkerFor(e).SetPrev(b) + waiterElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + waiterElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *waiterList) Remove(e *Entry) { + prev := waiterElementMapper{}.linkerFor(e).Prev() + next := waiterElementMapper{}.linkerFor(e).Next() + + if prev != nil { + waiterElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + waiterElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type waiterEntry struct { + next *Entry + prev *Entry +} + +// Next returns the entry that follows e in the list. +func (e *waiterEntry) Next() *Entry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *waiterEntry) Prev() *Entry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *waiterEntry) SetNext(elem *Entry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *waiterEntry) SetPrev(elem *Entry) { + e.prev = elem +} diff --git a/pkg/waiter/waiter_state_autogen.go b/pkg/waiter/waiter_state_autogen.go new file mode 100755 index 000000000..b73646808 --- /dev/null +++ b/pkg/waiter/waiter_state_autogen.go @@ -0,0 +1,67 @@ +// automatically generated by stateify. + +package waiter + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Entry) beforeSave() {} +func (x *Entry) save(m state.Map) { + x.beforeSave() + m.Save("Context", &x.Context) + m.Save("Callback", &x.Callback) + m.Save("mask", &x.mask) + m.Save("waiterEntry", &x.waiterEntry) +} + +func (x *Entry) afterLoad() {} +func (x *Entry) load(m state.Map) { + m.Load("Context", &x.Context) + m.Load("Callback", &x.Callback) + m.Load("mask", &x.mask) + m.Load("waiterEntry", &x.waiterEntry) +} + +func (x *Queue) beforeSave() {} +func (x *Queue) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(x.list) { m.Failf("list is %v, expected zero", x.list) } +} + +func (x *Queue) afterLoad() {} +func (x *Queue) load(m state.Map) { +} + +func (x *waiterList) beforeSave() {} +func (x *waiterList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *waiterList) afterLoad() {} +func (x *waiterList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *waiterEntry) beforeSave() {} +func (x *waiterEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *waiterEntry) afterLoad() {} +func (x *waiterEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("waiter.Entry", (*Entry)(nil), state.Fns{Save: (*Entry).save, Load: (*Entry).load}) + state.Register("waiter.Queue", (*Queue)(nil), state.Fns{Save: (*Queue).save, Load: (*Queue).load}) + state.Register("waiter.waiterList", (*waiterList)(nil), state.Fns{Save: (*waiterList).save, Load: (*waiterList).load}) + state.Register("waiter.waiterEntry", (*waiterEntry)(nil), state.Fns{Save: (*waiterEntry).save, Load: (*waiterEntry).load}) +} diff --git a/pkg/waiter/waiter_test.go b/pkg/waiter/waiter_test.go deleted file mode 100644 index c1b94a4f3..000000000 --- a/pkg/waiter/waiter_test.go +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package waiter - -import ( - "sync/atomic" - "testing" -) - -type callbackStub struct { - f func(e *Entry) -} - -// Callback implements EntryCallback.Callback. -func (c *callbackStub) Callback(e *Entry) { - c.f(e) -} - -func TestEmptyQueue(t *testing.T) { - var q Queue - - // Notify the zero-value of a queue. - q.Notify(EventIn) - - // Register then unregister a waiter, then notify the queue. - cnt := 0 - e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} - q.EventRegister(&e, EventIn) - q.EventUnregister(&e) - q.Notify(EventIn) - if cnt != 0 { - t.Errorf("Callback was called when it shouldn't have been") - } -} - -func TestMask(t *testing.T) { - // Register a waiter. - var q Queue - var cnt int - e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} - q.EventRegister(&e, EventIn|EventErr) - - // Notify with an overlapping mask. - cnt = 0 - q.Notify(EventIn | EventOut) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a subset mask. - cnt = 0 - q.Notify(EventIn) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a superset mask. - cnt = 0 - q.Notify(EventIn | EventErr | EventOut) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with the exact same mask. - cnt = 0 - q.Notify(EventIn | EventErr) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a disjoint mask. - cnt = 0 - q.Notify(EventOut | EventHUp) - if cnt != 0 { - t.Errorf("Callback was called when it shouldn't have been") - } -} - -func TestConcurrentRegistration(t *testing.T) { - var q Queue - var cnt int - const concurrency = 1000 - - ch1 := make(chan struct{}) - ch2 := make(chan struct{}) - ch3 := make(chan struct{}) - - // Create goroutines that will all register/unregister concurrently. - for i := 0; i < concurrency; i++ { - go func() { - var e Entry - e.Callback = &callbackStub{func(entry *Entry) { - cnt++ - if entry != &e { - t.Errorf("entry = %p, want %p", entry, &e) - } - }} - - // Wait for notification, then register. - <-ch1 - q.EventRegister(&e, EventIn|EventErr) - - // Tell main goroutine that we're done registering. - ch2 <- struct{}{} - - // Wait for notification, then unregister. - <-ch3 - q.EventUnregister(&e) - - // Tell main goroutine that we're done unregistering. - ch2 <- struct{}{} - }() - } - - // Let the goroutines register. - close(ch1) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Issue a notification. - q.Notify(EventIn) - if cnt != concurrency { - t.Errorf("cnt = %d, want %d", cnt, concurrency) - } - - // Let the goroutine unregister. - close(ch3) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Issue a notification. - q.Notify(EventIn) - if cnt != concurrency { - t.Errorf("cnt = %d, want %d", cnt, concurrency) - } -} - -func TestConcurrentNotification(t *testing.T) { - var q Queue - var cnt int32 - const concurrency = 1000 - const waiterCount = 1000 - - // Register waiters. - for i := 0; i < waiterCount; i++ { - var e Entry - e.Callback = &callbackStub{func(entry *Entry) { - atomic.AddInt32(&cnt, 1) - if entry != &e { - t.Errorf("entry = %p, want %p", entry, &e) - } - }} - - q.EventRegister(&e, EventIn|EventErr) - } - - // Launch notifiers. - ch1 := make(chan struct{}) - ch2 := make(chan struct{}) - for i := 0; i < concurrency; i++ { - go func() { - <-ch1 - q.Notify(EventIn) - ch2 <- struct{}{} - }() - } - - // Let notifiers go. - close(ch1) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Check the count. - if cnt != concurrency*waiterCount { - t.Errorf("cnt = %d, want %d", cnt, concurrency*waiterCount) - } -} |