summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/netstack
diff options
context:
space:
mode:
authorTamir Duberstein <tamird@google.com>2021-01-12 18:39:47 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-12 18:41:41 -0800
commit626a8ca22590d78dd4d3dd319cf7f98770924b94 (patch)
tree1a12c2347a10fa3ace2a95423720de730b8d2ebe /pkg/sentry/socket/netstack
parent8b0f0b4d11e0938eec8da411323b2ce35976ab56 (diff)
Remove useless cached state
Simplify some logic while I'm here. PiperOrigin-RevId: 351491593
Diffstat (limited to 'pkg/sentry/socket/netstack')
-rw-r--r--pkg/sentry/socket/netstack/netstack.go223
1 files changed, 77 insertions, 146 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index dcf898c0a..57f224120 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -309,11 +309,6 @@ type socketOpsCommon struct {
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
- // readCM holds control message information for the last packet read
- // from Endpoint.
- readCM socket.IPControlMessages
- sender tcpip.FullAddress
- linkPacketInfo tcpip.LinkPacketInfo
// sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
// of returned messages can be returned via control messages. When
@@ -368,25 +363,6 @@ func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
}
-// Precondition: s.readMu must be held.
-func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) {
- res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{
- Peek: peek,
- NeedRemoteAddr: true,
- NeedLinkPacketInfo: true,
- })
-
- // Assign these anyways.
- s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages)
- s.sender = res.RemoteAddr
- s.linkPacketInfo = res.LinkPacketInfo
-
- if err != nil {
- return 0, 0, syserr.TranslateNetstackError(err)
- }
- return res.Count, res.Total, nil
-}
-
// Release implements fs.FileOperations.Release.
func (s *socketOpsCommon) Release(ctx context.Context) {
e, ch := waiter.NewChannelEntry(nil)
@@ -436,11 +412,13 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write
defer s.readMu.Unlock()
// This may return a blocking error.
- n, _, err := s.readLocked(dst, int(count), dup /* peek */)
+ res, err := s.Endpoint.Read(dst, int(count), tcpip.ReadOptions{
+ Peek: dup,
+ })
if err != nil {
- return 0, err.ToError()
+ return 0, syserr.TranslateNetstackError(err).ToError()
}
- return int64(n), nil
+ return int64(res.Count), nil
}
// ioSequencePayload implements tcpip.Payload.
@@ -2557,22 +2535,6 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return a, l, nil
}
-// streamRead is the fast path for non-blocking, non-peek, stream-based socket.
-//
-// Precondition: s.readMu must be locked.
-func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) {
- // Always do at least one read, even if the number of bytes to read is 0.
- var n int
- n, _, err := s.readLocked(dst, count, false /* peek */)
- if err != nil {
- return 0, err
- }
- if n > 0 {
- s.Endpoint.ModerateRecvBuf(n)
- }
- return n, nil
-}
-
func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
if !s.sockOptInq {
return
@@ -2608,133 +2570,102 @@ func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
isPacket := s.isPacketBased()
- // Fast path for regular reads from stream (e.g., TCP) endpoints. Note
- // that senderRequested is ignored for stream sockets.
- if !peek && !isPacket {
- // TCP sockets discard the data if MSG_TRUNC is set.
- //
- // This behavior is documented in man 7 tcp:
- // Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
- // argument of recv(2) (and recvmsg(2)). This flag causes the received
- // bytes of data to be discarded, rather than passed back in a
- // caller-supplied buffer.
- s.readMu.Lock()
-
- var w io.Writer
- if trunc {
- w = ioutil.Discard
- } else {
- w = dst.Writer(ctx)
- }
-
- n, err := s.streamRead(ctx, w, int(dst.NumBytes()))
-
- if err == nil && !trunc {
- // Set the control message, even if 0 bytes were read.
- s.updateTimestamp()
- }
-
- cmsg := s.controlMessages()
- s.fillCmsgInq(&cmsg)
- s.readMu.Unlock()
- return n, 0, nil, 0, cmsg, err
+ readOptions := tcpip.ReadOptions{
+ Peek: peek,
+ NeedRemoteAddr: senderRequested,
+ NeedLinkPacketInfo: isPacket,
}
- s.readMu.Lock()
- defer s.readMu.Unlock()
-
- // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
- // amount that could be read, and does not write to buffer.
- isTCPPeekTrunc := !isPacket && peek && trunc
-
+ // TCP sockets discard the data if MSG_TRUNC is set.
+ //
+ // This behavior is documented in man 7 tcp:
+ // Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
+ // argument of recv(2) (and recvmsg(2)). This flag causes the received
+ // bytes of data to be discarded, rather than passed back in a
+ // caller-supplied buffer.
var w io.Writer
- if isTCPPeekTrunc {
+ if !isPacket && trunc {
w = ioutil.Discard
} else {
w = dst.Writer(ctx)
}
- var numRead, numTotal int
- var err *syserr.Error
- numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek)
- if err != nil {
- return 0, 0, nil, 0, socket.ControlMessages{}, err
- }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
- if isTCPPeekTrunc {
- // TCP endpoint does not return the total bytes in buffer as numTotal.
- // We need to query it from socket option.
- rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
- if err != nil {
- return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
- }
- available := int(rql)
- bufLen := int(dst.NumBytes())
- if available < bufLen {
- return available, 0, nil, 0, socket.ControlMessages{}, nil
- }
- return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
+ res, err := s.Endpoint.Read(w, int(dst.NumBytes()), readOptions)
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
-
// Set the control message, even if 0 bytes were read.
- s.updateTimestamp()
+ s.updateTimestamp(res.ControlMessages)
- var addr linux.SockAddr
- var addrLen uint32
- if isPacket && senderRequested {
- addr, addrLen = socket.ConvertAddress(s.family, s.sender)
- switch v := addr.(type) {
- case *linux.SockAddrLink:
- v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol))
- v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
+ if isPacket {
+ var addr linux.SockAddr
+ var addrLen uint32
+ if senderRequested {
+ addr, addrLen = socket.ConvertAddress(s.family, res.RemoteAddr)
+ switch v := addr.(type) {
+ case *linux.SockAddrLink:
+ v.Protocol = socket.Htons(uint16(res.LinkPacketInfo.Protocol))
+ v.PacketType = toLinuxPacketType(res.LinkPacketInfo.PktType)
+ }
}
- }
- if peek {
- if trunc && numTotal > numRead {
- // isPacket must be true.
- return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil
+ msgLen := res.Count
+ if trunc {
+ msgLen = res.Total
}
- return numRead, 0, nil, 0, s.controlMessages(), nil
- }
- var msgLen int
- if isPacket {
- msgLen = numTotal
- } else {
- msgLen = numRead
- }
+ var flags int
+ if res.Total > res.Count {
+ flags |= linux.MSG_TRUNC
+ }
- var flags int
- if msgLen > numRead {
- flags |= linux.MSG_TRUNC
+ return msgLen, flags, addr, addrLen, s.controlMessages(res.ControlMessages), nil
}
- n := numRead
- if trunc {
- n = msgLen
+ if peek {
+ // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
+ // amount that could be read, and does not write to buffer.
+ if trunc {
+ // TCP endpoint does not return the total bytes in buffer as numTotal.
+ // We need to query it from socket option.
+ rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
+ }
+ msgLen := int(dst.NumBytes())
+ if msgLen > rql {
+ msgLen = rql
+ }
+ return msgLen, 0, nil, 0, socket.ControlMessages{}, nil
+ }
+ } else if n := res.Count; n != 0 {
+ s.Endpoint.ModerateRecvBuf(n)
}
- cmsg := s.controlMessages()
+ cmsg := s.controlMessages(res.ControlMessages)
s.fillCmsgInq(&cmsg)
- return n, flags, addr, addrLen, cmsg, nil
+ return res.Count, 0, nil, 0, cmsg, syserr.TranslateNetstackError(err)
}
-func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
+func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.ControlMessages {
+ readCM := socket.NewIPControlMessages(s.family, cm)
return socket.ControlMessages{
IP: socket.IPControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasInq: s.readCM.HasInq,
- Inq: s.readCM.Inq,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
- HasTClass: s.readCM.HasTClass,
- TClass: s.readCM.TClass,
- HasIPPacketInfo: s.readCM.HasIPPacketInfo,
- PacketInfo: s.readCM.PacketInfo,
- OriginalDstAddress: s.readCM.OriginalDstAddress,
- SockErr: s.readCM.SockErr,
+ HasTimestamp: readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: readCM.Timestamp,
+ HasInq: readCM.HasInq,
+ Inq: readCM.Inq,
+ HasTOS: readCM.HasTOS,
+ TOS: readCM.TOS,
+ HasTClass: readCM.HasTClass,
+ TClass: readCM.TClass,
+ HasIPPacketInfo: readCM.HasIPPacketInfo,
+ PacketInfo: readCM.PacketInfo,
+ OriginalDstAddress: readCM.OriginalDstAddress,
+ SockErr: readCM.SockErr,
},
}
}
@@ -2743,11 +2674,11 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
// successfully writing packet data out to userspace.
//
// Precondition: s.readMu must be locked.
-func (s *socketOpsCommon) updateTimestamp() {
+func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) {
// Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
if !s.sockOptTimestamp {
s.timestampValid = true
- s.timestampNS = s.readCM.Timestamp
+ s.timestampNS = cm.Timestamp
}
}