diff options
Diffstat (limited to 'pkg/sentry/socket')
30 files changed, 1756 insertions, 814 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD deleted file mode 100644 index cc1f6bfcc..000000000 --- a/pkg/sentry/socket/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "socket", - srcs = ["socket.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/marshal", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/vfs", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD deleted file mode 100644 index fb7c5dc61..000000000 --- a/pkg/sentry/socket/control/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "control", - srcs = [ - "control.go", - "control_vfs2.go", - ], - imports = [ - "gvisor.dev/gvisor/pkg/sentry/fs", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/sentry/fs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/vfs", - "//pkg/syserror", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/socket/control/control_state_autogen.go b/pkg/sentry/socket/control/control_state_autogen.go new file mode 100644 index 000000000..0f567afd4 --- /dev/null +++ b/pkg/sentry/socket/control/control_state_autogen.go @@ -0,0 +1,58 @@ +// automatically generated by stateify. + +package control + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (fs *RightsFiles) StateTypeName() string { + return "pkg/sentry/socket/control.RightsFiles" +} + +func (fs *RightsFiles) StateFields() []string { + return nil +} + +func (c *scmCredentials) StateTypeName() string { + return "pkg/sentry/socket/control.scmCredentials" +} + +func (c *scmCredentials) StateFields() []string { + return []string{ + "t", + "kuid", + "kgid", + } +} + +func (c *scmCredentials) beforeSave() {} + +func (c *scmCredentials) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.t) + stateSinkObject.Save(1, &c.kuid) + stateSinkObject.Save(2, &c.kgid) +} + +func (c *scmCredentials) afterLoad() {} + +func (c *scmCredentials) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.t) + stateSourceObject.Load(1, &c.kuid) + stateSourceObject.Load(2, &c.kgid) +} + +func (fs *RightsFilesVFS2) StateTypeName() string { + return "pkg/sentry/socket/control.RightsFilesVFS2" +} + +func (fs *RightsFilesVFS2) StateFields() []string { + return nil +} + +func init() { + state.Register((*RightsFiles)(nil)) + state.Register((*scmCredentials)(nil)) + state.Register((*RightsFilesVFS2)(nil)) +} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD deleted file mode 100644 index b6ebe29d6..000000000 --- a/pkg/sentry/socket/hostinet/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "hostinet", - srcs = [ - "device.go", - "hostinet.go", - "save_restore.go", - "socket.go", - "socket_unsafe.go", - "socket_vfs2.go", - "sockopt_impl.go", - "stack.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/fdnotifier", - "//pkg/log", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsimpl/sockfs", - "//pkg/sentry/hostfd", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/control", - "//pkg/sentry/vfs", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/socket/hostinet/hostinet_impl_state_autogen.go b/pkg/sentry/socket/hostinet/hostinet_impl_state_autogen.go new file mode 100644 index 000000000..b0a59ba93 --- /dev/null +++ b/pkg/sentry/socket/hostinet/hostinet_impl_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostinet diff --git a/pkg/sentry/socket/hostinet/hostinet_state_autogen.go b/pkg/sentry/socket/hostinet/hostinet_state_autogen.go new file mode 100644 index 000000000..a229dcec5 --- /dev/null +++ b/pkg/sentry/socket/hostinet/hostinet_state_autogen.go @@ -0,0 +1,85 @@ +// automatically generated by stateify. + +package hostinet + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *socketOpsCommon) StateTypeName() string { + return "pkg/sentry/socket/hostinet.socketOpsCommon" +} + +func (s *socketOpsCommon) StateFields() []string { + return []string{ + "SendReceiveTimeout", + "family", + "stype", + "protocol", + "queue", + "fd", + } +} + +func (s *socketOpsCommon) beforeSave() {} + +func (s *socketOpsCommon) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.SendReceiveTimeout) + stateSinkObject.Save(1, &s.family) + stateSinkObject.Save(2, &s.stype) + stateSinkObject.Save(3, &s.protocol) + stateSinkObject.Save(4, &s.queue) + stateSinkObject.Save(5, &s.fd) +} + +func (s *socketOpsCommon) afterLoad() {} + +func (s *socketOpsCommon) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.SendReceiveTimeout) + stateSourceObject.Load(1, &s.family) + stateSourceObject.Load(2, &s.stype) + stateSourceObject.Load(3, &s.protocol) + stateSourceObject.Load(4, &s.queue) + stateSourceObject.Load(5, &s.fd) +} + +func (s *socketVFS2) StateTypeName() string { + return "pkg/sentry/socket/hostinet.socketVFS2" +} + +func (s *socketVFS2) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "LockFD", + "DentryMetadataFileDescriptionImpl", + "socketOpsCommon", + } +} + +func (s *socketVFS2) beforeSave() {} + +func (s *socketVFS2) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.vfsfd) + stateSinkObject.Save(1, &s.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &s.LockFD) + stateSinkObject.Save(3, &s.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(4, &s.socketOpsCommon) +} + +func (s *socketVFS2) afterLoad() {} + +func (s *socketVFS2) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.vfsfd) + stateSourceObject.Load(1, &s.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &s.LockFD) + stateSourceObject.Load(3, &s.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(4, &s.socketOpsCommon) +} + +func init() { + state.Register((*socketOpsCommon)(nil)) + state.Register((*socketVFS2)(nil)) +} diff --git a/pkg/sentry/socket/hostinet/hostinet_unsafe_state_autogen.go b/pkg/sentry/socket/hostinet/hostinet_unsafe_state_autogen.go new file mode 100644 index 000000000..b0a59ba93 --- /dev/null +++ b/pkg/sentry/socket/hostinet/hostinet_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package hostinet diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD deleted file mode 100644 index 8aea0200f..000000000 --- a/pkg/sentry/socket/netfilter/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "netfilter", - srcs = [ - "extensions.go", - "ipv4.go", - "ipv6.go", - "netfilter.go", - "owner_matcher.go", - "targets.go", - "tcp_matcher.go", - "udp_matcher.go", - ], - # This target depends on netstack and should only be used by epsocket, - # which is allowed to depend on netstack. - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/log", - "//pkg/sentry/kernel", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "//pkg/usermem", - ], -) diff --git a/pkg/sentry/socket/netfilter/netfilter_state_autogen.go b/pkg/sentry/socket/netfilter/netfilter_state_autogen.go new file mode 100644 index 000000000..6e95d89a4 --- /dev/null +++ b/pkg/sentry/socket/netfilter/netfilter_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package netfilter diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD deleted file mode 100644 index 1f926aa91..000000000 --- a/pkg/sentry/socket/netlink/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "netlink", - srcs = [ - "message.go", - "provider.go", - "provider_vfs2.go", - "socket.go", - "socket_vfs2.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/context", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsimpl/sockfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/netlink/port", - "//pkg/sentry/socket/unix", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/usermem", - "//pkg/waiter", - ], -) - -go_test( - name = "netlink_test", - size = "small", - srcs = [ - "message_test.go", - ], - deps = [ - ":netlink", - "//pkg/abi/linux", - ], -) diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go deleted file mode 100644 index ef13d9386..000000000 --- a/pkg/sentry/socket/netlink/message_test.go +++ /dev/null @@ -1,312 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package message_test - -import ( - "bytes" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/socket/netlink" -) - -type dummyNetlinkMsg struct { - Foo uint16 -} - -func TestParseMessage(t *testing.T) { - tests := []struct { - desc string - input []byte - - header linux.NetlinkMessageHeader - dataMsg *dummyNetlinkMsg - restLen int - ok bool - }{ - { - desc: "valid", - input: []byte{ - 0x14, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - }, - header: linux.NetlinkMessageHeader{ - Length: 20, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 0, - ok: true, - }, - { - desc: "valid with next message", - input: []byte{ - 0x14, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - 0xFF, // Next message (rest) - }, - header: linux.NetlinkMessageHeader{ - Length: 20, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 1, - ok: true, - }, - { - desc: "valid for last message without padding", - input: []byte{ - 0x12, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, // Data message - }, - header: linux.NetlinkMessageHeader{ - Length: 18, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 0, - ok: true, - }, - { - desc: "valid for last message not to be aligned", - input: []byte{ - 0x13, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, // Data message - 0x00, // Excessive 1 byte permitted at end - }, - header: linux.NetlinkMessageHeader{ - Length: 19, - Type: 1, - Flags: 2, - Seq: 3, - PortID: 4, - }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, - restLen: 0, - ok: true, - }, - { - desc: "header.Length too short", - input: []byte{ - 0x04, 0x00, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - }, - ok: false, - }, - { - desc: "header.Length too long", - input: []byte{ - 0xFF, 0xFF, 0x00, 0x00, // Length - 0x01, 0x00, // Type - 0x02, 0x00, // Flags - 0x03, 0x00, 0x00, 0x00, // Seq - 0x04, 0x00, 0x00, 0x00, // PortID - 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding - }, - ok: false, - }, - { - desc: "header incomplete", - input: []byte{ - 0x04, 0x00, 0x00, 0x00, // Length - }, - ok: false, - }, - { - desc: "empty message", - input: []byte{}, - ok: false, - }, - } - for _, test := range tests { - msg, rest, ok := netlink.ParseMessage(test.input) - if ok != test.ok { - t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) - continue - } - if !test.ok { - continue - } - if !reflect.DeepEqual(msg.Header(), test.header) { - t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) - } - - dataMsg := &dummyNetlinkMsg{} - _, dataOk := msg.GetData(dataMsg) - if !dataOk { - t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) - } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { - t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) - } - - if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) { - t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want) - } - } -} - -func TestAttrView(t *testing.T) { - tests := []struct { - desc string - input []byte - - // Outputs for ParseFirst. - hdr linux.NetlinkAttrHeader - value []byte - restLen int - ok bool - - // Outputs for Empty. - isEmpty bool - }{ - { - desc: "valid", - input: []byte{ - 0x06, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding - }, - hdr: linux.NetlinkAttrHeader{ - Length: 6, - Type: 1, - }, - value: []byte{0x30, 0x31}, - restLen: 0, - ok: true, - isEmpty: false, - }, - { - desc: "at alignment", - input: []byte{ - 0x08, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - }, - hdr: linux.NetlinkAttrHeader{ - Length: 8, - Type: 1, - }, - value: []byte{0x30, 0x31, 0x32, 0x33}, - restLen: 0, - ok: true, - isEmpty: false, - }, - { - desc: "at alignment with rest data", - input: []byte{ - 0x08, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - 0xFF, 0xFE, // Rest data - }, - hdr: linux.NetlinkAttrHeader{ - Length: 8, - Type: 1, - }, - value: []byte{0x30, 0x31, 0x32, 0x33}, - restLen: 2, - ok: true, - isEmpty: false, - }, - { - desc: "hdr.Length too long", - input: []byte{ - 0xFF, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - }, - ok: false, - isEmpty: false, - }, - { - desc: "hdr.Length too short", - input: []byte{ - 0x01, 0x00, // Length - 0x01, 0x00, // Type - 0x30, 0x31, 0x32, 0x33, // Data - }, - ok: false, - isEmpty: false, - }, - { - desc: "empty", - input: []byte{}, - ok: false, - isEmpty: true, - }, - } - for _, test := range tests { - attrs := netlink.AttrsView(test.input) - - // Test ParseFirst(). - hdr, value, rest, ok := attrs.ParseFirst() - if ok != test.ok { - t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) - } else if test.ok { - if !reflect.DeepEqual(hdr, test.hdr) { - t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr) - } - if !bytes.Equal(value, test.value) { - t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value) - } - if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) { - t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest) - } - } - - // Test Empty(). - if got, want := attrs.Empty(), test.isEmpty; got != want { - t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want) - } - } -} diff --git a/pkg/sentry/socket/netlink/netlink_state_autogen.go b/pkg/sentry/socket/netlink/netlink_state_autogen.go new file mode 100644 index 000000000..307498c1c --- /dev/null +++ b/pkg/sentry/socket/netlink/netlink_state_autogen.go @@ -0,0 +1,141 @@ +// automatically generated by stateify. + +package netlink + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *Socket) StateTypeName() string { + return "pkg/sentry/socket/netlink.Socket" +} + +func (s *Socket) StateFields() []string { + return []string{ + "socketOpsCommon", + } +} + +func (s *Socket) beforeSave() {} + +func (s *Socket) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.socketOpsCommon) +} + +func (s *Socket) afterLoad() {} + +func (s *Socket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.socketOpsCommon) +} + +func (s *socketOpsCommon) StateTypeName() string { + return "pkg/sentry/socket/netlink.socketOpsCommon" +} + +func (s *socketOpsCommon) StateFields() []string { + return []string{ + "SendReceiveTimeout", + "ports", + "protocol", + "skType", + "ep", + "connection", + "bound", + "portID", + "sendBufferSize", + "filter", + } +} + +func (s *socketOpsCommon) beforeSave() {} + +func (s *socketOpsCommon) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.SendReceiveTimeout) + stateSinkObject.Save(1, &s.ports) + stateSinkObject.Save(2, &s.protocol) + stateSinkObject.Save(3, &s.skType) + stateSinkObject.Save(4, &s.ep) + stateSinkObject.Save(5, &s.connection) + stateSinkObject.Save(6, &s.bound) + stateSinkObject.Save(7, &s.portID) + stateSinkObject.Save(8, &s.sendBufferSize) + stateSinkObject.Save(9, &s.filter) +} + +func (s *socketOpsCommon) afterLoad() {} + +func (s *socketOpsCommon) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.SendReceiveTimeout) + stateSourceObject.Load(1, &s.ports) + stateSourceObject.Load(2, &s.protocol) + stateSourceObject.Load(3, &s.skType) + stateSourceObject.Load(4, &s.ep) + stateSourceObject.Load(5, &s.connection) + stateSourceObject.Load(6, &s.bound) + stateSourceObject.Load(7, &s.portID) + stateSourceObject.Load(8, &s.sendBufferSize) + stateSourceObject.Load(9, &s.filter) +} + +func (k *kernelSCM) StateTypeName() string { + return "pkg/sentry/socket/netlink.kernelSCM" +} + +func (k *kernelSCM) StateFields() []string { + return []string{} +} + +func (k *kernelSCM) beforeSave() {} + +func (k *kernelSCM) StateSave(stateSinkObject state.Sink) { + k.beforeSave() +} + +func (k *kernelSCM) afterLoad() {} + +func (k *kernelSCM) StateLoad(stateSourceObject state.Source) { +} + +func (s *SocketVFS2) StateTypeName() string { + return "pkg/sentry/socket/netlink.SocketVFS2" +} + +func (s *SocketVFS2) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "LockFD", + "socketOpsCommon", + } +} + +func (s *SocketVFS2) beforeSave() {} + +func (s *SocketVFS2) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.vfsfd) + stateSinkObject.Save(1, &s.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &s.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &s.LockFD) + stateSinkObject.Save(4, &s.socketOpsCommon) +} + +func (s *SocketVFS2) afterLoad() {} + +func (s *SocketVFS2) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.vfsfd) + stateSourceObject.Load(1, &s.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &s.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &s.LockFD) + stateSourceObject.Load(4, &s.socketOpsCommon) +} + +func init() { + state.Register((*Socket)(nil)) + state.Register((*socketOpsCommon)(nil)) + state.Register((*kernelSCM)(nil)) + state.Register((*SocketVFS2)(nil)) +} diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD deleted file mode 100644 index 3a22923d8..000000000 --- a/pkg/sentry/socket/netlink/port/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "port", - srcs = ["port.go"], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/sync"], -) - -go_test( - name = "port_test", - srcs = ["port_test.go"], - library = ":port", -) diff --git a/pkg/sentry/socket/netlink/port/port_state_autogen.go b/pkg/sentry/socket/netlink/port/port_state_autogen.go new file mode 100644 index 000000000..e0083fcad --- /dev/null +++ b/pkg/sentry/socket/netlink/port/port_state_autogen.go @@ -0,0 +1,34 @@ +// automatically generated by stateify. + +package port + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (m *Manager) StateTypeName() string { + return "pkg/sentry/socket/netlink/port.Manager" +} + +func (m *Manager) StateFields() []string { + return []string{ + "ports", + } +} + +func (m *Manager) beforeSave() {} + +func (m *Manager) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.ports) +} + +func (m *Manager) afterLoad() {} + +func (m *Manager) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.ports) +} + +func init() { + state.Register((*Manager)(nil)) +} diff --git a/pkg/sentry/socket/netlink/port/port_test.go b/pkg/sentry/socket/netlink/port/port_test.go deleted file mode 100644 index 516f6cd6c..000000000 --- a/pkg/sentry/socket/netlink/port/port_test.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package port - -import ( - "testing" -) - -func TestAllocateHint(t *testing.T) { - m := New() - - // We can get the hint port. - p, ok := m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(0, 1) got %d want 1", p) - } - - // Hint is taken. - p, ok = m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p == 1 { - t.Errorf("m.Allocate(0, 1) got 1 want anything else") - } - - // Hint is available for a different protocol. - p, ok = m.Allocate(1, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(1, 1) got %d want 1", p) - } - - m.Release(0, 1) - - // Hint is available again after release. - p, ok = m.Allocate(0, 1) - if !ok { - t.Errorf("m.Allocate got !ok want ok") - } - if p != 1 { - t.Errorf("m.Allocate(0, 1) got %d want 1", p) - } -} - -func TestAllocateExhausted(t *testing.T) { - m := New() - - // Fill all ports (0 is already reserved). - for i := int32(1); i < maxPorts; i++ { - p, ok := m.Allocate(0, i) - if !ok { - t.Fatalf("m.Allocate got !ok want ok") - } - if p != i { - t.Fatalf("m.Allocate(0, %d) got %d want %d", i, p, i) - } - } - - // Now no more can be allocated. - p, ok := m.Allocate(0, 1) - if ok { - t.Errorf("m.Allocate got %d, ok want !ok", p) - } -} diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD deleted file mode 100644 index 93127398d..000000000 --- a/pkg/sentry/socket/netlink/route/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "route", - srcs = [ - "protocol.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/socket/netlink", - "//pkg/syserr", - ], -) diff --git a/pkg/sentry/socket/netlink/route/route_state_autogen.go b/pkg/sentry/socket/netlink/route/route_state_autogen.go new file mode 100644 index 000000000..0b263b5b1 --- /dev/null +++ b/pkg/sentry/socket/netlink/route/route_state_autogen.go @@ -0,0 +1,30 @@ +// automatically generated by stateify. + +package route + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (p *Protocol) StateTypeName() string { + return "pkg/sentry/socket/netlink/route.Protocol" +} + +func (p *Protocol) StateFields() []string { + return []string{} +} + +func (p *Protocol) beforeSave() {} + +func (p *Protocol) StateSave(stateSinkObject state.Sink) { + p.beforeSave() +} + +func (p *Protocol) afterLoad() {} + +func (p *Protocol) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*Protocol)(nil)) +} diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD deleted file mode 100644 index b6434923c..000000000 --- a/pkg/sentry/socket/netlink/uevent/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "uevent", - srcs = ["protocol.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/sentry/kernel", - "//pkg/sentry/socket/netlink", - "//pkg/syserr", - ], -) diff --git a/pkg/sentry/socket/netlink/uevent/uevent_state_autogen.go b/pkg/sentry/socket/netlink/uevent/uevent_state_autogen.go new file mode 100644 index 000000000..63f488594 --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/uevent_state_autogen.go @@ -0,0 +1,30 @@ +// automatically generated by stateify. + +package uevent + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (p *Protocol) StateTypeName() string { + return "pkg/sentry/socket/netlink/uevent.Protocol" +} + +func (p *Protocol) StateFields() []string { + return []string{} +} + +func (p *Protocol) beforeSave() {} + +func (p *Protocol) StateSave(stateSinkObject state.Sink) { + p.beforeSave() +} + +func (p *Protocol) afterLoad() {} + +func (p *Protocol) StateLoad(stateSourceObject state.Source) { +} + +func init() { + state.Register((*Protocol)(nil)) +} diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD deleted file mode 100644 index fae3b6783..000000000 --- a/pkg/sentry/socket/netstack/BUILD +++ /dev/null @@ -1,58 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "netstack", - srcs = [ - "device.go", - "netstack.go", - "netstack_vfs2.go", - "provider.go", - "provider_vfs2.go", - "save_restore.go", - "stack.go", - ], - visibility = [ - "//pkg/sentry:internal", - ], - deps = [ - "//pkg/abi/linux", - "//pkg/amutex", - "//pkg/binary", - "//pkg/context", - "//pkg/log", - "//pkg/marshal", - "//pkg/marshal/primitive", - "//pkg/metric", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsimpl/sockfs", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/netfilter", - "//pkg/sentry/unimpl", - "//pkg/sentry/vfs", - "//pkg/sync", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/usermem", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/socket/netstack/netstack_state_autogen.go b/pkg/sentry/socket/netstack/netstack_state_autogen.go new file mode 100644 index 000000000..8465d8743 --- /dev/null +++ b/pkg/sentry/socket/netstack/netstack_state_autogen.go @@ -0,0 +1,155 @@ +// automatically generated by stateify. + +package netstack + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (s *SocketOperations) StateTypeName() string { + return "pkg/sentry/socket/netstack.SocketOperations" +} + +func (s *SocketOperations) StateFields() []string { + return []string{ + "socketOpsCommon", + } +} + +func (s *SocketOperations) beforeSave() {} + +func (s *SocketOperations) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.socketOpsCommon) +} + +func (s *SocketOperations) afterLoad() {} + +func (s *SocketOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.socketOpsCommon) +} + +func (s *socketOpsCommon) StateTypeName() string { + return "pkg/sentry/socket/netstack.socketOpsCommon" +} + +func (s *socketOpsCommon) StateFields() []string { + return []string{ + "SendReceiveTimeout", + "Queue", + "family", + "Endpoint", + "skType", + "protocol", + "readViewHasData", + "readView", + "readCM", + "sender", + "linkPacketInfo", + "sockOptTimestamp", + "timestampValid", + "timestampNS", + "sockOptInq", + } +} + +func (s *socketOpsCommon) beforeSave() {} + +func (s *socketOpsCommon) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.SendReceiveTimeout) + stateSinkObject.Save(1, &s.Queue) + stateSinkObject.Save(2, &s.family) + stateSinkObject.Save(3, &s.Endpoint) + stateSinkObject.Save(4, &s.skType) + stateSinkObject.Save(5, &s.protocol) + stateSinkObject.Save(6, &s.readViewHasData) + stateSinkObject.Save(7, &s.readView) + stateSinkObject.Save(8, &s.readCM) + stateSinkObject.Save(9, &s.sender) + stateSinkObject.Save(10, &s.linkPacketInfo) + stateSinkObject.Save(11, &s.sockOptTimestamp) + stateSinkObject.Save(12, &s.timestampValid) + stateSinkObject.Save(13, &s.timestampNS) + stateSinkObject.Save(14, &s.sockOptInq) +} + +func (s *socketOpsCommon) afterLoad() {} + +func (s *socketOpsCommon) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.SendReceiveTimeout) + stateSourceObject.Load(1, &s.Queue) + stateSourceObject.Load(2, &s.family) + stateSourceObject.Load(3, &s.Endpoint) + stateSourceObject.Load(4, &s.skType) + stateSourceObject.Load(5, &s.protocol) + stateSourceObject.Load(6, &s.readViewHasData) + stateSourceObject.Load(7, &s.readView) + stateSourceObject.Load(8, &s.readCM) + stateSourceObject.Load(9, &s.sender) + stateSourceObject.Load(10, &s.linkPacketInfo) + stateSourceObject.Load(11, &s.sockOptTimestamp) + stateSourceObject.Load(12, &s.timestampValid) + stateSourceObject.Load(13, &s.timestampNS) + stateSourceObject.Load(14, &s.sockOptInq) +} + +func (s *SocketVFS2) StateTypeName() string { + return "pkg/sentry/socket/netstack.SocketVFS2" +} + +func (s *SocketVFS2) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "LockFD", + "socketOpsCommon", + } +} + +func (s *SocketVFS2) beforeSave() {} + +func (s *SocketVFS2) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.vfsfd) + stateSinkObject.Save(1, &s.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &s.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &s.LockFD) + stateSinkObject.Save(4, &s.socketOpsCommon) +} + +func (s *SocketVFS2) afterLoad() {} + +func (s *SocketVFS2) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.vfsfd) + stateSourceObject.Load(1, &s.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &s.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &s.LockFD) + stateSourceObject.Load(4, &s.socketOpsCommon) +} + +func (s *Stack) StateTypeName() string { + return "pkg/sentry/socket/netstack.Stack" +} + +func (s *Stack) StateFields() []string { + return []string{} +} + +func (s *Stack) beforeSave() {} + +func (s *Stack) StateSave(stateSinkObject state.Sink) { + s.beforeSave() +} + +func (s *Stack) StateLoad(stateSourceObject state.Source) { + stateSourceObject.AfterLoad(s.afterLoad) +} + +func init() { + state.Register((*SocketOperations)(nil)) + state.Register((*socketOpsCommon)(nil)) + state.Register((*SocketVFS2)(nil)) + state.Register((*Stack)(nil)) +} diff --git a/pkg/sentry/socket/socket_state_autogen.go b/pkg/sentry/socket/socket_state_autogen.go new file mode 100644 index 000000000..4f854f99f --- /dev/null +++ b/pkg/sentry/socket/socket_state_autogen.go @@ -0,0 +1,91 @@ +// automatically generated by stateify. + +package socket + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (i *IPControlMessages) StateTypeName() string { + return "pkg/sentry/socket.IPControlMessages" +} + +func (i *IPControlMessages) StateFields() []string { + return []string{ + "HasTimestamp", + "Timestamp", + "HasInq", + "Inq", + "HasTOS", + "TOS", + "HasTClass", + "TClass", + "HasIPPacketInfo", + "PacketInfo", + "OriginalDstAddress", + } +} + +func (i *IPControlMessages) beforeSave() {} + +func (i *IPControlMessages) StateSave(stateSinkObject state.Sink) { + i.beforeSave() + stateSinkObject.Save(0, &i.HasTimestamp) + stateSinkObject.Save(1, &i.Timestamp) + stateSinkObject.Save(2, &i.HasInq) + stateSinkObject.Save(3, &i.Inq) + stateSinkObject.Save(4, &i.HasTOS) + stateSinkObject.Save(5, &i.TOS) + stateSinkObject.Save(6, &i.HasTClass) + stateSinkObject.Save(7, &i.TClass) + stateSinkObject.Save(8, &i.HasIPPacketInfo) + stateSinkObject.Save(9, &i.PacketInfo) + stateSinkObject.Save(10, &i.OriginalDstAddress) +} + +func (i *IPControlMessages) afterLoad() {} + +func (i *IPControlMessages) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &i.HasTimestamp) + stateSourceObject.Load(1, &i.Timestamp) + stateSourceObject.Load(2, &i.HasInq) + stateSourceObject.Load(3, &i.Inq) + stateSourceObject.Load(4, &i.HasTOS) + stateSourceObject.Load(5, &i.TOS) + stateSourceObject.Load(6, &i.HasTClass) + stateSourceObject.Load(7, &i.TClass) + stateSourceObject.Load(8, &i.HasIPPacketInfo) + stateSourceObject.Load(9, &i.PacketInfo) + stateSourceObject.Load(10, &i.OriginalDstAddress) +} + +func (to *SendReceiveTimeout) StateTypeName() string { + return "pkg/sentry/socket.SendReceiveTimeout" +} + +func (to *SendReceiveTimeout) StateFields() []string { + return []string{ + "send", + "recv", + } +} + +func (to *SendReceiveTimeout) beforeSave() {} + +func (to *SendReceiveTimeout) StateSave(stateSinkObject state.Sink) { + to.beforeSave() + stateSinkObject.Save(0, &to.send) + stateSinkObject.Save(1, &to.recv) +} + +func (to *SendReceiveTimeout) afterLoad() {} + +func (to *SendReceiveTimeout) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &to.send) + stateSourceObject.Load(1, &to.recv) +} + +func init() { + state.Register((*IPControlMessages)(nil)) + state.Register((*SendReceiveTimeout)(nil)) +} diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD deleted file mode 100644 index cce0acc33..000000000 --- a/pkg/sentry/socket/unix/BUILD +++ /dev/null @@ -1,67 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "socket_refs", - out = "socket_refs.go", - package = "unix", - prefix = "socketOperations", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "SocketOperations", - }, -) - -go_template_instance( - name = "socket_vfs2_refs", - out = "socket_vfs2_refs.go", - package = "unix", - prefix = "socketVFS2", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "SocketVFS2", - }, -) - -go_library( - name = "unix", - srcs = [ - "device.go", - "io.go", - "socket_refs.go", - "socket_vfs2_refs.go", - "unix.go", - "unix_vfs2.go", - ], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/fspath", - "//pkg/log", - "//pkg/marshal", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/safemem", - "//pkg/sentry/arch", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", - "//pkg/sentry/fsimpl/sockfs", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/control", - "//pkg/sentry/socket/netstack", - "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/vfs", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/usermem", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/unix/socket_refs.go b/pkg/sentry/socket/unix/socket_refs.go new file mode 100644 index 000000000..2a7fcb253 --- /dev/null +++ b/pkg/sentry/socket/unix/socket_refs.go @@ -0,0 +1,132 @@ +package unix + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const socketOperationsenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var socketOperationsobj *SocketOperations + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// +stateify savable +type socketOperationsRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *socketOperationsRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *socketOperationsRefs) RefType() string { + return fmt.Sprintf("%T", socketOperationsobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *socketOperationsRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *socketOperationsRefs) LogRefs() bool { + return socketOperationsenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *socketOperationsRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *socketOperationsRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if socketOperationsenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.RefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *socketOperationsRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if socketOperationsenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *socketOperationsRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if socketOperationsenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *socketOperationsRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/socket/unix/socket_vfs2_refs.go b/pkg/sentry/socket/unix/socket_vfs2_refs.go new file mode 100644 index 000000000..f10033260 --- /dev/null +++ b/pkg/sentry/socket/unix/socket_vfs2_refs.go @@ -0,0 +1,132 @@ +package unix + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const socketVFS2enableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var socketVFS2obj *SocketVFS2 + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// +stateify savable +type socketVFS2Refs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *socketVFS2Refs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *socketVFS2Refs) RefType() string { + return fmt.Sprintf("%T", socketVFS2obj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *socketVFS2Refs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *socketVFS2Refs) LogRefs() bool { + return socketVFS2enableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *socketVFS2Refs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *socketVFS2Refs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if socketVFS2enableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.RefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *socketVFS2Refs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if socketVFS2enableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *socketVFS2Refs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if socketVFS2enableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *socketVFS2Refs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD deleted file mode 100644 index 3ebbd28b0..000000000 --- a/pkg/sentry/socket/unix/transport/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "transport_message_list", - out = "transport_message_list.go", - package = "transport", - prefix = "message", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*message", - "Linker": "*message", - }, -) - -go_template_instance( - name = "queue_refs", - out = "queue_refs.go", - package = "transport", - prefix = "queue", - template = "//pkg/refsvfs2:refs_template", - types = { - "T": "queue", - }, -) - -go_library( - name = "transport", - srcs = [ - "connectioned.go", - "connectioned_state.go", - "connectionless.go", - "queue.go", - "queue_refs.go", - "transport_message_list.go", - "unix.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/ilist", - "//pkg/log", - "//pkg/refs", - "//pkg/refsvfs2", - "//pkg/sync", - "//pkg/syserr", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/waiter", - ], -) diff --git a/pkg/sentry/socket/unix/transport/queue_refs.go b/pkg/sentry/socket/unix/transport/queue_refs.go new file mode 100644 index 000000000..42c5b7ce0 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/queue_refs.go @@ -0,0 +1,132 @@ +package transport + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const queueenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var queueobj *queue + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// +stateify savable +type queueRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *queueRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *queueRefs) RefType() string { + return fmt.Sprintf("%T", queueobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *queueRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *queueRefs) LogRefs() bool { + return queueenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *queueRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *queueRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if queueenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.RefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *queueRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if queueenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *queueRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if queueenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *queueRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/sentry/socket/unix/transport/transport_message_list.go b/pkg/sentry/socket/unix/transport/transport_message_list.go new file mode 100644 index 000000000..dda579c27 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/transport_message_list.go @@ -0,0 +1,193 @@ +package transport + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type messageElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (messageElementMapper) linkerFor(elem *message) *message { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type messageList struct { + head *message + tail *message +} + +// Reset resets list l to the empty state. +func (l *messageList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *messageList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *messageList) Front() *message { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *messageList) Back() *message { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *messageList) Len() (count int) { + for e := l.Front(); e != nil; e = (messageElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +func (l *messageList) PushFront(e *message) { + linker := messageElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + messageElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *messageList) PushBack(e *message) { + linker := messageElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + messageElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *messageList) PushBackList(m *messageList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + messageElementMapper{}.linkerFor(l.tail).SetNext(m.head) + messageElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *messageList) InsertAfter(b, e *message) { + bLinker := messageElementMapper{}.linkerFor(b) + eLinker := messageElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + messageElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *messageList) InsertBefore(a, e *message) { + aLinker := messageElementMapper{}.linkerFor(a) + eLinker := messageElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + messageElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *messageList) Remove(e *message) { + linker := messageElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + messageElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + messageElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type messageEntry struct { + next *message + prev *message +} + +// Next returns the entry that follows e in the list. +func (e *messageEntry) Next() *message { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *messageEntry) Prev() *message { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *messageEntry) SetNext(elem *message) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *messageEntry) SetPrev(elem *message) { + e.prev = elem +} diff --git a/pkg/sentry/socket/unix/transport/transport_state_autogen.go b/pkg/sentry/socket/unix/transport/transport_state_autogen.go new file mode 100644 index 000000000..2aeff0256 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/transport_state_autogen.go @@ -0,0 +1,376 @@ +// automatically generated by stateify. + +package transport + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (e *connectionedEndpoint) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.connectionedEndpoint" +} + +func (e *connectionedEndpoint) StateFields() []string { + return []string{ + "baseEndpoint", + "id", + "idGenerator", + "stype", + "acceptedChan", + } +} + +func (e *connectionedEndpoint) beforeSave() {} + +func (e *connectionedEndpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + var acceptedChanValue []*connectionedEndpoint = e.saveAcceptedChan() + stateSinkObject.SaveValue(4, acceptedChanValue) + stateSinkObject.Save(0, &e.baseEndpoint) + stateSinkObject.Save(1, &e.id) + stateSinkObject.Save(2, &e.idGenerator) + stateSinkObject.Save(3, &e.stype) +} + +func (e *connectionedEndpoint) afterLoad() {} + +func (e *connectionedEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.baseEndpoint) + stateSourceObject.Load(1, &e.id) + stateSourceObject.Load(2, &e.idGenerator) + stateSourceObject.Load(3, &e.stype) + stateSourceObject.LoadValue(4, new([]*connectionedEndpoint), func(y interface{}) { e.loadAcceptedChan(y.([]*connectionedEndpoint)) }) +} + +func (e *connectionlessEndpoint) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.connectionlessEndpoint" +} + +func (e *connectionlessEndpoint) StateFields() []string { + return []string{ + "baseEndpoint", + } +} + +func (e *connectionlessEndpoint) beforeSave() {} + +func (e *connectionlessEndpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.baseEndpoint) +} + +func (e *connectionlessEndpoint) afterLoad() {} + +func (e *connectionlessEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.baseEndpoint) +} + +func (q *queue) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.queue" +} + +func (q *queue) StateFields() []string { + return []string{ + "queueRefs", + "ReaderQueue", + "WriterQueue", + "closed", + "unread", + "used", + "limit", + "dataList", + } +} + +func (q *queue) beforeSave() {} + +func (q *queue) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.queueRefs) + stateSinkObject.Save(1, &q.ReaderQueue) + stateSinkObject.Save(2, &q.WriterQueue) + stateSinkObject.Save(3, &q.closed) + stateSinkObject.Save(4, &q.unread) + stateSinkObject.Save(5, &q.used) + stateSinkObject.Save(6, &q.limit) + stateSinkObject.Save(7, &q.dataList) +} + +func (q *queue) afterLoad() {} + +func (q *queue) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &q.queueRefs) + stateSourceObject.Load(1, &q.ReaderQueue) + stateSourceObject.Load(2, &q.WriterQueue) + stateSourceObject.Load(3, &q.closed) + stateSourceObject.Load(4, &q.unread) + stateSourceObject.Load(5, &q.used) + stateSourceObject.Load(6, &q.limit) + stateSourceObject.Load(7, &q.dataList) +} + +func (r *queueRefs) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.queueRefs" +} + +func (r *queueRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *queueRefs) beforeSave() {} + +func (r *queueRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +func (r *queueRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (l *messageList) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.messageList" +} + +func (l *messageList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *messageList) beforeSave() {} + +func (l *messageList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *messageList) afterLoad() {} + +func (l *messageList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *messageEntry) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.messageEntry" +} + +func (e *messageEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *messageEntry) beforeSave() {} + +func (e *messageEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *messageEntry) afterLoad() {} + +func (e *messageEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (c *ControlMessages) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.ControlMessages" +} + +func (c *ControlMessages) StateFields() []string { + return []string{ + "Rights", + "Credentials", + } +} + +func (c *ControlMessages) beforeSave() {} + +func (c *ControlMessages) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.Rights) + stateSinkObject.Save(1, &c.Credentials) +} + +func (c *ControlMessages) afterLoad() {} + +func (c *ControlMessages) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.Rights) + stateSourceObject.Load(1, &c.Credentials) +} + +func (m *message) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.message" +} + +func (m *message) StateFields() []string { + return []string{ + "messageEntry", + "Data", + "Control", + "Address", + } +} + +func (m *message) beforeSave() {} + +func (m *message) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.messageEntry) + stateSinkObject.Save(1, &m.Data) + stateSinkObject.Save(2, &m.Control) + stateSinkObject.Save(3, &m.Address) +} + +func (m *message) afterLoad() {} + +func (m *message) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.messageEntry) + stateSourceObject.Load(1, &m.Data) + stateSourceObject.Load(2, &m.Control) + stateSourceObject.Load(3, &m.Address) +} + +func (q *queueReceiver) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.queueReceiver" +} + +func (q *queueReceiver) StateFields() []string { + return []string{ + "readQueue", + } +} + +func (q *queueReceiver) beforeSave() {} + +func (q *queueReceiver) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.readQueue) +} + +func (q *queueReceiver) afterLoad() {} + +func (q *queueReceiver) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &q.readQueue) +} + +func (q *streamQueueReceiver) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.streamQueueReceiver" +} + +func (q *streamQueueReceiver) StateFields() []string { + return []string{ + "queueReceiver", + "buffer", + "control", + "addr", + } +} + +func (q *streamQueueReceiver) beforeSave() {} + +func (q *streamQueueReceiver) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.queueReceiver) + stateSinkObject.Save(1, &q.buffer) + stateSinkObject.Save(2, &q.control) + stateSinkObject.Save(3, &q.addr) +} + +func (q *streamQueueReceiver) afterLoad() {} + +func (q *streamQueueReceiver) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &q.queueReceiver) + stateSourceObject.Load(1, &q.buffer) + stateSourceObject.Load(2, &q.control) + stateSourceObject.Load(3, &q.addr) +} + +func (e *connectedEndpoint) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.connectedEndpoint" +} + +func (e *connectedEndpoint) StateFields() []string { + return []string{ + "endpoint", + "writeQueue", + } +} + +func (e *connectedEndpoint) beforeSave() {} + +func (e *connectedEndpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.endpoint) + stateSinkObject.Save(1, &e.writeQueue) +} + +func (e *connectedEndpoint) afterLoad() {} + +func (e *connectedEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.endpoint) + stateSourceObject.Load(1, &e.writeQueue) +} + +func (e *baseEndpoint) StateTypeName() string { + return "pkg/sentry/socket/unix/transport.baseEndpoint" +} + +func (e *baseEndpoint) StateFields() []string { + return []string{ + "Queue", + "DefaultSocketOptionsHandler", + "receiver", + "connected", + "path", + "ops", + } +} + +func (e *baseEndpoint) beforeSave() {} + +func (e *baseEndpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.Queue) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.receiver) + stateSinkObject.Save(3, &e.connected) + stateSinkObject.Save(4, &e.path) + stateSinkObject.Save(5, &e.ops) +} + +func (e *baseEndpoint) afterLoad() {} + +func (e *baseEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.Queue) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &e.receiver) + stateSourceObject.Load(3, &e.connected) + stateSourceObject.Load(4, &e.path) + stateSourceObject.Load(5, &e.ops) +} + +func init() { + state.Register((*connectionedEndpoint)(nil)) + state.Register((*connectionlessEndpoint)(nil)) + state.Register((*queue)(nil)) + state.Register((*queueRefs)(nil)) + state.Register((*messageList)(nil)) + state.Register((*messageEntry)(nil)) + state.Register((*ControlMessages)(nil)) + state.Register((*message)(nil)) + state.Register((*queueReceiver)(nil)) + state.Register((*streamQueueReceiver)(nil)) + state.Register((*connectedEndpoint)(nil)) + state.Register((*baseEndpoint)(nil)) +} diff --git a/pkg/sentry/socket/unix/unix_state_autogen.go b/pkg/sentry/socket/unix/unix_state_autogen.go new file mode 100644 index 000000000..fba990d9a --- /dev/null +++ b/pkg/sentry/socket/unix/unix_state_autogen.go @@ -0,0 +1,158 @@ +// automatically generated by stateify. + +package unix + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (r *socketOperationsRefs) StateTypeName() string { + return "pkg/sentry/socket/unix.socketOperationsRefs" +} + +func (r *socketOperationsRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *socketOperationsRefs) beforeSave() {} + +func (r *socketOperationsRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +func (r *socketOperationsRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (r *socketVFS2Refs) StateTypeName() string { + return "pkg/sentry/socket/unix.socketVFS2Refs" +} + +func (r *socketVFS2Refs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *socketVFS2Refs) beforeSave() {} + +func (r *socketVFS2Refs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +func (r *socketVFS2Refs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + +func (s *SocketOperations) StateTypeName() string { + return "pkg/sentry/socket/unix.SocketOperations" +} + +func (s *SocketOperations) StateFields() []string { + return []string{ + "socketOperationsRefs", + "socketOpsCommon", + } +} + +func (s *SocketOperations) beforeSave() {} + +func (s *SocketOperations) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.socketOperationsRefs) + stateSinkObject.Save(1, &s.socketOpsCommon) +} + +func (s *SocketOperations) afterLoad() {} + +func (s *SocketOperations) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.socketOperationsRefs) + stateSourceObject.Load(1, &s.socketOpsCommon) +} + +func (s *socketOpsCommon) StateTypeName() string { + return "pkg/sentry/socket/unix.socketOpsCommon" +} + +func (s *socketOpsCommon) StateFields() []string { + return []string{ + "SendReceiveTimeout", + "ep", + "stype", + "abstractName", + "abstractNamespace", + } +} + +func (s *socketOpsCommon) beforeSave() {} + +func (s *socketOpsCommon) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.SendReceiveTimeout) + stateSinkObject.Save(1, &s.ep) + stateSinkObject.Save(2, &s.stype) + stateSinkObject.Save(3, &s.abstractName) + stateSinkObject.Save(4, &s.abstractNamespace) +} + +func (s *socketOpsCommon) afterLoad() {} + +func (s *socketOpsCommon) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.SendReceiveTimeout) + stateSourceObject.Load(1, &s.ep) + stateSourceObject.Load(2, &s.stype) + stateSourceObject.Load(3, &s.abstractName) + stateSourceObject.Load(4, &s.abstractNamespace) +} + +func (s *SocketVFS2) StateTypeName() string { + return "pkg/sentry/socket/unix.SocketVFS2" +} + +func (s *SocketVFS2) StateFields() []string { + return []string{ + "vfsfd", + "FileDescriptionDefaultImpl", + "DentryMetadataFileDescriptionImpl", + "LockFD", + "socketVFS2Refs", + "socketOpsCommon", + } +} + +func (s *SocketVFS2) beforeSave() {} + +func (s *SocketVFS2) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.vfsfd) + stateSinkObject.Save(1, &s.FileDescriptionDefaultImpl) + stateSinkObject.Save(2, &s.DentryMetadataFileDescriptionImpl) + stateSinkObject.Save(3, &s.LockFD) + stateSinkObject.Save(4, &s.socketVFS2Refs) + stateSinkObject.Save(5, &s.socketOpsCommon) +} + +func (s *SocketVFS2) afterLoad() {} + +func (s *SocketVFS2) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.vfsfd) + stateSourceObject.Load(1, &s.FileDescriptionDefaultImpl) + stateSourceObject.Load(2, &s.DentryMetadataFileDescriptionImpl) + stateSourceObject.Load(3, &s.LockFD) + stateSourceObject.Load(4, &s.socketVFS2Refs) + stateSourceObject.Load(5, &s.socketOpsCommon) +} + +func init() { + state.Register((*socketOperationsRefs)(nil)) + state.Register((*socketVFS2Refs)(nil)) + state.Register((*SocketOperations)(nil)) + state.Register((*socketOpsCommon)(nil)) + state.Register((*SocketVFS2)(nil)) +} |