summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/strace/socket.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/strace/socket.go')
-rw-r--r--pkg/sentry/strace/socket.go131
1 files changed, 65 insertions, 66 deletions
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index f4aab25b0..b164d9107 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -161,12 +161,10 @@ var controlMessageType = map[int32]string{
linux.SO_TIMESTAMP: "SO_TIMESTAMP",
}
-func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights {
+func unmarshalControlMessageRights(src []byte) []primitive.Int32 {
count := len(src) / linux.SizeOfControlMessageRight
- cmr := make(linux.ControlMessageRights, count)
- for i, _ := range cmr {
- cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:]))
- }
+ cmr := make([]primitive.Int32, count)
+ primitive.UnmarshalUnsafeInt32Slice(cmr, src)
return cmr
}
@@ -182,14 +180,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
var strs []string
- for i := 0; i < len(buf); {
- if i+linux.SizeOfControlMessageHeader > len(buf) {
+ for len(buf) > 0 {
+ if linux.SizeOfControlMessageHeader > len(buf) {
strs = append(strs, "{invalid control message (too short)}")
break
}
var h linux.ControlMessageHeader
- h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader])
+ buf = h.UnmarshalUnsafe(buf)
var skipData bool
level := "SOL_SOCKET"
@@ -204,7 +202,9 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
typ = fmt.Sprint(h.Type)
}
- if h.Length > uint64(len(buf)-i) {
+ width := t.Arch().Width()
+ length := int(h.Length) - linux.SizeOfControlMessageHeader
+ if length > len(buf) {
strs = append(strs, fmt.Sprintf(
"{level=%s, type=%s, length=%d, content extends beyond buffer}",
level,
@@ -214,9 +214,6 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
break
}
- i += linux.SizeOfControlMessageHeader
- width := t.Arch().Width()
- length := int(h.Length) - linux.SizeOfControlMessageHeader
if length < 0 {
strs = append(strs, fmt.Sprintf(
"{level=%s, type=%s, length=%d, content too short}",
@@ -229,78 +226,80 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
if skipData {
strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
- i += bits.AlignUp(length, width)
- continue
- }
-
- switch h.Type {
- case linux.SCM_RIGHTS:
- rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight)
- fds := unmarshalControlMessageRights(buf[i : i+rightsSize])
- rights := make([]string, 0, len(fds))
- for _, fd := range fds {
- rights = append(rights, fmt.Sprint(fd))
- }
+ } else {
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight)
+ fds := unmarshalControlMessageRights(buf[:rightsSize])
+ rights := make([]string, 0, len(fds))
+ for _, fd := range fds {
+ rights = append(rights, fmt.Sprint(fd))
+ }
- strs = append(strs, fmt.Sprintf(
- "{level=%s, type=%s, length=%d, content: %s}",
- level,
- typ,
- h.Length,
- strings.Join(rights, ","),
- ))
-
- case linux.SCM_CREDENTIALS:
- if length < linux.SizeOfControlMessageCredentials {
strs = append(strs, fmt.Sprintf(
- "{level=%s, type=%s, length=%d, content too short}",
+ "{level=%s, type=%s, length=%d, content: %s}",
level,
typ,
h.Length,
+ strings.Join(rights, ","),
))
- break
- }
- var creds linux.ControlMessageCredentials
- creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials])
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
- strs = append(strs, fmt.Sprintf(
- "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}",
- level,
- typ,
- h.Length,
- creds.PID,
- creds.UID,
- creds.GID,
- ))
+ var creds linux.ControlMessageCredentials
+ creds.UnmarshalUnsafe(buf)
- case linux.SO_TIMESTAMP:
- if length < linux.SizeOfTimeval {
strs = append(strs, fmt.Sprintf(
- "{level=%s, type=%s, length=%d, content too short}",
+ "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}",
level,
typ,
h.Length,
+ creds.PID,
+ creds.UID,
+ creds.GID,
))
- break
- }
- var tv linux.Timeval
- tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
+ case linux.SO_TIMESTAMP:
+ if length < linux.SizeOfTimeval {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
- strs = append(strs, fmt.Sprintf(
- "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}",
- level,
- typ,
- h.Length,
- tv.Sec,
- tv.Usec,
- ))
+ var tv linux.Timeval
+ tv.UnmarshalUnsafe(buf)
- default:
- panic("unreachable")
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}",
+ level,
+ typ,
+ h.Length,
+ tv.Sec,
+ tv.Usec,
+ ))
+
+ default:
+ panic("unreachable")
+ }
+ }
+ if shift := bits.AlignUp(length, width); shift > len(buf) {
+ buf = buf[:0]
+ } else {
+ buf = buf[shift:]
}
- i += bits.AlignUp(length, width)
}
return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", "))