diff options
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r-- | pkg/sentry/socket/control/control.go | 63 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 18 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 27 | ||||
-rw-r--r-- | pkg/sentry/socket/socket.go | 5 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 3 |
5 files changed, 89 insertions, 27 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 782a3cb92..af1a4e95f 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -195,15 +195,15 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([ // the available space, we must align down. // // align must be >= 4 and each data int32 is 4 bytes. The length of the - // header is already aligned, so if we align to the with of the data there + // header is already aligned, so if we align to the width of the data there // are two cases: // 1. The aligned length is less than the length of the header. The // unaligned length was also less than the length of the header, so we // can't write anything. // 2. The aligned length is greater than or equal to the length of the - // header. We can write the header plus zero or more datas. We can't write - // a partial int32, so the length of the message will be - // min(aligned length, header + datas). + // header. We can write the header plus zero or more bytes of data. We can't + // write a partial int32, so the length of the message will be + // min(aligned length, header + data). if space < linux.SizeOfControlMessageHeader { flags |= linux.MSG_CTRUNC return buf, flags @@ -240,12 +240,12 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf buf = binary.Marshal(buf, usermem.ByteOrder, data) - // Check if we went over. + // If the control message data brought us over capacity, omit it. if cap(buf) != cap(ob) { return hdrBuf } - // Fix up length. + // Update control message length to include data. putUint64(ob, uint64(len(buf)-len(ob))) return alignSlice(buf, align) @@ -348,43 +348,62 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte { ) } -func addSpaceForCmsg(cmsgDataLen int, buf []byte) []byte { - newBuf := make([]byte, 0, len(buf)+linux.SizeOfControlMessageHeader+cmsgDataLen) - return append(newBuf, buf...) -} - -// PackControlMessages converts the given ControlMessages struct into a buffer. +// PackControlMessages packs control messages into the given buffer. +// // We skip control messages specific to Unix domain sockets. -func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages) []byte { - var buf []byte - // The use of t.Arch().Width() is analogous to Linux's use of sizeof(long) in - // CMSG_ALIGN. - width := t.Arch().Width() - +// +// Note that some control messages may be truncated if they do not fit under +// the capacity of buf. +func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte { if cmsgs.IP.HasTimestamp { - buf = addSpaceForCmsg(int(width), buf) buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf) } if cmsgs.IP.HasInq { // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. - buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageInq, width), buf) buf = PackInq(t, cmsgs.IP.Inq, buf) } if cmsgs.IP.HasTOS { - buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageTOS, width), buf) buf = PackTOS(t, cmsgs.IP.TOS, buf) } if cmsgs.IP.HasTClass { - buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageTClass, width), buf) buf = PackTClass(t, cmsgs.IP.TClass, buf) } return buf } +// cmsgSpace is equivalent to CMSG_SPACE in Linux. +func cmsgSpace(t *kernel.Task, dataLen int) int { + return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width()) +} + +// CmsgsSpace returns the number of bytes needed to fit the control messages +// represented in cmsgs. +func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { + space := 0 + + if cmsgs.IP.HasTimestamp { + space += cmsgSpace(t, linux.SizeOfTimeval) + } + + if cmsgs.IP.HasInq { + space += cmsgSpace(t, linux.SizeOfControlMessageInq) + } + + if cmsgs.IP.HasTOS { + space += cmsgSpace(t, linux.SizeOfControlMessageTOS) + } + + if cmsgs.IP.HasTClass { + space += cmsgSpace(t, linux.SizeOfControlMessageTClass) + } + + return space +} + // Parse parses a raw socket control message into portable objects. func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) { var ( diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index a8c152b54..c957b0f1d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -45,7 +45,7 @@ const ( sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in) // maxControlLen is the maximum size of a control message buffer used in a - // recvmsg syscall. + // recvmsg or sendmsg syscall. maxControlLen = 1024 ) @@ -412,9 +412,12 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msg.Namelen = uint32(len(senderAddrBuf)) } if controlLen > 0 { - controlBuf = make([]byte, maxControlLen) + if controlLen > maxControlLen { + controlLen = maxControlLen + } + controlBuf = make([]byte, controlLen) msg.Control = &controlBuf[0] - msg.Controllen = maxControlLen + msg.Controllen = controlLen } n, err := recvmsg(s.fd, &msg, sysflags) if err != nil { @@ -489,7 +492,14 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.ErrInvalidArgument } - controlBuf := control.PackControlMessages(t, controlMessages) + space := uint64(control.CmsgsSpace(t, controlMessages)) + if space > maxControlLen { + space = maxControlLen + } + controlBuf := make([]byte, 0, space) + // PackControlMessages will append up to space bytes to controlBuf. + controlBuf = control.PackControlMessages(t, controlMessages, controlBuf) + sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) { // Refuse to do anything if any part of src.Addrs was unusable. if uint64(src.NumBytes()) != srcs.NumBytes() { diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index d92399efd..140851c17 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -151,6 +151,8 @@ var Metrics = tcpip.Stats{ PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), + EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."), + EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), @@ -324,7 +326,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { + if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported } @@ -1125,6 +1127,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_USER_TIMEOUT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPUserTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Millisecond), nil + case linux.TCP_INFO: var v tcpip.TCPInfoOption if err := ep.GetSockOpt(&v); err != nil { @@ -1561,6 +1575,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_USER_TIMEOUT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) if err := ep.SetSockOpt(v); err != nil { diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 8c250c325..2389a9cdb 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -43,6 +43,11 @@ type ControlMessages struct { IP tcpip.ControlMessages } +// Release releases Unix domain socket credentials and rights. +func (c *ControlMessages) Release() { + c.Unix.Release() +} + // Socket is the interface containing socket syscalls used by the syscall layer // to redirect them to the appropriate implementation. type Socket interface { diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 1aaae8487..885758054 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -118,6 +118,9 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { func extractPath(sockaddr []byte) (string, *syserr.Error) { addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { + if err == syserr.ErrAddressFamilyNotSupported { + err = syserr.ErrInvalidArgument + } return "", err } |