summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/stack/nic.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go11
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go29
-rw-r--r--test/syscalls/linux/tcp_socket.cc15
4 files changed, 43 insertions, 14 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 46d3a6646..3e6196aee 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -451,7 +451,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
for _, r := range primaryAddrs {
// If r is not valid for outgoing connections, it is not a valid endpoint.
- if !r.isValidForOutgoing() {
+ if !r.isValidForOutgoingRLocked() {
continue
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index cd247f3e1..ae4f3f3a9 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -295,6 +295,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
h.state = handshakeSynRcvd
h.ep.mu.Lock()
ttl := h.ep.ttl
+ amss := h.ep.amss
h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
@@ -307,7 +308,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// permits SACK. This is not explicitly defined in the RFC but
// this is the behaviour implemented by Linux.
SACKPermitted: rcvSynOpts.SACKPermitted,
- MSS: h.ep.amss,
+ MSS: amss,
}
if ttl == 0 {
ttl = s.route.DefaultTTL()
@@ -356,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+ h.ep.mu.RLock()
+ amss := h.ep.amss
+ h.ep.mu.RUnlock()
+
h.resetState()
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
@@ -363,7 +368,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.sackPermitted,
- MSS: h.ep.amss,
+ MSS: amss,
}
h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
@@ -530,6 +535,7 @@ func (h *handshake) execute() *tcpip.Error {
// Send the initial SYN segment and loop until the handshake is
// completed.
+ h.ep.mu.Lock()
h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
synOpts := header.TCPSynOptions{
@@ -540,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error {
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
}
+ h.ep.mu.Unlock()
// Execute is also called in a listen context so we want to make sure we
// only send the TS/SACK option when we received the TS/SACK in the
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 9e72730bd..8b9154e69 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -959,15 +959,18 @@ func (e *endpoint) initialReceiveWindow() int {
// ModerateRecvBuf adjusts the receive buffer and the advertised window
// based on the number of bytes copied to user space.
func (e *endpoint) ModerateRecvBuf(copied int) {
+ e.mu.RLock()
e.rcvListMu.Lock()
if e.rcvAutoParams.disabled {
e.rcvListMu.Unlock()
+ e.mu.RUnlock()
return
}
now := time.Now()
if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
e.rcvAutoParams.copied += copied
e.rcvListMu.Unlock()
+ e.mu.RUnlock()
return
}
prevRTTCopied := e.rcvAutoParams.copied + copied
@@ -1008,7 +1011,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvBufSize = rcvWnd
availAfter := e.receiveBufferAvailableLocked()
mask := uint32(notifyReceiveWindowChanged)
- if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.notifyProtocolGoroutine(mask)
@@ -1023,6 +1026,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvAutoParams.measureTime = now
e.rcvAutoParams.copied = 0
e.rcvListMu.Unlock()
+ e.mu.RUnlock()
}
// IPTables implements tcpip.Endpoint.IPTables.
@@ -1052,7 +1056,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
v, err := e.readLocked()
e.rcvListMu.Unlock()
-
e.mu.RUnlock()
if err == tcpip.ErrClosedForReceive {
@@ -1085,7 +1088,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1303,9 +1306,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
-// windowCrossedACKThreshold checks if the receive window to be announced now
-// would be under aMSS or under half receive buffer, whichever smaller. This is
-// useful as a receive side silly window syndrome prevention mechanism. If
+// windowCrossedACKThresholdLocked checks if the receive window to be announced
+// now would be under aMSS or under half receive buffer, whichever smaller. This
+// is useful as a receive side silly window syndrome prevention mechanism. If
// window grows to reasonable value, we should send ACK to the sender to inform
// the rx space is now large. We also want ensure a series of small read()'s
// won't trigger a flood of spurious tiny ACK's.
@@ -1316,7 +1319,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// crossed will be true if the window size crossed the ACK threshold.
// above will be true if the new window is >= ACK threshold and false
// otherwise.
-func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) {
+//
+// Precondition: e.mu and e.rcvListMu must be held.
+func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
newAvail := e.receiveBufferAvailableLocked()
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
@@ -1379,6 +1384,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
mask := uint32(notifyReceiveWindowChanged)
+ e.mu.RLock()
e.rcvListMu.Lock()
// Make sure the receive buffer size allows us to send a
@@ -1405,11 +1411,11 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// Immediately send an ACK to uncork the sender silly window
// syndrome prevetion, when our available space grows above aMSS
// or half receive buffer, whichever smaller.
- if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.rcvListMu.Unlock()
-
+ e.mu.RUnlock()
e.notifyProtocolGoroutine(mask)
return nil
@@ -2414,13 +2420,14 @@ func (e *endpoint) updateSndBufferUsage(v int) {
// to be read, or when the connection is closed for receiving (in which case
// s will be nil).
func (e *endpoint) readyToRead(s *segment) {
+ e.mu.RLock()
e.rcvListMu.Lock()
if s != nil {
s.incRef()
e.rcvBufUsed += s.data.Size()
// Increase counter if the receive window falls down below MSS
// or half receive buffer size, whichever smaller.
- if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
}
e.rcvList.PushBack(s)
@@ -2428,7 +2435,7 @@ func (e *endpoint) readyToRead(s *segment) {
e.rcvClosed = true
}
e.rcvListMu.Unlock()
-
+ e.mu.RUnlock()
e.waiterQueue.Notify(waiter.EventIn)
}
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index c4591a3b9..579463384 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -1349,6 +1349,21 @@ TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) {
SyscallFailsWithErrno(ENOTCONN));
}
+TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ addrlen);
+ int buf_sz = 1 << 18;
+ EXPECT_THAT(
+ setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)),
+ SyscallSucceedsWithValue(0));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));