diff options
Diffstat (limited to 'pkg/sentry/strace/socket.go')
-rw-r--r-- | pkg/sentry/strace/socket.go | 131 |
1 files changed, 65 insertions, 66 deletions
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index f4aab25b0..b164d9107 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -161,12 +161,10 @@ var controlMessageType = map[int32]string{ linux.SO_TIMESTAMP: "SO_TIMESTAMP", } -func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights { +func unmarshalControlMessageRights(src []byte) []primitive.Int32 { count := len(src) / linux.SizeOfControlMessageRight - cmr := make(linux.ControlMessageRights, count) - for i, _ := range cmr { - cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:])) - } + cmr := make([]primitive.Int32, count) + primitive.UnmarshalUnsafeInt32Slice(cmr, src) return cmr } @@ -182,14 +180,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) var strs []string - for i := 0; i < len(buf); { - if i+linux.SizeOfControlMessageHeader > len(buf) { + for len(buf) > 0 { + if linux.SizeOfControlMessageHeader > len(buf) { strs = append(strs, "{invalid control message (too short)}") break } var h linux.ControlMessageHeader - h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) + buf = h.UnmarshalUnsafe(buf) var skipData bool level := "SOL_SOCKET" @@ -204,7 +202,9 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) typ = fmt.Sprint(h.Type) } - if h.Length > uint64(len(buf)-i) { + width := t.Arch().Width() + length := int(h.Length) - linux.SizeOfControlMessageHeader + if length > len(buf) { strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, content extends beyond buffer}", level, @@ -214,9 +214,6 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) break } - i += linux.SizeOfControlMessageHeader - width := t.Arch().Width() - length := int(h.Length) - linux.SizeOfControlMessageHeader if length < 0 { strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, content too short}", @@ -229,78 +226,80 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) - i += bits.AlignUp(length, width) - continue - } - - switch h.Type { - case linux.SCM_RIGHTS: - rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) - fds := unmarshalControlMessageRights(buf[i : i+rightsSize]) - rights := make([]string, 0, len(fds)) - for _, fd := range fds { - rights = append(rights, fmt.Sprint(fd)) - } + } else { + switch h.Type { + case linux.SCM_RIGHTS: + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) + fds := unmarshalControlMessageRights(buf[:rightsSize]) + rights := make([]string, 0, len(fds)) + for _, fd := range fds { + rights = append(rights, fmt.Sprint(fd)) + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content: %s}", - level, - typ, - h.Length, - strings.Join(rights, ","), - )) - - case linux.SCM_CREDENTIALS: - if length < linux.SizeOfControlMessageCredentials { strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content too short}", + "{level=%s, type=%s, length=%d, content: %s}", level, typ, h.Length, + strings.Join(rights, ","), )) - break - } - var creds linux.ControlMessageCredentials - creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) + case linux.SCM_CREDENTIALS: + if length < linux.SizeOfControlMessageCredentials { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", - level, - typ, - h.Length, - creds.PID, - creds.UID, - creds.GID, - )) + var creds linux.ControlMessageCredentials + creds.UnmarshalUnsafe(buf) - case linux.SO_TIMESTAMP: - if length < linux.SizeOfTimeval { strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content too short}", + "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", level, typ, h.Length, + creds.PID, + creds.UID, + creds.GID, )) - break - } - var tv linux.Timeval - tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) + case linux.SO_TIMESTAMP: + if length < linux.SizeOfTimeval { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", - level, - typ, - h.Length, - tv.Sec, - tv.Usec, - )) + var tv linux.Timeval + tv.UnmarshalUnsafe(buf) - default: - panic("unreachable") + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", + level, + typ, + h.Length, + tv.Sec, + tv.Usec, + )) + + default: + panic("unreachable") + } + } + if shift := bits.AlignUp(length, width); shift > len(buf) { + buf = buf[:0] + } else { + buf = buf[shift:] } - i += bits.AlignUp(length, width) } return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", ")) |