summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry')
-rw-r--r--pkg/sentry/socket/control/control.go19
-rw-r--r--pkg/sentry/socket/hostinet/socket.go2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go57
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go3
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go36
-rw-r--r--pkg/sentry/vfs/save_restore.go6
-rw-r--r--pkg/sentry/vfs/vfs_unsafe_state_autogen.go37
8 files changed, 89 insertions, 74 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 70ccf77a7..c284efde5 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -359,13 +359,26 @@ func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte)
)
}
+// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
+func PackOriginalDstAddress(t *kernel.Task, family int, originalDstAddress tcpip.FullAddress, buf []byte) []byte {
+ p, _ := socket.ConvertAddress(family, originalDstAddress)
+ level := uint32(linux.SOL_IP)
+ optType := uint32(linux.IP_RECVORIGDSTADDR)
+ if family == linux.AF_INET6 {
+ level = linux.SOL_IPV6
+ optType = linux.IPV6_RECVORIGDSTADDR
+ }
+ return putCmsgStruct(
+ buf, level, optType, t.Arch().Width(), p)
+}
+
// PackControlMessages packs control messages into the given buffer.
//
// We skip control messages specific to Unix domain sockets.
//
// Note that some control messages may be truncated if they do not fit under
// the capacity of buf.
-func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
+func PackControlMessages(t *kernel.Task, family int, cmsgs socket.ControlMessages, buf []byte) []byte {
if cmsgs.IP.HasTimestamp {
buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
}
@@ -387,6 +400,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
}
+ if cmsgs.IP.HasOriginalDstAddress {
+ buf = PackOriginalDstAddress(t, family, cmsgs.IP.OriginalDstAddress, buf)
+ }
+
return buf
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 7d3c4a01c..a35850a8f 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -551,7 +551,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
}
controlBuf := make([]byte, 0, space)
// PackControlMessages will append up to space bytes to controlBuf.
- controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)
+ controlBuf = control.PackControlMessages(t, s.family, controlMessages, controlBuf)
sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
// Refuse to do anything if any part of src.Addrs was unusable.
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 3cc0d4f0f..37c3faa57 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1418,6 +1418,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass()))
return &v, nil
+ case linux.IPV6_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
+
case linux.IP6T_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet6{})) {
return nil, syserr.ErrInvalidArgument
@@ -1599,6 +1607,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded()))
return &v, nil
+ case linux.IP_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
+
case linux.SO_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet{})) {
return nil, syserr.ErrInvalidArgument
@@ -2094,6 +2110,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_RECVORIGDSTADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2325,6 +2350,18 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
ep.SocketOptions().SetHeaderIncluded(v != 0)
return nil
+ case linux.IP_RECVORIGDSTADDR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPT_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIPTReplace {
return syserr.ErrInvalidArgument
@@ -2363,7 +2400,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
- linux.IP_RECVORIGDSTADDR,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -2441,7 +2477,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVFRAGSIZE,
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
- linux.IPV6_RECVORIGDSTADDR,
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
@@ -2746,14 +2781,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
- HasTClass: s.readCM.HasTClass,
- TClass: s.readCM.TClass,
- HasIPPacketInfo: s.readCM.HasIPPacketInfo,
- PacketInfo: s.readCM.PacketInfo,
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ HasOriginalDstAddress: s.readCM.HasOriginalDstAddress,
+ OriginalDstAddress: s.readCM.OriginalDstAddress,
},
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 9cd052c3d..754a5df24 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -784,8 +784,9 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
}
defer cms.Release(t)
+ family, _, _ := s.Type()
controlData := make([]byte, 0, msg.ControlLen)
- controlData = control.PackControlMessages(t, cms, controlData)
+ controlData = control.PackControlMessages(t, family, cms, controlData)
if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 7b33b3f59..c40892b9f 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -787,8 +787,9 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
}
defer cms.Release(t)
+ family, _, _ := s.Type()
controlData := make([]byte, 0, msg.ControlLen)
- controlData = control.PackControlMessages(t, cms, controlData)
+ controlData = control.PackControlMessages(t, family, cms, controlData)
if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index cb48c37a1..0df023713 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.12
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
package vfs
import (
@@ -41,6 +36,15 @@ type mountKey struct {
point unsafe.Pointer // *Dentry
}
+var (
+ mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil))
+ mountKeySeed = sync.RandUintptr()
+)
+
+func (k *mountKey) hash() uintptr {
+ return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed)
+}
+
func (mnt *Mount) parent() *Mount {
return (*Mount)(atomic.LoadPointer(&mnt.key.parent))
}
@@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry {
}
}
-func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
-
// Invariant: mnt.key.parent == nil. vd.Ok().
func (mnt *Mount) setKey(vd VirtualDentry) {
atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
}
-func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
-
// mountTable maps (mount parent, mount point) pairs to mounts. It supports
// efficient concurrent lookup, even in the presence of concurrent mutators
// (provided mutation is sufficiently uncommon).
//
// mountTable.Init() must be called on new mountTables before use.
-//
-// +stateify savable
type mountTable struct {
// mountTable is implemented as a seqcount-protected hash table that
// resolves collisions with linear probing, featuring Robin Hood insertion
@@ -84,8 +82,7 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq sync.SeqCount `state:"nosave"`
- seed uint32 // for hashing keys
+ seq sync.SeqCount `state:"nosave"`
// size holds both length (number of elements) and capacity (number of
// slots): capacity is stored as its base-2 log (referred to as order) in
@@ -150,7 +147,6 @@ func init() {
// Init must be called exactly once on each mountTable before use.
func (mt *mountTable) Init() {
- mt.seed = rand32()
mt.size = mtInitOrder
mt.slots = newMountTableSlots(mtInitCap)
}
@@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := key.hash()
loop:
for {
@@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must not already contain a Mount with the same mount point and parent.
func (mt *mountTable) insertSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
// We're under the maximum load factor if:
//
@@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must contain mount.
func (mt *mountTable) removeSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
mask := tcap - 1
slots := mt.slots
@@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) {
off = (off + mountSlotBytes) & offmask
}
}
-
-//go:linkname memhash runtime.memhash
-func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
-
-//go:linkname rand32 runtime.fastrand
-func rand32() uint32
diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go
index 8f070ed53..8998a82dd 100644
--- a/pkg/sentry/vfs/save_restore.go
+++ b/pkg/sentry/vfs/save_restore.go
@@ -101,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount {
return mounts
}
+// saveKey is called by stateify.
+func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
+
// loadMounts is called by stateify.
func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
if mounts == nil {
@@ -112,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
}
}
+// loadKey is called by stateify.
+func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
+
func (mnt *Mount) afterLoad() {
if atomic.LoadInt64(&mnt.refs) != 0 {
refsvfs2.Register(mnt)
diff --git a/pkg/sentry/vfs/vfs_unsafe_state_autogen.go b/pkg/sentry/vfs/vfs_unsafe_state_autogen.go
index d23733ee7..20f06c953 100644
--- a/pkg/sentry/vfs/vfs_unsafe_state_autogen.go
+++ b/pkg/sentry/vfs/vfs_unsafe_state_autogen.go
@@ -1,40 +1,3 @@
// automatically generated by stateify.
-// +build go1.12
-// +build !go1.17
-
package vfs
-
-import (
- "gvisor.dev/gvisor/pkg/state"
-)
-
-func (mt *mountTable) StateTypeName() string {
- return "pkg/sentry/vfs.mountTable"
-}
-
-func (mt *mountTable) StateFields() []string {
- return []string{
- "seed",
- "size",
- }
-}
-
-func (mt *mountTable) beforeSave() {}
-
-func (mt *mountTable) StateSave(stateSinkObject state.Sink) {
- mt.beforeSave()
- stateSinkObject.Save(0, &mt.seed)
- stateSinkObject.Save(1, &mt.size)
-}
-
-func (mt *mountTable) afterLoad() {}
-
-func (mt *mountTable) StateLoad(stateSourceObject state.Source) {
- stateSourceObject.Load(0, &mt.seed)
- stateSourceObject.Load(1, &mt.size)
-}
-
-func init() {
- state.Register((*mountTable)(nil))
-}