summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/header/ipv4.go9
-rw-r--r--pkg/tcpip/header/ipv6.go3
-rw-r--r--pkg/tcpip/tcpip.go20
-rw-r--r--pkg/tcpip/transport/tcp/connect.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go73
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go10
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go33
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go50
-rw-r--r--test/syscalls/linux/socket_generic_stress.cc49
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc112
-rw-r--r--test/syscalls/linux/tcp_socket.cc24
11 files changed, 362 insertions, 23 deletions
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index d0d1efd0d..680eafd16 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -315,3 +315,12 @@ func IsV4MulticastAddress(addr tcpip.Address) bool {
}
return (addr[0] & 0xf0) == 0xe0
}
+
+// IsV4LoopbackAddress determines if the provided address is an IPv4 loopback
+// address (belongs to 127.0.0.1/8 subnet).
+func IsV4LoopbackAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv4AddressSize {
+ return false
+ }
+ return addr[0] == 0x7f
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 4f367fe4c..ea3823898 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -98,6 +98,9 @@ const (
// section 5.
IPv6MinimumMTU = 1280
+ // IPv6Loopback is the IPv6 Loopback address.
+ IPv6Loopback tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+
// IPv6Any is the non-routable IPv6 "any" meta address. It is also
// known as the unspecified address.
IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 091bc5281..07c85ce59 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -958,6 +958,26 @@ type SocketDetachFilterOption int
// and port of a redirected packet.
type OriginalDestinationOption FullAddress
+// TCPTimeWaitReuseOption is used stack.(*Stack).TransportProtocolOption to
+// specify if the stack can reuse the port bound by an endpoint in TIME-WAIT for
+// new connections when it is safe from protocol viewpoint.
+type TCPTimeWaitReuseOption uint8
+
+const (
+ // TCPTimeWaitReuseDisabled indicates reuse of port bound by endponts in TIME-WAIT cannot
+ // be reused for new connections.
+ TCPTimeWaitReuseDisabled TCPTimeWaitReuseOption = iota
+
+ // TCPTimeWaitReuseGlobal indicates reuse of port bound by endponts in TIME-WAIT can
+ // be reused for new connections irrespective of the src/dest addresses.
+ TCPTimeWaitReuseGlobal
+
+ // TCPTimeWaitReuseLoopbackOnly indicates reuse of port bound by endpoint in TIME-WAIT can
+ // only be reused if the connection was a connection over loopback. i.e src/dest adddresses
+ // are loopback addresses.
+ TCPTimeWaitReuseLoopbackOnly
+)
+
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 8dd759ba2..46702906b 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1706,7 +1706,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
}
case notification:
n := e.fetchNotifications()
- if n&notifyClose != 0 || n&notifyAbort != 0 {
+ if n&notifyAbort != 0 {
return nil
}
if n&notifyDrain != 0 {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b8b52b03d..d08cfe0ff 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -449,10 +449,11 @@ type endpoint struct {
// recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
- //
- // recentTS must be read/written atomically.
recentTS uint32
+ // recentTSTime is the unix time when we updated recentTS last.
+ recentTSTime time.Time `state:".(unixTime)"`
+
// tsOffset is a randomized offset added to the value of the
// TSVal field in the timestamp option.
tsOffset uint32
@@ -795,15 +796,15 @@ func (e *endpoint) EndpointState() EndpointState {
return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
}
-// setRecentTimestamp atomically sets the recentTS field to the
-// provided value.
+// setRecentTimestamp sets the recentTS field to the provided value.
func (e *endpoint) setRecentTimestamp(recentTS uint32) {
- atomic.StoreUint32(&e.recentTS, recentTS)
+ e.recentTS = recentTS
+ e.recentTSTime = time.Now()
}
-// recentTimestamp atomically reads and returns the value of the recentTS field.
+// recentTimestamp returns the value of the recentTS field.
func (e *endpoint) recentTimestamp() uint32 {
- return atomic.LoadUint32(&e.recentTS)
+ return e.recentTS
}
// keepalive is a synchronization wrapper used to appease stateify. See the
@@ -902,7 +903,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
- case StateClose, StateError:
+ case StateClose, StateError, StateTimeWait:
// Ready for anything.
result = mask
@@ -2148,12 +2149,66 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
h.Write(portBuf)
portOffset := h.Sum32()
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &twReuse, err))
+ }
+
+ reuse := twReuse == tcpip.TCPTimeWaitReuseGlobal
+ if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly {
+ switch netProto {
+ case header.IPv4ProtocolNumber:
+ reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress)
+ case header.IPv6ProtocolNumber:
+ reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback
+ }
+ }
+
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
- return false, nil
+ if err != tcpip.ErrPortInUse || !reuse {
+ return false, nil
+ }
+ transEPID := e.ID
+ transEPID.LocalPort = p
+ // Check if an endpoint is registered with demuxer in TIME-WAIT and if
+ // we can reuse it. If we can't find a transport endpoint then we just
+ // skip using this port as it's possible that either an endpoint has
+ // bound the port but not registered with demuxer yet (no listen/connect
+ // done yet) or the reservation was freed between the check above and
+ // the FindTransportEndpoint below. But rather than retry the same port
+ // we just skip it and move on.
+ transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r)
+ if transEP == nil {
+ // ReservePort failed but there is no registered endpoint with
+ // demuxer. Which indicates there is at least some endpoint that has
+ // bound the port.
+ return false, nil
+ }
+
+ tcpEP := transEP.(*endpoint)
+ tcpEP.LockUser()
+ // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but
+ // less than 1 second has elapsed since its recentTS was updated then
+ // we cannot reuse the port.
+ if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second {
+ tcpEP.UnlockUser()
+ return false, nil
+ }
+ // Since the endpoint is in TIME-WAIT it should be safe to acquire its
+ // Lock while holding the lock for this endpoint as endpoints in
+ // TIME-WAIT do not acquire locks on other endpoints.
+ tcpEP.workerCleanup = false
+ tcpEP.cleanupLocked()
+ tcpEP.notifyProtocolGoroutine(notifyAbort)
+ tcpEP.UnlockUser()
+ // Now try and Reserve again if it fails then we skip.
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
+ return false, nil
+ }
}
id := e.ID
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index abf1ac5c9..723e47ddc 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -309,6 +309,16 @@ func (e *endpoint) loadLastError(s string) {
e.lastError = tcpip.StringToError(s)
}
+// saveRecentTSTime is invoked by stateify.
+func (e *endpoint) saveRecentTSTime() unixTime {
+ return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()}
+}
+
+// loadRecentTSTime is invoked by stateify.
+func (e *endpoint) loadRecentTSTime(unix unixTime) {
+ e.recentTSTime = time.Unix(unix.second, unix.nano)
+}
+
// saveHardError is invoked by stateify.
func (e *EndpointInfo) saveHardError() string {
if e.HardError == nil {
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2e5093b36..49a673b42 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -191,8 +191,9 @@ type protocol struct {
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
- tcpLingerTimeout time.Duration
- tcpTimeWaitTimeout time.Duration
+ lingerTimeout time.Duration
+ timeWaitTimeout time.Duration
+ timeWaitReuse tcpip.TCPTimeWaitReuseOption
minRTO time.Duration
maxRTO time.Duration
maxRetries uint32
@@ -358,7 +359,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
v = 0
}
p.mu.Lock()
- p.tcpLingerTimeout = time.Duration(v)
+ p.lingerTimeout = time.Duration(v)
p.mu.Unlock()
return nil
@@ -367,7 +368,16 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
v = 0
}
p.mu.Lock()
- p.tcpTimeWaitTimeout = time.Duration(v)
+ p.timeWaitTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPTimeWaitReuseOption:
+ if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.timeWaitReuse = v
p.mu.Unlock()
return nil
@@ -468,13 +478,19 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
case *tcpip.TCPLingerTimeoutOption:
p.mu.RLock()
- *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
+ *v = tcpip.TCPLingerTimeoutOption(p.lingerTimeout)
p.mu.RUnlock()
return nil
case *tcpip.TCPTimeWaitTimeoutOption:
p.mu.RLock()
- *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
+ *v = tcpip.TCPTimeWaitTimeoutOption(p.timeWaitTimeout)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitReuseOption:
+ p.mu.RLock()
+ *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse)
p.mu.RUnlock()
return nil
@@ -564,8 +580,9 @@ func NewProtocol() stack.TransportProtocol {
},
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
- tcpLingerTimeout: DefaultTCPLingerTimeout,
- tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ lingerTimeout: DefaultTCPLingerTimeout,
+ timeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly,
synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
synRetries: DefaultSynRetries,
minRTO: MinRTO,
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 1b58eb91b..0f7e958e4 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -7297,3 +7297,53 @@ func TestResetDuringClose(t *testing.T) {
wg.Wait()
}
+
+func TestStackTimeWaitReuse(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err)
+ }
+ if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want {
+ t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
+ }
+}
+
+func TestSetStackTimeWaitReuse(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ testCases := []struct {
+ v int
+ err *tcpip.Error
+ }{
+ {int(tcpip.TCPTimeWaitReuseDisabled), nil},
+ {int(tcpip.TCPTimeWaitReuseGlobal), nil},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue},
+ {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue},
+ }
+
+ for _, tc := range testCases {
+ err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitReuseOption(tc.v))
+ if got, want := err, tc.err; got != want {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.v, err, tc.err)
+ }
+ if tc.err != nil {
+ continue
+ }
+
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err)
+ }
+
+ if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want {
+ t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
+ }
+ }
+}
diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc
index 6a232238d..19239e9e9 100644
--- a/test/syscalls/linux/socket_generic_stress.cc
+++ b/test/syscalls/linux/socket_generic_stress.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <poll.h>
#include <stdio.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
@@ -37,6 +38,14 @@ TEST_P(ConnectStressTest, Reset65kTimes) {
char sent_data[100] = {};
ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)),
SyscallSucceedsWithValue(sizeof(sent_data)));
+ // Poll the other FD to make sure that the data is in the receive buffer
+ // before closing it to ensure a RST is triggered.
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = sockets->second_fd(),
+ .events = POLL_IN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
}
}
@@ -58,7 +67,45 @@ INSTANTIATE_TEST_SUITE_P(
// a persistent listener (if applicable).
using PersistentListenerConnectStressTest = SocketPairTest;
-TEST_P(PersistentListenerConnectStressTest, 65kTimes) {
+TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) {
+ for (int i = 0; i < 1 << 16; ++i) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds());
+ if (GetParam().type == SOCK_STREAM) {
+ // Poll the other FD to make sure that we see the FIN from the other
+ // side before closing the second_fd. This ensures that the first_fd
+ // enters TIME-WAIT and not second_fd.
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = sockets->second_fd(),
+ .events = POLL_IN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ }
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds());
+ }
+}
+
+TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) {
+ for (int i = 0; i < 1 << 16; ++i) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds());
+ if (GetParam().type == SOCK_STREAM) {
+ // Poll the other FD to make sure that we see the FIN from the other
+ // side before closing the first_fd. This ensures that the second_fd
+ // enters TIME-WAIT and not first_fd.
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = sockets->first_fd(),
+ .events = POLL_IN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ }
+ ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds());
+ }
+}
+
+TEST_P(PersistentListenerConnectStressTest, 65kTimesClose) {
for (int i = 0; i < 1 << 16; ++i) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
}
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 18b9e4b70..c3b42682f 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -865,7 +865,7 @@ TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) {
// results in the stack.Seed() being different which can cause
// sequence number of final connect to be one that is considered
// old and can cause the test to be flaky.
-TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
@@ -920,14 +920,27 @@ TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) {
&conn_addrlen),
SyscallSucceeds());
- // close the accept FD to trigger TIME_WAIT on the accepted socket which
+ // shutdown the accept FD to trigger TIME_WAIT on the accepted socket which
// should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of
// TIME_WAIT.
- accepted.reset();
- absl::SleepFor(absl::Seconds(1));
+ ASSERT_THAT(shutdown(accepted.get(), SHUT_RDWR), SyscallSucceeds());
+ {
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = conn_fd.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN);
+ }
+
conn_fd.reset();
+ // This sleep is required to give conn_fd time to transition to TIME-WAIT.
absl::SleepFor(absl::Seconds(1));
+ // At this point conn_fd should be the one that moved to CLOSE_WAIT and
+ // eventually to CLOSED.
+
// Now bind and connect a new socket and verify that we can immediately
// rebind the address bound by the conn_fd as it never entered TIME_WAIT.
const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE(
@@ -942,6 +955,97 @@ TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) {
SyscallSucceeds());
}
+TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ // We disable saves after this point as a S/R causes the netstack seed
+ // to be regenerated which changes what ports/ISN is picked for a given
+ // tuple (src ip,src port, dst ip, dst port). This can cause the final
+ // SYN to use a sequence number that looks like one from the current
+ // connection in TIME_WAIT and will not be accepted causing the test
+ // to timeout.
+ //
+ // TODO(gvisor.dev/issue/940): S/R portSeed/portHint
+ DisableSave ds;
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Get the address/port bound by the connecting socket.
+ sockaddr_storage conn_bound_addr;
+ socklen_t conn_addrlen = connector.addr_len;
+ ASSERT_THAT(
+ getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ &conn_addrlen),
+ SyscallSucceeds());
+
+ // shutdown the conn FD to trigger TIME_WAIT on the connect socket.
+ ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds());
+ {
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = accepted.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN);
+ }
+ ScopedThread t([&]() {
+ constexpr int kTimeout = 10000;
+ constexpr int16_t want_events = POLLHUP;
+ struct pollfd pfd = {
+ .fd = conn_fd.get(),
+ .events = want_events,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ });
+
+ accepted.reset();
+ t.Join();
+ conn_fd.reset();
+
+ // Now bind and connect a new socket and verify that we can't immediately
+ // rebind the address bound by the conn_fd as it is in TIME_WAIT.
+ conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ conn_addrlen),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 0cea7d11f..a6325a761 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -720,6 +720,30 @@ TEST_P(TcpSocketTest, TcpSCMPriority) {
ASSERT_EQ(cmsg, nullptr);
}
+TEST_P(TcpSocketTest, TimeWaitPollHUP) {
+ shutdown(s_, SHUT_RDWR);
+ ScopedThread t([&]() {
+ constexpr int kTimeout = 10000;
+ constexpr int16_t want_events = POLLHUP;
+ struct pollfd pfd = {
+ .fd = s_,
+ .events = want_events,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ });
+ shutdown(t_, SHUT_RDWR);
+ t.Join();
+ // At this point s_ should be in TIME-WAIT and polling for POLLHUP should
+ // return with 1 FD.
+ constexpr int kTimeout = 10000;
+ constexpr int16_t want_events = POLLHUP;
+ struct pollfd pfd = {
+ .fd = s_,
+ .events = want_events,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, TcpSocketTest,
::testing::Values(AF_INET, AF_INET6));