diff options
Diffstat (limited to 'pkg/sentry')
-rw-r--r-- | pkg/sentry/socket/control/control.go | 19 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 57 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/sys_socket.go | 3 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/vfs2/socket.go | 3 | ||||
-rw-r--r-- | pkg/sentry/vfs/mount_unsafe.go | 36 | ||||
-rw-r--r-- | pkg/sentry/vfs/save_restore.go | 6 | ||||
-rw-r--r-- | pkg/sentry/vfs/vfs_unsafe_state_autogen.go | 37 |
8 files changed, 89 insertions, 74 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 70ccf77a7..c284efde5 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -359,13 +359,26 @@ func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) ) } +// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. +func PackOriginalDstAddress(t *kernel.Task, family int, originalDstAddress tcpip.FullAddress, buf []byte) []byte { + p, _ := socket.ConvertAddress(family, originalDstAddress) + level := uint32(linux.SOL_IP) + optType := uint32(linux.IP_RECVORIGDSTADDR) + if family == linux.AF_INET6 { + level = linux.SOL_IPV6 + optType = linux.IPV6_RECVORIGDSTADDR + } + return putCmsgStruct( + buf, level, optType, t.Arch().Width(), p) +} + // PackControlMessages packs control messages into the given buffer. // // We skip control messages specific to Unix domain sockets. // // Note that some control messages may be truncated if they do not fit under // the capacity of buf. -func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte { +func PackControlMessages(t *kernel.Task, family int, cmsgs socket.ControlMessages, buf []byte) []byte { if cmsgs.IP.HasTimestamp { buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf) } @@ -387,6 +400,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) } + if cmsgs.IP.HasOriginalDstAddress { + buf = PackOriginalDstAddress(t, family, cmsgs.IP.OriginalDstAddress, buf) + } + return buf } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 7d3c4a01c..a35850a8f 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -551,7 +551,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b } controlBuf := make([]byte, 0, space) // PackControlMessages will append up to space bytes to controlBuf. - controlBuf = control.PackControlMessages(t, controlMessages, controlBuf) + controlBuf = control.PackControlMessages(t, s.family, controlMessages, controlBuf) sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) { // Refuse to do anything if any part of src.Addrs was unusable. diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 3cc0d4f0f..37c3faa57 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1418,6 +1418,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) return &v, nil + case linux.IPV6_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil + case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { return nil, syserr.ErrInvalidArgument @@ -1599,6 +1607,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) return &v, nil + case linux.IP_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil + case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { return nil, syserr.ErrInvalidArgument @@ -2094,6 +2110,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_RECVORIGDSTADDR: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2325,6 +2350,18 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in ep.SocketOptions().SetHeaderIncluded(v != 0) return nil + case linux.IP_RECVORIGDSTADDR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { return syserr.ErrInvalidArgument @@ -2363,7 +2400,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, - linux.IP_RECVORIGDSTADDR, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2441,7 +2477,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVFRAGSIZE, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, - linux.IPV6_RECVORIGDSTADDR, linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, @@ -2746,14 +2781,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq func (s *socketOpsCommon) controlMessages() socket.ControlMessages { return socket.ControlMessages{ IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, - HasTClass: s.readCM.HasTClass, - TClass: s.readCM.TClass, - HasIPPacketInfo: s.readCM.HasIPPacketInfo, - PacketInfo: s.readCM.PacketInfo, + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, + HasOriginalDstAddress: s.readCM.HasOriginalDstAddress, + OriginalDstAddress: s.readCM.OriginalDstAddress, }, } } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 9cd052c3d..754a5df24 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -784,8 +784,9 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } defer cms.Release(t) + family, _, _ := s.Type() controlData := make([]byte, 0, msg.ControlLen) - controlData = control.PackControlMessages(t, cms, controlData) + controlData = control.PackControlMessages(t, family, cms, controlData) if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { creds, _ := cms.Unix.Credentials.(control.SCMCredentials) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 7b33b3f59..c40892b9f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -787,8 +787,9 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla } defer cms.Release(t) + family, _, _ := s.Type() controlData := make([]byte, 0, msg.ControlLen) - controlData = control.PackControlMessages(t, cms, controlData) + controlData = control.PackControlMessages(t, family, cms, controlData) if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { creds, _ := cms.Unix.Credentials.(control.SCMCredentials) diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index cb48c37a1..0df023713 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build go1.12 -// +build !go1.17 - -// Check go:linkname function signatures when updating Go version. - package vfs import ( @@ -41,6 +36,15 @@ type mountKey struct { point unsafe.Pointer // *Dentry } +var ( + mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil)) + mountKeySeed = sync.RandUintptr() +) + +func (k *mountKey) hash() uintptr { + return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed) +} + func (mnt *Mount) parent() *Mount { return (*Mount)(atomic.LoadPointer(&mnt.key.parent)) } @@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry { } } -func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } - // Invariant: mnt.key.parent == nil. vd.Ok(). func (mnt *Mount) setKey(vd VirtualDentry) { atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) } -func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } - // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). // // mountTable.Init() must be called on new mountTables before use. -// -// +stateify savable type mountTable struct { // mountTable is implemented as a seqcount-protected hash table that // resolves collisions with linear probing, featuring Robin Hood insertion @@ -84,8 +82,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq sync.SeqCount `state:"nosave"` - seed uint32 // for hashing keys + seq sync.SeqCount `state:"nosave"` // size holds both length (number of elements) and capacity (number of // slots): capacity is stored as its base-2 log (referred to as order) in @@ -150,7 +147,6 @@ func init() { // Init must be called exactly once on each mountTable before use. func (mt *mountTable) Init() { - mt.seed = rand32() mt.size = mtInitOrder mt.slots = newMountTableSlots(mtInitCap) } @@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer { // Lookup may be called even if there are concurrent mutators of mt. func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount { key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)} - hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes) + hash := key.hash() loop: for { @@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must not already contain a Mount with the same mount point and parent. func (mt *mountTable) insertSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() // We're under the maximum load factor if: // @@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must contain mount. func (mt *mountTable) removeSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() tcap := uintptr(1) << (mt.size & mtSizeOrderMask) mask := tcap - 1 slots := mt.slots @@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) { off = (off + mountSlotBytes) & offmask } } - -//go:linkname memhash runtime.memhash -func memhash(p unsafe.Pointer, seed, s uintptr) uintptr - -//go:linkname rand32 runtime.fastrand -func rand32() uint32 diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go index 8f070ed53..8998a82dd 100644 --- a/pkg/sentry/vfs/save_restore.go +++ b/pkg/sentry/vfs/save_restore.go @@ -101,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount { return mounts } +// saveKey is called by stateify. +func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } + // loadMounts is called by stateify. func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { if mounts == nil { @@ -112,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { } } +// loadKey is called by stateify. +func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } + func (mnt *Mount) afterLoad() { if atomic.LoadInt64(&mnt.refs) != 0 { refsvfs2.Register(mnt) diff --git a/pkg/sentry/vfs/vfs_unsafe_state_autogen.go b/pkg/sentry/vfs/vfs_unsafe_state_autogen.go index d23733ee7..20f06c953 100644 --- a/pkg/sentry/vfs/vfs_unsafe_state_autogen.go +++ b/pkg/sentry/vfs/vfs_unsafe_state_autogen.go @@ -1,40 +1,3 @@ // automatically generated by stateify. -// +build go1.12 -// +build !go1.17 - package vfs - -import ( - "gvisor.dev/gvisor/pkg/state" -) - -func (mt *mountTable) StateTypeName() string { - return "pkg/sentry/vfs.mountTable" -} - -func (mt *mountTable) StateFields() []string { - return []string{ - "seed", - "size", - } -} - -func (mt *mountTable) beforeSave() {} - -func (mt *mountTable) StateSave(stateSinkObject state.Sink) { - mt.beforeSave() - stateSinkObject.Save(0, &mt.seed) - stateSinkObject.Save(1, &mt.size) -} - -func (mt *mountTable) afterLoad() {} - -func (mt *mountTable) StateLoad(stateSourceObject state.Source) { - stateSourceObject.Load(0, &mt.seed) - stateSourceObject.Load(1, &mt.size) -} - -func init() { - state.Register((*mountTable)(nil)) -} |