summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorBhasker Hariharan <bhaskerh@google.com>2020-01-29 15:41:51 -0800
committergVisor bot <gvisor-bot@google.com>2020-01-29 15:53:45 -0800
commit51b783505b1ec164b02b48a0fd234509fba01a73 (patch)
treee68ead67e442b6f4df6e5852d8267f3d0fa37646
parent148fda60e8dee29f2df85e3104e3d5de1a225bcf (diff)
Add support for TCP_DEFER_ACCEPT.
PiperOrigin-RevId: 292233574
-rw-r--r--pkg/sentry/socket/netstack/netstack.go22
-rw-r--r--pkg/tcpip/tcpip.go6
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go25
-rw-r--r--pkg/tcpip/transport/tcp/connect.go53
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go26
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go126
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc158
-rw-r--r--test/syscalls/linux/tcp_socket.cc53
10 files changed, 451 insertions, 23 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 8619cc506..049d04bf2 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1260,6 +1260,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return int32(time.Duration(v) / time.Second), nil
+ case linux.TCP_DEFER_ACCEPT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPDeferAcceptOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1713,6 +1725,16 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))))
+ case linux.TCP_DEFER_ACCEPT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < 0 {
+ v = 0
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v))))
+
case linux.TCP_REPAIR_OPTIONS:
t.Kernel().EmitUnimplementedEvent(t)
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 59c9b3fb0..0fa141d58 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -626,6 +626,12 @@ type TCPLingerTimeoutOption time.Duration
// before being marked closed.
type TCPTimeWaitTimeoutOption time.Duration
+// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a
+// accept to return a completed connection only when there is data to be
+// read. This usually means the listening socket will drop the final ACK
+// for a handshake till the specified timeout until a segment with data arrives.
+type TCPDeferAcceptOption time.Duration
+
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 4acd9fb9a..7b4a87a2d 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -57,6 +57,7 @@ go_library(
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/sync",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d469758eb..6101f2945 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -222,13 +222,13 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
netProto = s.route.NetProto
}
- n := newEndpoint(l.stack, netProto, nil)
+ n := newEndpoint(l.stack, netProto, queue)
n.v6only = l.v6only
n.ID = s.id
n.boundNICID = s.route.NICID()
@@ -273,16 +273,17 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// createEndpoint creates a new endpoint in connected state and then performs
// the TCP 3-way handshake.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts)
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
if err != nil {
return nil, err
}
// listenEP is nil when listenContext is used by tcp.Forwarder.
+ deferAccept := time.Duration(0)
if l.listenEP != nil {
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
@@ -290,13 +291,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
+ deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
// Perform the 3-way handshake.
- h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
-
- h.resetToSynRcvd(isn, irs, opts)
+ h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.Close()
if l.listenEP != nil {
@@ -377,16 +377,14 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
defer e.decSynRcvdCount()
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts)
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{})
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
return
}
ctx.removePendingEndpoint(n)
- // Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
+ n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
e.deliverAccepted(n)
@@ -546,7 +544,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
+ n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{})
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
@@ -576,8 +574,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// space available in the backlog.
// Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
+ n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 4e3c5419c..9ff7ac261 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -86,6 +86,19 @@ type handshake struct {
// rcvWndScale is the receive window scale, as defined in RFC 1323.
rcvWndScale int
+
+ // startTime is the time at which the first SYN/SYN-ACK was sent.
+ startTime time.Time
+
+ // deferAccept if non-zero will drop the final ACK for a passive
+ // handshake till an ACK segment with data is received or the timeout is
+ // hit.
+ deferAccept time.Duration
+
+ // acked is true if the the final ACK for a 3-way handshake has
+ // been received. This is required to stop retransmitting the
+ // original SYN-ACK when deferAccept is enabled.
+ acked bool
}
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
@@ -112,6 +125,12 @@ func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
return h
}
+func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake {
+ h := newHandshake(ep, rcvWnd)
+ h.resetToSynRcvd(isn, irs, opts, deferAccept)
+ return h
+}
+
// FindWndScale determines the window scale to use for the given maximum window
// size.
func FindWndScale(wnd seqnum.Size) int {
@@ -181,7 +200,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 {
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
-func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) {
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) {
h.active = false
h.state = handshakeSynRcvd
h.flags = header.TCPFlagSyn | header.TCPFlagAck
@@ -189,6 +208,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.ackNum = irs + 1
h.mss = opts.MSS
h.sndWndScale = opts.WS
+ h.deferAccept = deferAccept
h.ep.mu.Lock()
h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
@@ -352,6 +372,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
// We have previously received (and acknowledged) the peer's SYN. If the
// peer acknowledges our SYN, the handshake is completed.
if s.flagIsSet(header.TCPFlagAck) {
+ // If deferAccept is not zero and this is a bare ACK and the
+ // timeout is not hit then drop the ACK.
+ if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept {
+ h.acked = true
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
// If the timestamp option is negotiated and the segment does
// not carry a timestamp option then the segment must be dropped
// as per https://tools.ietf.org/html/rfc7323#section-3.2.
@@ -365,10 +393,16 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
h.state = handshakeCompleted
+
h.ep.mu.Lock()
h.ep.transitionToStateEstablishedLocked(h)
+ // If the segment has data then requeue it for the receiver
+ // to process it again once main loop is started.
+ if s.data.Size() > 0 {
+ s.incRef()
+ h.ep.enqueueSegment(s)
+ }
h.ep.mu.Unlock()
-
return nil
}
@@ -471,6 +505,7 @@ func (h *handshake) execute() *tcpip.Error {
}
}
+ h.startTime = time.Now()
// Initialize the resend timer.
resendWaker := sleep.Waker{}
timeOut := time.Duration(time.Second)
@@ -524,11 +559,21 @@ func (h *handshake) execute() *tcpip.Error {
switch index, _ := s.Fetch(true); index {
case wakerForResend:
timeOut *= 2
- if timeOut > 60*time.Second {
+ if timeOut > MaxRTO {
return tcpip.ErrTimeout
}
rt.Reset(timeOut)
- h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ // Resend the SYN/SYN-ACK only if the following conditions hold.
+ // - It's an active handshake (deferAccept does not apply)
+ // - It's a passive handshake and we have not yet got the final-ACK.
+ // - It's a passive handshake and we got an ACK but deferAccept is
+ // enabled and we are now past the deferAccept duration.
+ // The last is required to provide a way for the peer to complete
+ // the connection with another ACK or data (as ACKs are never
+ // retransmitted on their own).
+ if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
+ h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ }
case wakerForNotification:
n := h.ep.fetchNotifications()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 13718ff55..8d52414b7 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -498,6 +498,13 @@ type endpoint struct {
// without any data being acked.
userTimeout time.Duration
+ // deferAccept if non-zero specifies a user specified time during
+ // which the final ACK of a handshake will be dropped provided the
+ // ACK is a bare ACK and carries no data. If the timeout is crossed then
+ // the bare ACK is accepted and the connection is delivered to the
+ // listener.
+ deferAccept time.Duration
+
// pendingAccepted is a synchronization primitive used to track number
// of connections that are queued up to be delivered to the accepted
// channel. We use this to ensure that all goroutines blocked on writing
@@ -1574,6 +1581,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.TCPDeferAcceptOption:
+ e.mu.Lock()
+ if time.Duration(v) > MaxRTO {
+ v = tcpip.TCPDeferAcceptOption(MaxRTO)
+ }
+ e.deferAccept = time.Duration(v)
+ e.mu.Unlock()
+ return nil
+
default:
return nil
}
@@ -1798,6 +1814,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case *tcpip.TCPDeferAcceptOption:
+ e.mu.Lock()
+ *o = tcpip.TCPDeferAcceptOption(e.deferAccept)
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2149,9 +2171,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
-func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+func (e *endpoint) startAcceptedLoop() {
e.mu.Lock()
- e.waiterQueue = waiterQueue
e.workerRunning = true
e.mu.Unlock()
wakerInitDone := make(chan struct{})
@@ -2177,7 +2198,6 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
default:
return nil, nil, tcpip.ErrWouldBlock
}
-
return n, n.waiterQueue, nil
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 7eb613be5..c9ee5bf06 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -157,13 +157,13 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
TSVal: r.synOptions.TSVal,
TSEcr: r.synOptions.TSEcr,
SACKPermitted: r.synOptions.SACKPermitted,
- })
+ }, queue)
if err != nil {
return nil, err
}
// Start the protocol goroutine.
- ep.startAcceptedLoop(queue)
+ ep.startAcceptedLoop()
return ep, nil
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index df2fb1071..a12336d47 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -6787,3 +6787,129 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
),
)
}
+
+func TestTCPDeferAccept(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give a bit of time for the socket to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestTCPDeferAcceptTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Sleep for a little of the tcpDeferAccept timeout.
+ time.Sleep(tcpDeferAccept + 100*time.Millisecond)
+
+ // On timeout expiry we should get a SYN-ACK retransmission.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give sometime for the endpoint to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 2f9821555..3bf7081b9 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -828,6 +828,164 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) {
EXPECT_EQ(get, kUserTimeout);
}
+// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+// saved. Enable S/R once issue is fixed.
+TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) {
+ // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+ // saved. Enable S/R issue is fixed.
+ DisableSave ds;
+
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Set the TCP_DEFER_ACCEPT on the listening socket.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Set the listening socket to nonblock so that we can verify that there is no
+ // connection in queue despite the connect above succeeding since the peer has
+ // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the
+ // FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Set FD back to blocking.
+ opts &= ~O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Now write some data to the socket.
+ int data = 0;
+ ASSERT_THAT(RetryEINTR(write)(conn_fd.get(), &data, sizeof(data)),
+ SyscallSucceedsWithValue(sizeof(data)));
+
+ // This should now cause the connection to complete and be delivered to the
+ // accept socket.
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Verify that the accepted socket returns the data written.
+ int get = -1;
+ ASSERT_THAT(RetryEINTR(recv)(accepted.get(), &get, sizeof(get), 0),
+ SyscallSucceedsWithValue(sizeof(get)));
+
+ EXPECT_EQ(get, data);
+}
+
+// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+// saved. Enable S/R once issue is fixed.
+TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) {
+ // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+ // saved. Enable S/R once issue is fixed.
+ DisableSave ds;
+
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Set the TCP_DEFER_ACCEPT on the listening socket.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Set the listening socket to nonblock so that we can verify that there is no
+ // connection in queue despite the connect above succeeding since the peer has
+ // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the
+ // FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Verify that there is no acceptable connection before TCP_DEFER_ACCEPT
+ // timeout is hit.
+ absl::SleepFor(absl::Seconds(kTCPDeferAccept - 1));
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Set FD back to blocking.
+ opts &= ~O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Now sleep for a little over the TCP_DEFER_ACCEPT duration. When the timeout
+ // is hit a SYN-ACK should be retransmitted by the listener as a last ditch
+ // attempt to complete the connection with or without data.
+ absl::SleepFor(absl::Seconds(2));
+
+ // Verify that we have a connection that can be accepted even though no
+ // data was written.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+}
+
INSTANTIATE_TEST_SUITE_P(
All, SocketInetLoopbackTest,
::testing::Values(
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 33a5ac66c..525ccbd88 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -1286,6 +1286,59 @@ TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) {
EXPECT_EQ(get, kTCPUserTimeout);
}
+TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptNeg) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // -ve TCP_DEFER_ACCEPT is same as setting it to zero.
+ constexpr int kNeg = -1;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &kNeg, sizeof(kNeg)),
+ SyscallSucceeds());
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0);
+}
+
+TEST_P(SimpleTcpSocketTest, GetTCPDeferAcceptDefault) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ // kTCPDeferAccept is in seconds.
+ // NOTE: linux translates seconds to # of retries and back from
+ // #of retries to seconds. Which means only certain values
+ // translate back exactly. That's why we use 3 here, a value of
+ // 5 will result in us getting back 7 instead of 5 in the
+ // getsockopt.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kTCPDeferAccept);
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));