From aee2c93366f451b9cc0a62430185749556fc3900 Mon Sep 17 00:00:00 2001 From: Jianfeng Tan Date: Thu, 29 Aug 2019 16:23:11 +0000 Subject: netstack: add counters for tcp CurrEstab and EstabResets Signed-off-by: Jianfeng Tan --- pkg/tcpip/transport/tcp/accept.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 844959fa0..2b4c5c2f9 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -297,7 +297,10 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() - ep.state = StateEstablished + if ep.state != StateEstablished { + ep.stack.Stats().TCP.CurrentEstablished.Increment() + ep.state = StateEstablished + } ep.mu.Unlock() // Update the receive window scaling. We can't do it before the @@ -519,6 +522,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. + n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished // Do the delivery in a separate goroutine so -- cgit v1.2.3 From d277bfba2702b319d8336b65429cf8775661ea2f Mon Sep 17 00:00:00 2001 From: Jianfeng Tan Date: Mon, 20 May 2019 11:26:10 +0000 Subject: epsocket: support /proc/net/snmp Netstack has its own stats, we use this to fill /proc/net/snmp. Note that some metrics are not recorded in Netstack, which will be shown as 0 in the proc file. Signed-off-by: Jianfeng Tan Change-Id: Ie0089184507d16f49bc0057b4b0482094417ebe1 --- pkg/sentry/socket/netstack/stack.go | 93 ++++++++++++++++++++++++++++++++++++- pkg/tcpip/transport/tcp/accept.go | 6 +-- pkg/tcpip/transport/tcp/connect.go | 12 ++--- test/syscalls/linux/proc_net.cc | 23 +++++++++ 4 files changed, 121 insertions(+), 13 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index fda0156e5..d5db8c17c 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -144,7 +144,98 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserr.ErrEndpointOperation.ToError() + switch stats := stat.(type) { + case *inet.StatSNMPIP: + ip := Metrics.IP + *stats = inet.StatSNMPIP{ + 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. + 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. + ip.PacketsReceived.Value(), // InReceives. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. + ip.InvalidAddressesReceived.Value(), // InAddrErrors. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. + ip.PacketsDelivered.Value(), // InDelivers. + ip.PacketsSent.Value(), // OutRequests. + ip.OutgoingPacketErrors.Value(), // OutDiscards. + 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. + } + case *inet.StatSNMPICMP: + in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + *stats = inet.StatSNMPICMP{ + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs. + Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors. + in.DstUnreachable.Value(), // InDestUnreachs. + in.TimeExceeded.Value(), // InTimeExcds. + in.ParamProblem.Value(), // InParmProbs. + in.SrcQuench.Value(), // InSrcQuenchs. + in.Redirect.Value(), // InRedirects. + in.Echo.Value(), // InEchos. + in.EchoReply.Value(), // InEchoReps. + in.Timestamp.Value(), // InTimestamps. + in.TimestampReply.Value(), // InTimestampReps. + in.InfoRequest.Value(), // InAddrMasks. + in.InfoReply.Value(), // InAddrMaskReps. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs. + Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. + } + case *inet.StatSNMPTCP: + tcp := Metrics.TCP + // RFC 2012 (updates 1213): SNMPv2-MIB-TCP. + *stats = inet.StatSNMPTCP{ + 1, // RtoAlgorithm. + 200, // RtoMin. + 120000, // RtoMax. + (1<<64 - 1), // MaxConn. + tcp.ActiveConnectionOpenings.Value(), // ActiveOpens. + tcp.PassiveConnectionOpenings.Value(), // PassiveOpens. + tcp.FailedConnectionAttempts.Value(), // AttemptFails. + tcp.EstablishedResets.Value(), // EstabResets. + tcp.CurrentEstablished.Value(), // CurrEstab. + tcp.ValidSegmentsReceived.Value(), // InSegs. + tcp.SegmentsSent.Value(), // OutSegs. + tcp.Retransmits.Value(), // RetransSegs. + tcp.InvalidSegmentsReceived.Value(), // InErrs. + tcp.ResetsSent.Value(), // OutRsts. + tcp.ChecksumErrors.Value(), // InCsumErrors. + } + case *inet.StatSNMPUDP: + udp := Metrics.UDP + *stats = inet.StatSNMPUDP{ + udp.PacketsReceived.Value(), // InDatagrams. + udp.UnknownPortErrors.Value(), // NoPorts. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors. + udp.PacketsSent.Value(), // OutDatagrams. + udp.ReceiveBufferErrors.Value(), // RcvbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti. + } + default: + return syserr.ErrEndpointOperation.ToError() + } + return nil } // RouteTable implements inet.Stack.RouteTable. diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 2b4c5c2f9..65c346046 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -297,10 +297,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() - if ep.state != StateEstablished { - ep.stack.Stats().TCP.CurrentEstablished.Increment() - ep.state = StateEstablished - } + ep.stack.Stats().TCP.CurrentEstablished.Increment() + ep.state = StateEstablished ep.mu.Unlock() // Update the receive window scaling. We can't do it before the diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 4467dda82..b724d02bb 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -928,10 +928,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() - if e.state == StateEstablished || e.state == StateCloseWait { - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() - } + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateError e.HardError = err @@ -1126,10 +1124,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() if e.state != StateError { - if e.state == StateEstablished || e.state == StateCloseWait { - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() - } + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateClose } // Lock released below. diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index af4cd616a..d0ef8d380 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -15,11 +15,14 @@ #include #include #include +#include #include #include #include #include "absl/strings/str_split.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "gtest/gtest.h" #include "test/util/capability_util.h" #include "test/syscalls/linux/socket_test_util.h" @@ -184,11 +187,31 @@ TEST(ProcNetSnmp, TcpEstab) { EXPECT_EQ(oldPassiveOpens, newPassiveOpens - 1); EXPECT_EQ(oldCurrEstab, newCurrEstab - 2); + // Send 1 byte from client to server. ASSERT_THAT(send(s_connect.get(), "a", 1, 0), SyscallSucceedsWithValue(1)); + constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. + + // Wait until server-side fd sees the data on its side but don't read it. + struct pollfd poll_fd = {s_accept.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + // Now close server-side fd without reading the data which leads to a RST + // packet sent to client side. s_accept.reset(-1); + + // Wait until client-side fd sees RST packet. + struct pollfd poll_fd1 = {s_connect.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd1, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + // Now close client-side fd. s_connect.reset(-1); + // Wait until the process of the netstack. + absl::SleepFor(absl::Seconds(1.0)); + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); newEstabResets = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets")); -- cgit v1.2.3 From 0864549ecc26e734bae3dcf40e0d761232f8bdad Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 28 Oct 2019 18:19:12 -0700 Subject: Use the user supplied TCP MSS when creating a new active socket This change supports using a user supplied TCP MSS for new active TCP connections. Note, the user supplied MSS must be less than or equal to the maximum possible MSS for a TCP connection's route. If it is greater than the maximum possible MSS, the maximum possible MSS will be used as the connection's MSS instead. This change does not use this user supplied MSS for connections accepted from listening sockets - that will come in a later change. Test: Test that outgoing TCP SYN segments contain a TCP MSS option with the user supplied MSS if it is not greater than the maximum possible MSS for the route. PiperOrigin-RevId: 277185125 --- pkg/tcpip/stack/nic.go | 5 ++ pkg/tcpip/transport/tcp/accept.go | 6 +- pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 24 ++++++- pkg/tcpip/transport/tcp/tcp_test.go | 124 ++++++++++++++++++++++++++++++++++++ 5 files changed, 155 insertions(+), 6 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a867f8c00..a01a208b8 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -89,6 +89,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. + + // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints + // observe an MTU of at least 1280 bytes. Ensure that this requirement + // of IPv6 is supported on this endpoint's LinkEndpoint. + nic := &NIC{ stack: stack, id: id, diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 65c346046..1dd00d026 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -400,6 +400,9 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { + // TODO(b/143300739): Use the userMSS of the listening socket + // for accepted sockets. + switch s.flags { case header.TCPFlagSyn: opts := parseSynSegmentOptions(s) @@ -434,13 +437,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // // Enable Timestamp option if the original syn did have // the timestamp option specified. - mss := mssForRoute(&s.route) synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, TSVal: tcpTimeStamp(timeStampOffset()), TSEcr: opts.TSVal, - MSS: uint16(mss), + MSS: mssForRoute(&s.route), } e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 790e89cc3..ca982c451 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -442,7 +442,7 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. - h.ep.amss = mssForRoute(&h.ep.route) + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 6ca0d73a9..8234a8b53 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -411,7 +411,7 @@ type endpoint struct { // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. - userMSS int + userMSS uint16 // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the @@ -504,6 +504,21 @@ type endpoint struct { stats Stats `state:"nosave"` } +// calculateAdvertisedMSS calculates the MSS to advertise. +// +// If userMSS is non-zero and is not greater than the maximum possible MSS for +// r, it will be used; otherwise, the maximum possible MSS will be used. +func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { + // The maximum possible MSS is dependent on the route. + maxMSS := mssForRoute(&r) + + if userMSS != 0 && userMSS < maxMSS { + return userMSS + } + + return maxMSS +} + // StopWork halts packet processing. Only to be used in tests. func (e *endpoint) StopWork() { e.workMu.Lock() @@ -752,7 +767,9 @@ func (e *endpoint) initialReceiveWindow() int { if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } - routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + + // Use the user supplied MSS, if available. + routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2 if rcvWnd > routeWnd { rcvWnd = routeWnd } @@ -1206,7 +1223,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidOptionValue } e.mu.Lock() - e.userMSS = int(userMSS) + e.userMSS = uint16(userMSS) e.mu.Unlock() e.notifyProtocolGoroutine(notifyMSSChanged) return nil @@ -2383,5 +2400,6 @@ func (e *endpoint) Stats() tcpip.EndpointStats { } func mssForRoute(r *stack.Route) uint16 { + // TODO(b/143359391): Respect TCP Min and Max size. return uint16(r.MTU() - header.TCPMinimumSize) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 6d808328c..126f26ed3 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -474,6 +474,130 @@ func TestSimpleReceive(t *testing.T) { ) } +// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when +// creating a new active IPv4 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV4(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.Create(-1) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv4 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when +// creating a new active IPv6 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV6(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv6 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + func TestTOSV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() -- cgit v1.2.3 From d0d89ceeddd21f1f22e818d78dc3b07d3669dbb5 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 6 Nov 2019 10:42:00 -0800 Subject: Send a TCP RST in response to a TCP SYN-ACK on a listening endpoint This change better follows what is outlined in RFC 793 section 3.4 figure 12 where a listening socket should not accept a SYN-ACK segment in response to a (potentially) old SYN segment. Tests: Test that checks the TCP RST segment sent in response to a TCP SYN-ACK segment received on a listening TCP endpoint. PiperOrigin-RevId: 278893114 --- pkg/tcpip/transport/tcp/accept.go | 9 ++++++ pkg/tcpip/transport/tcp/segment.go | 10 +++++-- pkg/tcpip/transport/tcp/tcp_test.go | 56 +++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 2 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 1dd00d026..cb0e13ebc 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -400,6 +401,14 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { + if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) { + // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST + // must be sent in response to a SYN-ACK while in the listen + // state to prevent completing a handshake from an old SYN. + e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil) + return + } + // TODO(b/143300739): Use the userMSS of the listening socket // for accepted sockets. diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index ea725d513..c4a89525e 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -99,8 +99,14 @@ func (s *segment) clone() *segment { return t } -func (s *segment) flagIsSet(flag uint8) bool { - return (s.flags & flag) != 0 +// flagIsSet checks if at least one flag in flags is set in s.flags. +func (s *segment) flagIsSet(flags uint8) bool { + return s.flags&flags != 0 +} + +// flagsAreSet checks if all flags in flags are set in s.flags. +func (s *segment) flagsAreSet(flags uint8) bool { + return s.flags&flags == flags } func (s *segment) decRef() { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 126f26ed3..beaa40210 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -598,6 +598,62 @@ func TestUserSuppliedMSSOnConnectV6(t *testing.T) { } } +func TestSendRstOnListenerRxSynAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxSynAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + func TestTOSV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() -- cgit v1.2.3 From 66ebb6575f929a389d3c929977ed5e31d706fcfe Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 7 Nov 2019 09:45:26 -0800 Subject: Add support for TIME_WAIT timeout. This change adds explicit support for honoring the 2MSL timeout for sockets in TIME_WAIT state. It also adds support for the TCP_LINGER2 option that allows modification of the FIN_WAIT2 state timeout duration for a given socket. It also adds an option to modify the Stack wide TIME_WAIT timeout but this is only for testing. On Linux this is fixed at 60s. Further, we also now correctly process RST's in CLOSE_WAIT and close the socket similar to linux without moving it to error state. We also now handle SYN in ESTABLISHED state as per RFC5961#section-4.1. Earlier we would just drop these SYNs. Which can result in some tests that pass on linux to fail on gVisor. Netstack now honors TIME_WAIT correctly as well as handles the following cases correctly. - TCP RSTs in TIME_WAIT are ignored. - A duplicate TCP FIN during TIME_WAIT extends the TIME_WAIT and a dup ACK is sent in response to the FIN as the dup FIN indicates potential loss of the original final ACK. - An out of order segment during TIME_WAIT generates a dup ACK. - A new SYN w/ a sequence number > the highest sequence number in the previous connection closes the TIME_WAIT early and opens a new connection. Further to make the SYN case work correctly the ISN (Initial Sequence Number) generation for Netstack has been updated to be as per RFC. Its not a pure random number anymore and follows the recommendation in https://tools.ietf.org/html/rfc6528#page-3. The current hash used is not a cryptographically secure hash function. A separate change will update the hash function used to Siphash similar to what is used in Linux. PiperOrigin-RevId: 279106406 --- pkg/sentry/socket/netstack/netstack.go | 20 + pkg/tcpip/adapters/gonet/gonet_test.go | 12 +- pkg/tcpip/stack/stack.go | 20 +- pkg/tcpip/stack/transport_demuxer.go | 33 +- pkg/tcpip/tcpip.go | 12 +- pkg/tcpip/transport/tcp/BUILD | 2 +- pkg/tcpip/transport/tcp/accept.go | 17 +- pkg/tcpip/transport/tcp/connect.go | 322 ++++++++++++-- pkg/tcpip/transport/tcp/endpoint.go | 101 ++++- pkg/tcpip/transport/tcp/endpoint_state.go | 26 +- pkg/tcpip/transport/tcp/protocol.go | 43 ++ pkg/tcpip/transport/tcp/rcv.go | 167 ++++++- pkg/tcpip/transport/tcp/tcp_test.go | 622 ++++++++++++++++++++++++++- test/syscalls/BUILD | 22 +- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/socket_inet_loopback.cc | 336 +++++++++++++++ test/syscalls/linux/socket_ip_tcp_generic.cc | 93 ++++ 17 files changed, 1736 insertions(+), 113 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 27c6692c4..d92399efd 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1173,6 +1173,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa copy(b, v) return b, nil + case linux.TCP_LINGER2: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPLingerTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + default: emitUnimplementedEventTCP(t, name) } @@ -1556,6 +1568,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return nil + case linux.TCP_LINGER2: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 8ced960bb..ee077ae83 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -151,10 +151,8 @@ func TestCloseReader(t *testing.T) { buf := make([]byte, 256) n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want) + if n != 0 || err != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) } }() sender, err := connect(s, addr) @@ -203,10 +201,8 @@ func TestCloseReaderWithForwarder(t *testing.T) { buf := make([]byte, 256) n, e := c.Read(buf) - got, ok := e.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want) + if n != 0 || e != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) } }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 99809df75..2f8d8e822 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -402,11 +402,11 @@ type Stack struct { // by the stack. icmpRateLimiter *ICMPRateLimiter - // portSeed is a one-time random value initialized at stack startup + // seed is a one-time random value initialized at stack startup // and is used to seed the TCP port picking on active connections // // TODO(gvisor.dev/issue/940): S/R this field. - portSeed uint32 + seed uint32 // ndpConfigs is the default NDP configurations used by interfaces. ndpConfigs NDPConfigurations @@ -544,7 +544,7 @@ func New(opts Options) *Stack { stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, icmpRateLimiter: NewICMPRateLimiter(), - portSeed: generateRandUint32(), + seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, uniqueIDGenerator: opts.UniqueID, @@ -1186,6 +1186,12 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { s.mu.Unlock() } +// FindTransportEndpoint finds an endpoint that most closely matches the provided +// id. If no endpoint is found it returns nil. +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, r) +} + // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. @@ -1573,12 +1579,12 @@ func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRoute return nil } -// PortSeed returns a 32 bit value that can be used as a seed value for port -// picking. +// Seed returns a 32 bit value that can be used as a seed value for port +// picking, ISN generation etc. // // NOTE: The seed is generated once during stack initialization only. -func (s *Stack) PortSeed() uint32 { - return s.portSeed +func (s *Stack) Seed() uint32 { + return s.seed } func generateRandUint32() uint32 { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 594570216..cb805522b 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -103,7 +103,6 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return } - // multiPortEndpoints are guaranteed to have at least one element. selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. @@ -507,10 +506,40 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr if ep, ok := eps.endpoints[nid]; ok { matchedEPs = append(matchedEPs, ep) } - return matchedEPs } +// findTransportEndpoint find a single endpoint that most closely matches the provided id. +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] + if !ok { + return nil + } + // Try to find the endpoint. + eps.mu.RLock() + epsByNic := d.findEndpointLocked(eps, id) + // Fail if we didn't find one. + if epsByNic == nil { + eps.mu.RUnlock() + return nil + } + + epsByNic.mu.RLock() + eps.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return nil + } + } + + ep := selectEndpoint(id, mpep, epsByNic.seed) + epsByNic.mu.RUnlock() + return ep +} + // findEndpointLocked returns the endpoint that most closely matches the given // id. func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3edb513d4..bd5eb89ca 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -586,6 +586,16 @@ type MaxSegOption int // A zero value indicates the default. type TTLOption uint8 +// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state +// before being marked closed. +type TCPLingerTimeoutOption time.Duration + +// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TIME_WAIT state +// before being marked closed. +type TCPTimeWaitTimeoutOption time.Duration + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 @@ -1329,8 +1339,8 @@ var ( // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { - es := make([]Endpoint, 0, len(danglingEndpoints)) danglingEndpointsMu.Lock() + es := make([]Endpoint, 0, len(danglingEndpoints)) for e := range danglingEndpoints { es = append(es, e) } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index f1dbc6f91..3f47b328d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -71,7 +71,7 @@ filegroup( go_test( name = "tcp_test", - size = "small", + size = "medium", srcs = [ "dual_stack_test.go", "sack_scoreboard_test.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index cb0e13ebc..0e8e0a2b4 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -269,8 +269,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) + isn := generateSecureISN(s.id, l.stack.Seed()) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts) if err != nil { return nil, err } @@ -289,7 +289,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - h.resetToSynRcvd(cookie, irs, opts) + h.resetToSynRcvd(isn, irs, opts) if err := h.execute(); err != nil { ep.Close() if l.listenEP != nil { @@ -361,6 +361,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer decSynRcvdCount() defer e.decSynRcvdCount() defer s.decRef() + n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -368,6 +369,11 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header return } ctx.removePendingEndpoint(n) + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(n) } @@ -543,6 +549,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // number of goroutines as we do check before // entering here that there was at least some // space available in the backlog. + + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index ca982c451..a114c06c1 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "sync" "time" @@ -22,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -139,7 +141,32 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) +} + +// generateSecureISN generates a secure Initial Sequence number based on the +// recommendation here https://tools.ietf.org/html/rfc6528#page-3. +func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { + isnHasher := jenkins.Sum32(seed) + isnHasher.Write([]byte(id.LocalAddress)) + isnHasher.Write([]byte(id.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, id.LocalPort) + isnHasher.Write(portBuf) + binary.LittleEndian.PutUint16(portBuf, id.RemotePort) + isnHasher.Write(portBuf) + // The time period here is 64ns. This is similar to what linux uses + // generate a sequence number that overlaps less than one + // time per MSL (2 minutes). + // + // A 64ns clock ticks 10^9/64 = 15625000) times in a second. + // To wrap the whole 32 bit space would require + // 2^32/1562500 ~ 274 seconds. + // + // Which sort of guarantees that we won't reuse the ISN for a new + // connection for the same tuple for at least 274s. + isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + return seqnum.Value(isn) } // effectiveRcvWndScale returns the effective receive window scale to be used. @@ -809,7 +836,19 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { e.state = StateError e.HardError = err if err != tcpip.ErrConnectionReset { - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + // The exact sequence number to be used for the RST is the same as the + // one used by Linux. We need to handle the case of window being shrunk + // which can cause sndNxt to be outside the acceptable window on the + // receiver. + // + // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more + // information. + sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + resetSeqNum := sndWndEnd + if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1< + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. @@ -856,7 +960,15 @@ func (e *endpoint) handleSegments() *tcpip.Error { // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment // information." - e.rcv.handleRcvdSegment(s) + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + s.decRef() + return err + } + if drop { + s.decRef() + continue + } e.snd.handleRcvdSegment(s) } s.decRef() @@ -955,7 +1067,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() - // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1001,6 +1112,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // RTT itself. e.rcvAutoParams.prevCopied = initialRcvWnd e.rcvListMu.Unlock() + e.stack.Stats().TCP.CurrentEstablished.Increment() + e.mu.Lock() + e.state = StateEstablished + e.mu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) @@ -1008,10 +1123,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - if e.state != StateEstablished { - e.stack.Stats().TCP.CurrentEstablished.Increment() - e.state = StateEstablished - } drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -1042,7 +1153,13 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { { w: &closeWaker, f: func() *tcpip.Error { - return tcpip.ErrConnectionAborted + // This means the socket is being closed due + // to the TCP_FIN_WAIT2 timeout was hit. Just + // mark the socket as closed. + e.mu.Lock() + e.state = StateClose + e.mu.Unlock() + return nil }, }, { @@ -1085,17 +1202,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() } + if n¬ifyClose != 0 && closeTimer == nil { - // Reset the connection 3 seconds after - // the endpoint has been closed. - // - // The timer could fire in background - // when the endpoint is drained. That's - // OK as the loop here will not honor - // the firing until the undrain arrives. - closeTimer = time.AfterFunc(3*time.Second, func() { - closeWaker.Assert() - }) + e.mu.Lock() + if e.state == StateFinWait2 && e.closed { + // The socket has been closed and we are in FIN_WAIT2 + // so start the FIN_WAIT2 timer. + closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { + closeWaker.Assert() + }) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + e.mu.Unlock() } if n¬ifyKeepaliveChanged != 0 { @@ -1117,6 +1235,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } } + if n¬ifyTickleWorker != 0 { + // Just a tickle notification. No need to do + // anything. + return nil + } + return nil }, }, @@ -1143,15 +1267,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.rcvListMu.Unlock() - e.mu.RLock() + e.mu.Lock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } - e.mu.RUnlock() // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + + for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() @@ -1167,6 +1292,23 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil } + e.mu.Lock() + } + + state := e.state + e.mu.Unlock() + var reuseTW func() + if state == StateTimeWait { + // Disable close timer as we now entering real TIME_WAIT. + if closeTimer != nil { + closeTimer.Stop() + } + // Mark the current sleeper done so as to free all associated + // wakers. + s.Done() + // Wake up any waiters before we enter TIME_WAIT. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + reuseTW = e.doTimeWait() } // Mark endpoint as closed. @@ -1176,8 +1318,130 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateClose } + // Lock released below. epilogue() + // A new SYN was received during TIME_WAIT and we need to abort + // the timewait and redirect the segment to the listener queue + if reuseTW != nil { + reuseTW() + } + return nil } + +// handleTimeWaitSegments processes segments received during TIME_WAIT +// state. +func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + extTW, newSyn := e.rcv.handleTimeWaitSegment(s) + if newSyn { + info := e.EndpointInfo.TransportEndpointInfo + newID := info.ID + newID.RemoteAddress = "" + newID.RemotePort = 0 + netProtos := []tcpip.NetworkProtocolNumber{info.NetProto} + // If the local address is an IPv4 address then also + // look for IPv6 dual stack endpoints that might be + // listening on the local address. + if newID.LocalAddress.To4() != "" { + netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} + } + for _, netProto := range netProtos { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + tcpEP := listenEP.(*endpoint) + if EndpointState(tcpEP.State()) == StateListen { + reuseTW = func() { + tcpEP.enqueueSegment(s) + } + // We explicitly do not decRef + // the segment as it's still + // valid and being reflected to + // a listening endpoint. + return false, reuseTW + } + } + } + } + if extTW { + extendTimeWait = true + } + s.decRef() + } + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + return extendTimeWait, nil +} + +// doTimeWait is responsible for handling the TCP behaviour once a socket +// enters the TIME_WAIT state. Optionally it can return a closure that +// should be executed after releasing the endpoint registrations. This is +// done in cases where a new SYN is received during TIME_WAIT that carries +// a sequence number larger than one see on the connection. +func (e *endpoint) doTimeWait() (twReuse func()) { + // Trigger a 2 * MSL time wait state. During this period + // we will drop all incoming segments. + // NOTE: On Linux this is not configurable and is fixed at 60 seconds. + timeWaitDuration := DefaultTCPTimeWaitTimeout + + // Get the stack wide configuration. + var tcpTW tcpip.TCPTimeWaitTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil { + timeWaitDuration = time.Duration(tcpTW) + } + + const newSegment = 1 + const notification = 2 + const timeWaitDone = 3 + + s := sleep.Sleeper{} + s.AddWaker(&e.newSegmentWaker, newSegment) + s.AddWaker(&e.notificationWaker, notification) + + var timeWaitWaker sleep.Waker + s.AddWaker(&timeWaitWaker, timeWaitDone) + timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + defer timeWaitTimer.Stop() + + for { + e.workMu.Unlock() + v, _ := s.Fetch(true) + e.workMu.Lock() + switch v { + case newSegment: + extendTimeWait, reuseTW := e.handleTimeWaitSegments() + if reuseTW != nil { + return reuseTW + } + if extendTimeWait { + timeWaitTimer.Reset(timeWaitDuration) + } + case notification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + // Ignore extending TIME_WAIT during a + // save. For sockets in TIME_WAIT we just + // terminate the TIME_WAIT early. + e.handleTimeWaitSegments() + } + close(e.drainDone) + <-e.undrain + return nil + } + case timeWaitDone: + return nil + } + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 79fec6b77..04c92c04c 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -121,6 +121,11 @@ const ( notifyReset notifyKeepaliveChanged notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -320,6 +325,11 @@ type endpoint struct { state EndpointState `state:".(EndpointState)"` + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` @@ -503,6 +513,16 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -599,6 +619,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.SetSockOptInt(tcpip.DelayOption, 1) } + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -686,6 +711,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { + e.mu.Lock() + closed := e.closed + e.mu.Unlock() + if closed { + return + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -706,6 +738,8 @@ func (e *endpoint) Close() { e.isPortReserved = false } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) @@ -731,9 +765,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { go func() { defer close(done) for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() + n.notifyProtocolGoroutine(notifyReset) n.Close() } }() @@ -1349,6 +1381,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { + // We were unable to retrieve a stack config, just use + // the DefaultTCPLingerTimeout. + if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { + stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) + } + } + // Cap it to the stack wide TCPLinger timeout. + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1562,6 +1616,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.RUnlock() return nil + case *tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -1696,7 +1756,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.PortSeed()) + h := jenkins.Sum32(e.stack.Seed()) h.Write([]byte(e.ID.LocalAddress)) h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) @@ -1782,9 +1842,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.mu.Lock() - defer e.mu.Unlock() e.shutdownFlags |= flags - + finQueued := false switch { case e.state.connected(): // Close for read. @@ -1799,6 +1858,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { e.notifyProtocolGoroutine(notifyReset) + e.mu.Unlock() return nil } } @@ -1817,14 +1877,11 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - + finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() - - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() } case e.state == StateListen: @@ -1832,11 +1889,20 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } - default: + e.mu.Unlock() return tcpip.ErrNotConnected } - + e.mu.Unlock() + if finQueued { + if e.workMu.TryLock() { + e.handleClose() + e.workMu.Unlock() + } else { + // Tell protocol goroutine to close. + e.sndCloseWaker.Assert() + } + } return nil } @@ -1928,12 +1994,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrWouldBlock } - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - return n, wq, nil + return n, n.waiterQueue, nil } // Bind binds the endpoint to a specific local port and optionally address. @@ -2058,6 +2119,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.stack.Stats().TCP.ResetsReceived.Increment() } + e.enqueueSegment(s) +} + +func (e *endpoint) enqueueSegment(s *segment) { // Send packet to worker goroutine. if e.segmentQueue.enqueue(s) { e.newSegmentWaker.Assert() diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 19f003b6b..7aa4c3f0e 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() { } fallthrough case StateError, StateClose: - for e.state == StateError && e.workerRunning { + for (e.state == StateError || e.state == StateClose) && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { + // Freeze segment queue before registering to prevent any segments + // from being delivered while it is being restored. + e.origEndpointState = e.state + // Restore the endpoint to InitialState as it will be moved to + // its origEndpointState during Resume. + e.state = StateInitial stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() + state := e.origEndpointState - state := e.state switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -189,7 +195,6 @@ func (e *endpoint) Resume(s *stack.Stack) { } bind := func() { - e.state = StateInitial if len(e.BindAddr) == 0 { e.BindAddr = e.ID.LocalAddress } @@ -219,6 +224,16 @@ func (e *endpoint) Resume(s *stack.Stack) { if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } + e.mu.Lock() + e.state = e.origEndpointState + closed := e.closed + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) + if state == StateFinWait2 && closed { + // If the endpoint has been closed then make sure we notify so + // that the FIN_WAIT2 timer is started after a restore. + e.notifyProtocolGoroutine(notifyClose) + } connectedLoading.Done() case StateListen: tcpip.AsyncLoading.Add(1) @@ -265,8 +280,11 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() } - fallthrough + e.state = StateClose + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) case StateError: + e.state = StateError e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c8e4a0d7e..89b965c23 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -23,6 +23,7 @@ package tcp import ( "strings" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -54,6 +55,14 @@ const ( // MaxUnprocessedSegments is the maximum number of unprocessed segments // that can be queued for a given endpoint. MaxUnprocessedSegments = 300 + + // DefaultTCPLingerTimeout is the amount of time that sockets linger in + // FIN_WAIT_2 state before being marked closed. + DefaultTCPLingerTimeout = 60 * time.Second + + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger + // in TIME_WAIT state before being marked closed. + DefaultTCPTimeWaitTimeout = 60 * time.Second ) // SACKEnabled option can be used to enable SACK support in the TCP @@ -93,6 +102,8 @@ type protocol struct { congestionControl string availableCongestionControl []string moderateReceiveBuffer bool + tcpLingerTimeout time.Duration + tcpTimeWaitTimeout time.Duration } // Number returns the tcp protocol number. @@ -212,6 +223,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpLingerTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpTimeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -262,6 +291,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *tcpip.TCPLingerTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + p.mu.Unlock() + return nil + + case *tcpip.TCPTimeWaitTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -274,5 +315,7 @@ func NewProtocol() stack.TransportProtocol { recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, + tcpLingerTimeout: DefaultTCPLingerTimeout, + tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index e90f9a7d9..068b90fb6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -18,6 +18,7 @@ import ( "container/heap" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -209,6 +210,11 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum switch r.ep.state { case StateFinWait1: r.ep.state = StateFinWait2 + // Notify protocol goroutine that we have received an + // ACK to our FIN so that it can start the FIN_WAIT2 + // timer to abort connection if the other side does + // not close within 2MSL. + r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: r.ep.state = StateTimeWait case StateLastAck: @@ -253,23 +259,105 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -// handleRcvdSegment handles TCP segments directed at the connection managed by -// r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { + r.ep.rcvListMu.Lock() + rcvClosed := r.ep.rcvClosed || r.closed + r.ep.rcvListMu.Unlock() + + // If we are in one of the shutdown states then we need to do + // additional checks before we try and process the segment. + switch state { + case StateCloseWait, StateClosing, StateLastAck: + if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + s.decRef() + // Just drop the segment as we have + // already received a FIN and this + // segment is after the sequence number + // for the FIN. + return true, nil + } + fallthrough + case StateFinWait1: + fallthrough + case StateFinWait2: + // If we are closed for reads (either due to an + // incoming FIN or the user calling shutdown(.., + // SHUT_RD) then any data past the rcvNxt should + // trigger a RST. + endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + if state == StateFinWait1 { + break + } + + // If it's a retransmission of an old data segment + // or a pure ACK then allow it. + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + s.logicalLen() == 0 { + break + } + + // In FIN-WAIT2 if the socket is fully + // closed(not owned by application on our end + // then the only acceptable segment is a + // FIN. Since FIN can technically also carry + // data we verify that the segment carrying a + // FIN ends at exactly e.rcvNxt+1. + // + // From RFC793 page 25. + // + // For sequence number purposes, the SYN is + // considered to occur before the first actual + // data octet of the segment in which it occurs, + // while the FIN is considered to occur after + // the last actual data octet in a segment in + // which it occurs. + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + } + // We don't care about receive processing anymore if the receive side // is closed. - if r.closed { - return + // + // NOTE: We still want to permit a FIN as it's possible only our + // end has closed and the peer is yet to send a FIN. Hence we + // compare only the payload. + segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + return true, nil + } + return false, nil +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { + r.ep.mu.RLock() + state := r.ep.state + closed := r.ep.closed + r.ep.mu.RUnlock() + + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } } segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber // If the sequence number range is outside the acceptable range, just - // send an ACK. This is according to RFC 793, page 37. + // send an ACK and stop further processing of the segment. + // This is according to RFC 793, page 68. if !r.acceptable(segSeq, segLen) { r.ep.snd.sendAck() - return + return true, nil } // Defer segment processing if it can't be consumed now. @@ -288,7 +376,7 @@ func (r *receiver) handleRcvdSegment(s *segment) { // have to retransmit. r.ep.snd.sendAck() } - return + return false, nil } // Since we consumed a segment update the receiver's RTT estimate @@ -315,4 +403,67 @@ func (r *receiver) handleRcvdSegment(s *segment) { r.pendingBufUsed -= s.logicalLen() s.decRef() } + return false, nil +} + +// handleTimeWaitSegment handles inbound segments received when the endpoint +// has entered the TIME_WAIT state. +func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) { + segSeq := s.sequenceNumber + segLen := seqnum.Size(s.data.Size()) + + // Just silently drop any RST packets in TIME_WAIT. We do not support + // TIME_WAIT assasination as a result we confirm w/ fix 1 as described + // in https://tools.ietf.org/html/rfc1337#section-3. + if s.flagIsSet(header.TCPFlagRst) { + return false, false + } + + // If it's a SYN and the sequence number is higher than any seen before + // for this connection then try and redirect it to a listening endpoint + // if available. + // + // RFC 1122: + // "When a connection is [...] on TIME-WAIT state [...] + // [a TCP] MAY accept a new SYN from the remote TCP to + // reopen the connection directly, if it: + + // (1) assigns its initial sequence number for the new + // connection to be larger than the largest sequence + // number it used on the previous connection incarnation, + // and + + // (2) returns to TIME-WAIT state if the SYN turns out + // to be an old duplicate". + if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + + return false, true + } + + // Drop the segment if it does not contain an ACK. + if !s.flagIsSet(header.TCPFlagAck) { + return false, false + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if r.ep.sendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + } + + if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + // If it's a FIN-ACK then resetTimeWait and send an ACK, as it + // indicates our final ACK could have been lost. + r.ep.snd.sendAck() + return true, false + } + + // If the sequence number range is outside the acceptable range or + // carries data then just send an ACK. This is according to RFC 793, + // page 37. + // + // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. + if segSeq != r.rcvNxt || segLen != 0 { + r.ep.snd.sendAck() + } + return false, false } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index f4ea5f091..0c1704d74 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -206,17 +206,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send a SYN request. @@ -256,7 +257,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -264,6 +265,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { } } + // Lower stackwide TIME_WAIT timeout so that the reservations + // are released instantly on Close. + tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) + } + c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), @@ -285,6 +293,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Get the ACK to the FIN we just sent. c.GetPacket() + // Since an active close was done we need to wait for a little more than + // tcpLingerTimeout for the port reservations to be released and the + // socket to move to a CLOSED state. + time.Sleep(20 * time.Millisecond) + // Now resend the same ACK, this ACK should generate a RST as there // should be no endpoint in SYN-RCVD state and we are not using // syn-cookies yet. The reason we send the same ACK is we need a valid @@ -376,6 +389,13 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLinger to 3 seconds so that sockets are marked closed + // after 3 second in FIN_WAIT2 state. + tcpLingerTimeout := 3 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -396,12 +416,24 @@ func TestConnectResetAfterClose(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1), + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN. + time.Sleep(tcpLingerTimeout + 1*time.Second) + + // Now send an ACK and it should trigger a RST as the endpoint should + // not exist anymore. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), RcvWnd: 30000, }) - // Wait for the ep to give up waiting for a FIN, and send a RST. - time.Sleep(3 * time.Second) for { b := c.GetPacket() tcpHdr := header.TCP(header.IPv4(b).Payload()) @@ -413,7 +445,7 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), checker.AckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), ), @@ -1110,8 +1142,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -3085,6 +3116,13 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) @@ -3092,12 +3130,12 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock) } // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -3135,10 +3173,9 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. Note that since - // both the sender and receiver are now closed, we effectively skip the - // TIME-WAIT state. - time.Sleep(1 * time.Second) + // Give the stack the chance to transition to closed state from + // TIME_WAIT. + time.Sleep(tcpTimeWaitTimeout * 2) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -3155,7 +3192,7 @@ func TestReadAfterClosedState(t *testing.T) { peekBuf := make([]byte, 10) n, _, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { - t.Fatalf("Peek failed: %v", err) + t.Fatalf("Peek failed: %s", err) } peekBuf = peekBuf[:n] @@ -3166,7 +3203,7 @@ func TestReadAfterClosedState(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -3176,11 +3213,11 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive) } if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -4347,7 +4384,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Send a SYN request. irs := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, + // pick a different src port for new SYN. + SrcPort: context.TestPort + 1, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: irs, @@ -4893,3 +4931,545 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) } } + +func TestTCPLingerTimeout(t *testing.T) { + c := context.New(t, 1500 /* mtu */) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + testCases := []struct { + name string + tcpLingerTimeout time.Duration + want time.Duration + }{ + {"NegativeLingerTimeout", -123123, 0}, + {"ZeroLingerTimeout", 0, 0}, + {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, + // Values > stack's TCPLingerTimeout are capped to the stack's + // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) + {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { + t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + } + var v tcpip.TCPLingerTimeoutOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + } + if got, want := time.Duration(v), tc.want; got != want { + t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + } + }) + } +} + +func TestTCPTimeWaitRSTIgnored(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now send a RST and this should be ignored and not + // generate an ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + }) + + c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitOutOfOrder(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitNewSyn(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Send a SYN request w/ sequence number lower than + // the highest sequence number sent. We just reuse + // the same number. + iss = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + + // Send a SYN request w/ sequence number higher than + // the highest sequence number sent. + iss = seqnum.Value(792) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b = c.GetPacket() + tcpHdr = header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } +} + +func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + time.Sleep(2 * time.Second) + + // Now send a duplicate FIN. This should cause the TIME_WAIT to extend + // by another 5 seconds and also send us a duplicate ACK as it should + // indicate that the final ACK was potentially lost. + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Sleep for 4 seconds so at this point we are 1 second past the + // original tcpLingerTimeout of 5 seconds. + time.Sleep(4 * time.Second) + + // Send an ACK and it should not generate any packet as the socket + // should still be in TIME_WAIT for another another 5 seconds due + // to the duplicate FIN we sent earlier. + *ackHeaders = *finHeaders + ackHeaders.SeqNum = ackHeaders.SeqNum + 1 + ackHeaders.Flags = header.TCPFlagAck + c.SendPacket(nil, ackHeaders) + + c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) + // Now sleep for another 2 seconds so that we are past the + // extended TIME_WAIT of 7 seconds (2 + 5). + time.Sleep(2 * time.Second) + + // Resend the same ACK. + c.SendPacket(nil, ackHeaders) + + // Receive the RST that should be generated as there is no valid + // endpoint. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(uint32(ackHeaders.SeqNum)), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) +} diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 3e5b6b3c3..722d14b53 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -9,7 +9,7 @@ syscall_test(test = "//test/syscalls/linux:accept_bind_stream_test") syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:accept_bind_test", ) @@ -434,7 +434,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_abstract_test", ) @@ -445,7 +445,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_domain_test", ) @@ -458,19 +458,19 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_filesystem_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_inet_loopback_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test", ) @@ -481,13 +481,13 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_loopback_test", ) syscall_test( size = "medium", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test", ) @@ -498,7 +498,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_udp_loopback_test", ) @@ -560,7 +560,7 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_pair_test", ) @@ -599,7 +599,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_unbound_stream_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 93bff8299..f8b8cb724 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2141,6 +2141,7 @@ cc_library( deps = [ ":socket_test_util", "//test/util:test_util", + "//test/util:thread_util", "@com_google_googletest//:gtest", ], alwayslink = 1, diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index ab375aaaf..2eeee352e 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -31,6 +32,7 @@ #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" @@ -267,6 +269,340 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog) { } } +// TCPFinWait2Test creates a pair of connected sockets then closes one end to +// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local +// IP/port on a new socket and tries to connect. The connect should fail w/ +// an EADDRINUSE. Then we wait till the FIN_WAIT2 timeout is over and try the +// connect again with a new socket and this time it should succeed. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Lower FIN_WAIT2 state to 5 seconds for test. + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + // Now bind and connect a new socket. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Disable cooperative saves after this point. As a save between the first + // bind/connect and the second one can cause the linger timeout timer to + // be restarted causing the final bind/connect to fail. + DisableSave ds; + + // TODO(gvisor.dev/issue/1030): Portmanager does not track all 5 tuple + // reservations which causes the bind() to succeed on gVisor but connect + // correctly fails. + if (IsRunningOnGvisor()) { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast(&conn_bound_addr), + conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + } else { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast(&conn_bound_addr), + conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + } + + // Sleep for a little over the linger timeout to reduce flakiness in + // save/restore tests. + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + ds.reset(); + + if (!IsRunningOnGvisor()) { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast(&conn_bound_addr), + conn_addrlen), + SyscallSucceeds()); + } + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPLinger2TimeoutAfterClose creates a pair of connected sockets +// then closes one end to trigger FIN_WAIT2 state for the closed endpont. +// It then sleeps for the TCP_LINGER2 timeout and verifies that bind/ +// connecting the same address succeeds. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPResetAfterClose creates a pair of connected sockets then closes +// one end to trigger FIN_WAIT2 state for the closed endpoint verifies +// that we generate RSTs for any new data after the socket is fully +// closed. +TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + int data = 1234; + + // Now send data which should trigger a RST as the other end should + // have timed out and closed the socket. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceeds()); + // Sleep for a shortwhile to get a RST back. + absl::SleepFor(absl::Seconds(1)); + + // Try writing again and we should get an EPIPE back. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallFailsWithErrno(EPIPE)); + + // Trying to read should return zero as the other end did send + // us a FIN. We do it twice to verify that the RST does not cause an + // ECONNRESET on the read after EOF has been read by applicaiton. + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); +} + +// This test is disabled under random save as the the restore run +// results in the stack.Seed() being different which can cause +// sequence number of final connect to be one that is considered +// old and can cause the test to be flaky. +TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // We disable saves after this point as a S/R causes the netstack seed + // to be regenerated which changes what ports/ISN is picked for a given + // tuple (src ip,src port, dst ip, dst port). This can cause the final + // SYN to use a sequence number that looks like one from the current + // connection in TIME_WAIT and will not be accepted causing the test + // to timeout. + // + // TODO(gvisor.dev/issue/940): S/R portSeed/portHint + DisableSave ds; + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // close the accept FD to trigger TIME_WAIT on the accepted socket which + // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of + // TIME_WAIT. + accepted.reset(); + absl::SleepFor(absl::Seconds(1)); + conn_fd.reset(); + absl::SleepFor(absl::Seconds(1)); + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 592448289..a37b49447 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -26,6 +26,7 @@ #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { @@ -243,6 +244,31 @@ TEST_P(TCPSocketPairTest, ShutdownRdAllowsReadOfReceivedDataBeforeEOF) { SyscallSucceedsWithValue(0)); } +// This test verifies that a shutdown(wr) by the server after sending +// data allows the client to still read() the queued data and a client +// close after sending response allows server to read the incoming +// response. +TEST_P(TCPSocketPairTest, ShutdownWrServerClientClose) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + char buf[10] = {}; + ScopedThread t([&]() { + ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(close(sockets->release_first_fd()), + SyscallSucceedsWithValue(0)); + }); + ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(shutdown)(sockets->second_fd(), SHUT_WR), + SyscallSucceedsWithValue(0)); + t.Join(); + + ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); +} + TEST_P(TCPSocketPairTest, ClosedReadNonBlockingSocket) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -696,5 +722,72 @@ TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc))); } +// Linux and Netstack both default to a 60s TCP_LINGER2 timeout. +constexpr int kDefaultTCPLingerTimeout = 60; + +TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, + sizeof(kZero)), + SyscallSucceedsWithValue(0)); + + constexpr int kNegative = -1234; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kNegative, sizeof(kNegative)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kAboveDefault = kDefaultTCPLingerTimeout + 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kAboveDefault, sizeof(kAboveDefault)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kTCPLingerTimeout = kDefaultTCPLingerTimeout - 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPLingerTimeout); +} + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 773071680021a2fb985f3a3af7e9f65cdc1bd1ed Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Mon, 11 Nov 2019 14:13:50 -0800 Subject: Make `connect` on socket returned by `accept` correctly error out with EISCONN PiperOrigin-RevId: 279814493 --- pkg/tcpip/transport/tcp/accept.go | 2 ++ pkg/tcpip/transport/tcp/tcp_test.go | 3 +++ test/syscalls/linux/tcp_socket.cc | 13 +++++++++++++ 3 files changed, 18 insertions(+) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 0e8e0a2b4..f24b51b91 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -300,6 +300,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.mu.Lock() ep.stack.Stats().TCP.CurrentEstablished.Increment() ep.state = StateEstablished + ep.isConnectNotified = true ep.mu.Unlock() // Update the receive window scaling. We can't do it before the @@ -539,6 +540,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // Switch state to connected. n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished + n.isConnectNotified = true // Do the delivery in a separate goroutine so // that we don't block the listen loop in case diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 0c1704d74..84579ce52 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4599,6 +4599,9 @@ func TestEndpointBindListenAcceptState(t *testing.T) { if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { + t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected) + } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 277d6835a..bfc77ffc2 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -130,6 +130,19 @@ void TcpSocketTest::TearDown() { } } +TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) { + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + ASSERT_THAT( + connect(s_, reinterpret_cast(&addr), addrlen), + SyscallFailsWithErrno(EISCONN)); + ASSERT_THAT( + connect(t_, reinterpret_cast(&addr), addrlen), + SyscallFailsWithErrno(EISCONN)); +} + TEST_P(TcpSocketTest, DataCoalesced) { char buf[10]; -- cgit v1.2.3 From 3e534f2974f469a889534221b83c3bbbd1b0318c Mon Sep 17 00:00:00 2001 From: Mithun Iyer Date: Fri, 15 Nov 2019 11:44:02 -0800 Subject: Handle in-flight TCP segments when moving to CLOSE. As we move to CLOSE state from LAST-ACK or TIME-WAIT, ensure that we re-match all in-flight segments to any listening endpoint. Also fix LISTEN state handling of any ACK segments as per RFC793. Fixes #1153 PiperOrigin-RevId: 280703556 --- pkg/tcpip/transport/tcp/accept.go | 14 ++- pkg/tcpip/transport/tcp/connect.go | 60 +++++++++-- pkg/tcpip/transport/tcp/rcv.go | 2 +- pkg/tcpip/transport/tcp/tcp_test.go | 196 ++++++++++++++++++++++++++++++++++++ 4 files changed, 261 insertions(+), 11 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index f24b51b91..023045ec1 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -419,8 +419,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // TODO(b/143300739): Use the userMSS of the listening socket // for accepted sockets. - switch s.flags { - case header.TCPFlagSyn: + switch { + case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) if incSynRcvdCount() { // Only handle the syn if the following conditions hold @@ -464,7 +464,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } - case header.TCPFlagAck: + case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -478,6 +478,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { } if !synCookiesInUse() { + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // + // // Send a reset as this is an ACK for which there is no // half open connections and we are not using cookies // yet. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 49f2b9685..364067731 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -865,6 +865,33 @@ func (e *endpoint) completeWorkerLocked() { } } +// transitionToStateCloseLocked ensures that the endpoint is +// cleaned up from the transport demuxer, "before" moving to +// StateClose. This will ensure that no packet will be +// delivered to this endpoint from the demuxer when the endpoint +// is transitioned to StateClose. +func (e *endpoint) transitionToStateCloseLocked() { + if e.state == StateClose { + return + } + e.cleanupLocked() + e.state = StateClose +} + +// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed +// segment to any other endpoint other than the current one. This is called +// only when the endpoint is in StateClose and we want to deliver the segment +// to any other listening endpoint. We reply with RST if we cannot find one. +func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + if ep == nil { + replyWithReset(s) + s.decRef() + return + } + ep.(*endpoint).enqueueSegment(s) +} + func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { if e.rcv.acceptable(s.sequenceNumber, 0) { // RFC 793, page 37 states that "in all states @@ -894,12 +921,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // general "connection reset" signal. Enter the CLOSED state, // delete the TCB, and return. case StateCloseWait: - e.state = StateClose + e.transitionToStateCloseLocked() e.HardError = tcpip.ErrAborted - // We need to set this explicitly here because otherwise - // the port registrations will not be released till the - // endpoint is actively closed by the application. - e.workerCleanup = true e.mu.Unlock() return false, nil default: @@ -915,6 +938,20 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { func (e *endpoint) handleSegments() *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + return nil + } + s := e.segmentQueue.dequeue() if s == nil { checkRequeue = false @@ -1160,7 +1197,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // to the TCP_FIN_WAIT2 timeout was hit. Just // mark the socket as closed. e.mu.Lock() - e.state = StateClose + e.transitionToStateCloseLocked() e.mu.Unlock() return nil }, @@ -1321,12 +1358,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if e.state != StateError { e.stack.Stats().TCP.EstablishedResets.Increment() e.stack.Stats().TCP.CurrentEstablished.Decrement() - e.state = StateClose + e.transitionToStateCloseLocked() } // Lock released below. epilogue() + // epilogue removes the endpoint from the transport-demuxer and + // unlocks e.mu. Now that no new segments can get enqueued to this + // endpoint, try to re-match the segment to a different endpoint + // as the current endpoint is closed. + for !e.segmentQueue.empty() { + s := e.segmentQueue.dequeue() + e.tryDeliverSegmentFromClosedEndpoint(s) + } + // A new SYN was received during TIME_WAIT and we need to abort // the timewait and redirect the segment to the listener queue if reuseTW != nil { diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 068b90fb6..857dc445f 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -218,7 +218,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateClosing: r.ep.state = StateTimeWait case StateLastAck: - r.ep.state = StateClose + r.ep.transitionToStateCloseLocked() } r.ep.mu.Unlock() } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index b443fe9dc..64f765c70 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -454,6 +454,112 @@ func TestConnectResetAfterClose(t *testing.T) { } } +// TestClosingWithEnqueuedSegments tests handling of +// still enqueued segments when the endpoint transitions +// to StateClose. The in-flight segments would be re-enqueued +// to a any listening endpoint. +func TestClosingWithEnqueuedSegments(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Send a FIN for ESTABLISHED --> CLOSED-WAIT + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Get the ACK for the FIN we sent. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Close the application endpoint for CLOSE_WAIT --> LAST_ACK + ep.Close() + + // Get the FIN + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Pause the endpoint`s protocolMainLoop. + ep.(interface{ StopWork() }).StopWork() + + // Enqueue last ACK followed by an ACK matching the endpoint + // + // Send Last ACK for LAST_ACK --> CLOSED + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 791, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Send a packet with ACK set, this would generate RST when + // not using SYN cookies as in this test. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 792, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Unpause endpoint`s protocolMainLoop. + ep.(interface{ ResumeWork() }).ResumeWork() + + // Wait for the protocolMainLoop to resume and update state. + time.Sleep(1 * time.Millisecond) + + // Expect the endpoint to be closed. + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Check if the endpoint was moved to CLOSED and netstack a reset in + // response to the ACK packet that we sent after last-ACK. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(793), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + ), + ) +} + func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -686,6 +792,96 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) { checker.SeqNum(200))) } +func TestSendRstOnListenerRxAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true /* v6Only */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(200))) +} + +// TestListenShutdown tests for the listening endpoint not processing +// any receive when it is on read shutdown. +func TestListenShutdown(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatal("Shutdown failed:", err) + } + + // Wait for the endpoint state to be propagated. + time.Sleep(10 * time.Millisecond) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: 100, + AckNum: 200, + }) + + c.CheckNoPacket("Packet received when listening socket was shutdown") +} + func TestTOSV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() -- cgit v1.2.3 From 8eb68912e40bc87c932baeb13d151fd590d7d279 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Fri, 22 Nov 2019 14:56:54 -0800 Subject: Store SO_BINDTODEVICE state at bind. This allows us to ensure that the correct port reservation is released. Fixes #1217 PiperOrigin-RevId: 282048155 --- pkg/tcpip/transport/tcp/accept.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 26 ++++++++++++++++------- pkg/tcpip/transport/udp/endpoint.go | 35 ++++++++++++++++++++----------- pkg/tcpip/transport/udp/endpoint_state.go | 2 +- 4 files changed, 43 insertions(+), 22 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 023045ec1..f543a6105 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -243,7 +243,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { n.Close() return nil, err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 04c92c04c..9d4a87e30 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -340,6 +340,9 @@ type endpoint struct { // TCP should never broadcast but Linux nevertheless supports enabling/ // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // Values used to reserve a port or register a transport endpoint + // (which ever happens first). + boundBindToDevice tcpip.NICID // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -730,12 +733,13 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) e.isPortReserved = false + e.boundBindToDevice = 0 } // Mark endpoint as closed. @@ -791,14 +795,15 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) e.isPortReserved = false } + e.boundBindToDevice = 0 e.route.Release() e.stack.CompleteTransportEndpointCleanup(e) @@ -1741,7 +1746,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice) if err != nil { return err } @@ -1778,7 +1783,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc id.LocalPort = p switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: + // Port picking successful. Save the details of + // the selected port. e.ID = id + e.boundBindToDevice = e.bindToDevice return true, nil case tcpip.ErrPortInUse: return false, nil @@ -1794,7 +1802,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundBindToDevice) e.isPortReserved = false } @@ -1950,7 +1958,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil { return err } @@ -2031,6 +2039,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { return err } + e.boundBindToDevice = e.bindToDevice e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port @@ -2044,8 +2053,9 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalPort = 0 e.ID.LocalAddress = "" e.boundNICID = 0 + e.boundBindToDevice = 0 } - }(e.bindToDevice) + }(e.boundBindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 2d97d1398..23c1da717 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -104,6 +104,10 @@ type endpoint struct { bindToDevice tcpip.NICID broadcast bool + // Values used to reserve a port or register a transport endpoint. + // (which ever happens first). + boundBindToDevice tcpip.NICID + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 @@ -175,8 +179,9 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) + e.boundBindToDevice = 0 } for _, mem := range e.multicastMemberships { @@ -870,7 +875,10 @@ func (e *endpoint) Disconnect() *tcpip.Error { if e.state != StateConnected { return nil } - id := stack.TransportEndpointID{} + var ( + id stack.TransportEndpointID + btd tcpip.NICID + ) // Exclude ephemerally bound endpoints. if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error @@ -878,7 +886,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { LocalPort: e.ID.LocalPort, LocalAddress: e.ID.LocalAddress, } - id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { return err } @@ -886,13 +894,14 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.ID = id + e.boundBindToDevice = btd e.route.Release() e.route = stack.Route{} e.dstPort = 0 @@ -961,17 +970,18 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } } - id, err = e.registerWithStack(nicID, netProtos, id) + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err } // Remove the old registration. if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) } e.ID = id + e.boundBindToDevice = btd e.route = r.Clone() e.dstPort = addr.Port e.RegisterNICID = nicID @@ -1029,11 +1039,11 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) if err != nil { - return id, err + return id, e.bindToDevice, err } id.LocalPort = port } @@ -1042,7 +1052,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) } - return id, err + return id, e.bindToDevice, err } func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { @@ -1081,12 +1091,13 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { LocalPort: addr.Port, LocalAddress: addr.Addr, } - id, err = e.registerWithStack(nicID, netProtos, id) + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err } e.ID = id + e.boundBindToDevice = btd e.RegisterNICID = nicID e.effectiveNetProtos = netProtos diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index b227e353b..43fb047ed 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -109,7 +109,7 @@ func (e *endpoint) Resume(s *stack.Stack) { // pass it to the reservation machinery. id := e.ID e.ID.LocalPort = 0 - e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { panic(err) } -- cgit v1.2.3 From 3e84777d2e2a2b56c00487cd77aa8d2fc25bbb16 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Fri, 6 Dec 2019 15:45:01 -0800 Subject: Fix flakiness in tcp_test. This change marks the socket as ESTABLISHED and creates the receiver and sender the moment we send the final ACK in case of an active TCP handshake or when we receive the final ACK for a passive TCP handshake. Before this change there was a short window in which an ACK can be received and processed but the state on the socket is not yet ESTABLISHED. This can be seen in TestConnectBindToDevice which is flaky because sometimes the socket is in SYN-SENT and not ESTABLISHED even though the other side has already received the final ACK of the handshake. PiperOrigin-RevId: 284277713 --- pkg/tcpip/transport/tcp/accept.go | 4 +- pkg/tcpip/transport/tcp/connect.go | 52 ++++++++++++++-------- pkg/tcpip/transport/tcp/tcp_test.go | 2 +- pkg/tcpip/transport/tcp/testing/context/context.go | 3 ++ 4 files changed, 39 insertions(+), 22 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index f543a6105..74df3edfb 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -298,8 +298,6 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() - ep.stack.Stats().TCP.CurrentEstablished.Increment() - ep.state = StateEstablished ep.isConnectNotified = true ep.mu.Unlock() @@ -546,6 +544,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. + // We do not use transitionToStateEstablishedLocked here as there is + // no handshake state available when doing a SYN cookie based accept. n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished n.isConnectNotified = true diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 16f8aea12..2975a1c3c 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -252,6 +252,11 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // and the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { h.state = handshakeCompleted + + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil } @@ -352,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + return nil } @@ -880,6 +889,30 @@ func (e *endpoint) completeWorkerLocked() { } } +// transitionToStateEstablisedLocked transitions a given endpoint +// to an established state using the handshake parameters provided. +// It also initializes sender/receiver if required. +func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { + if e.snd == nil { + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + } + if e.rcv == nil { + rcvBufSize := seqnum.Size(e.receiveBufferSize()) + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + e.rcvAutoParams.prevCopied = int(h.rcvWnd) + e.rcvListMu.Unlock() + } + h.ep.stack.Stats().TCP.CurrentEstablished.Increment() + e.state = StateEstablished +} + // transitionToStateCloseLocked ensures that the endpoint is // cleaned up from the transport demuxer, "before" moving to // StateClose. This will ensure that no packet will be @@ -1156,25 +1189,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } - - // Transfer handshake state to TCP connection. We disable - // receive window scaling if the peer doesn't support it - // (indicated by a negative send window scale). - e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - - rcvBufSize := seqnum.Size(e.receiveBufferSize()) - e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) - // boot strap the auto tuning algorithm. Starting at zero will - // result in a large step function on the first proper causing - // the window to just go to a really large value after the first - // RTT itself. - e.rcvAutoParams.prevCopied = initialRcvWnd - e.rcvListMu.Unlock() - e.stack.Stats().TCP.CurrentEstablished.Increment() - e.mu.Lock() - e.state = StateEstablished - e.mu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index d1f0d6ce7..52c2fa7e3 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -541,7 +541,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { ep.(interface{ ResumeWork() }).ResumeWork() // Wait for the protocolMainLoop to resume and update state. - time.Sleep(1 * time.Millisecond) + time.Sleep(10 * time.Millisecond) // Expect the endpoint to be closed. if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6cb66c1af..b0a376eba 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -231,6 +231,7 @@ func (c *Context) CheckNoPacket(errMsg string) { // addresses. It will fail with an error if no packet is received for // 2 seconds. func (c *Context) GetPacket() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { @@ -259,6 +260,7 @@ func (c *Context) GetPacket() []byte { // and destination address. If no packet is available it will return // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { @@ -483,6 +485,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // GetV6Packet reads a single packet from the link layer endpoint of the context // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv6.ProtocolNumber { -- cgit v1.2.3 From 6fc9f0aefd89ce42ef2c38ea7853f9ba7c4bee04 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 11 Dec 2019 17:51:37 -0800 Subject: Add support for TCP_USER_TIMEOUT option. The implementation follows the linux behavior where specifying a TCP_USER_TIMEOUT will cause the resend timer to honor the user specified timeout rather than the default rto based timeout. Further it alters when connections are timedout due to keepalive failures. It does not alter the behavior of when keepalives are sent. This is as per the linux behavior. PiperOrigin-RevId: 285099795 --- pkg/sentry/socket/netstack/netstack.go | 23 ++++ pkg/tcpip/tcpip.go | 5 + pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/accept.go | 15 +++ pkg/tcpip/transport/tcp/connect.go | 19 ++- pkg/tcpip/transport/tcp/endpoint.go | 19 +++ pkg/tcpip/transport/tcp/protocol.go | 21 ++- pkg/tcpip/transport/tcp/rcv.go | 19 ++- pkg/tcpip/transport/tcp/rcv_state.go | 29 ++++ pkg/tcpip/transport/tcp/snd.go | 48 ++++++- pkg/tcpip/transport/tcp/snd_state.go | 10 ++ pkg/tcpip/transport/tcp/tcp_test.go | 194 ++++++++++++++++++++++++--- test/syscalls/linux/socket_inet_loopback.cc | 56 +++++++- test/syscalls/linux/socket_ip_tcp_generic.cc | 63 +++++++++ test/syscalls/linux/tcp_socket.cc | 25 ++++ 15 files changed, 509 insertions(+), 38 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/rcv_state.go (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index fe5a46aa3..8a6522eac 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1127,6 +1127,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_USER_TIMEOUT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPUserTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Millisecond), nil + case linux.TCP_INFO: var v tcpip.TCPInfoOption if err := ep.GetSockOpt(&v); err != nil { @@ -1563,6 +1575,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_USER_TIMEOUT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) if err := ep.SetSockOpt(v); err != nil { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index d5bb5b6ed..f62fd729f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -576,6 +576,11 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user +// specified timeout for a given TCP connection. +// See: RFC5482 for details. +type TCPUserTimeoutOption time.Duration + // CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get // the current congestion control algorithm. type CongestionControlOption string diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 455a1c098..3b353d56c 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -28,6 +28,7 @@ go_library( "forwarder.go", "protocol.go", "rcv.go", + "rcv_state.go", "reno.go", "sack.go", "sack_scoreboard.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 74df3edfb..5422ae80c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -242,6 +242,13 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() + // Now inherit any socket options that should be inherited from the + // listening endpoint. + // In case of Forwarder listenEP will be nil and hence this check. + if l.listenEP != nil { + l.listenEP.propagateInheritableOptions(n) + } + // Register new endpoint so that packets are routed to it. if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { n.Close() @@ -350,6 +357,14 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } } +// propagateInheritableOptions propagates any options set on the listening +// endpoint to the newly created endpoint. +func (e *endpoint) propagateInheritableOptions(n *endpoint) { + e.mu.Lock() + n.userTimeout = e.userTimeout + e.mu.Unlock() +} + // handleSynSegment is called in its own goroutine once the listening endpoint // receives a SYN segment. It is responsible for completing the handshake and // queueing the new endpoint for acceptance. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 3d059c302..4c34fc9d2 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -862,7 +862,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { } e.state = StateError e.HardError = err - if err != tcpip.ErrConnectionReset { + if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk // which can cause sndNxt to be outside the acceptable window on the @@ -1087,12 +1087,24 @@ func (e *endpoint) handleSegments() *tcpip.Error { // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + e.mu.RLock() + userTimeout := e.userTimeout + e.mu.RUnlock() + e.keepalive.Lock() if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } + // If a userTimeout is set then abort the connection if it is + // exceeded. + if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return tcpip.ErrTimeout + } + if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() @@ -1112,7 +1124,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { // whether it is enabled for this endpoint. func (e *endpoint) resetKeepaliveTimer(receivedData bool) { e.keepalive.Lock() - defer e.keepalive.Unlock() if receivedData { e.keepalive.unacked = 0 } @@ -1120,6 +1131,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { // data to send. if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() + e.keepalive.Unlock() return } if e.keepalive.unacked > 0 { @@ -1127,6 +1139,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } else { e.keepalive.timer.enable(e.keepalive.idle) } + e.keepalive.Unlock() } // disableKeepaliveTimer stops the keepalive timer. @@ -1239,6 +1252,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { w: &e.snd.resendWaker, f: func() *tcpip.Error { if !e.snd.retransmitTimerExpired() { + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } return nil @@ -1405,6 +1419,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if s == nil { break } + e.tryDeliverSegmentFromClosedEndpoint(s) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4861ab513..dd8b47cbe 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -341,6 +341,7 @@ type endpoint struct { // TCP should never broadcast but Linux nevertheless supports enabling/ // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // Values used to reserve a port or register a transport endpoint // (which ever happens first). boundBindToDevice tcpip.NICID @@ -474,6 +475,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // userTimeout if non-zero specifies a user specified timeout for + // a connection w/ pending data to send. A connection that has pending + // unacked data will be forcibily aborted if the timeout is reached + // without any data being acked. + userTimeout time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -1333,6 +1340,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyKeepaliveChanged) return nil + case tcpip.TCPUserTimeoutOption: + e.mu.Lock() + e.userTimeout = time.Duration(v) + e.mu.Unlock() + return nil + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 @@ -1591,6 +1604,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.keepalive.Unlock() return nil + case *tcpip.TCPUserTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPUserTimeoutOption(e.userTimeout) + e.mu.Unlock() + return nil + case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 89b965c23..bc718064c 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -162,13 +162,26 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo func replyWithReset(s *segment) { // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) + ack := seqnum.Value(0) + flags := byte(header.TCPFlagRst) + // As per RFC 793 page 35 (Reset Generation) + // 1. If the connection does not exist (CLOSED) then a reset is sent + // in response to any incoming segment except another reset. In + // particular, SYNs addressed to a non-existent connection are rejected + // by this means. + + // If the incoming segment has an ACK field, the reset takes its + // sequence number from the ACK field of the segment, otherwise the + // reset has sequence number zero and the ACK field is set to the sum + // of the sequence number and segment length of the incoming segment. + // The connection remains in the CLOSED state. if s.flagIsSet(header.TCPFlagAck) { seq = s.ackNumber + } else { + flags |= header.TCPFlagAck + ack = s.sequenceNumber.Add(s.logicalLen()) } - - ack := s.sequenceNumber.Add(s.logicalLen()) - - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 5ee499c36..0a5534959 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -50,16 +50,20 @@ type receiver struct { pendingRcvdSegments segmentHeap pendingBufUsed seqnum.Size pendingBufSize seqnum.Size + + // Time when the last ack was received. + lastRcvdAckTime time.Time `state:".(unixTime)"` } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), - rcvWnd: rcvWnd, - rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWnd: rcvWnd, + rcvWndScale: rcvWndScale, + pendingBufSize: pendingBufSize, + lastRcvdAckTime: time.Now(), } } @@ -360,6 +364,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { return true, nil } + // Store the time of the last ack. + r.lastRcvdAckTime = time.Now() + // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go new file mode 100644 index 000000000..2bf21a2e7 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveLastRcvdAckTime is invoked by stateify. +func (r *receiver) saveLastRcvdAckTime() unixTime { + return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} +} + +// loadLastRcvdAckTime is invoked by stateify. +func (r *receiver) loadLastRcvdAckTime(unix unixTime) { + r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8332a0179..8a947dc66 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -28,8 +28,11 @@ import ( ) const ( - // minRTO is the minimum allowed value for the retransmit timeout. - minRTO = 200 * time.Millisecond + // MinRTO is the minimum allowed value for the retransmit timeout. + MinRTO = 200 * time.Millisecond + + // MaxRTO is the maximum allowed value for the retransmit timeout. + MaxRTO = 120 * time.Second // InitialCwnd is the initial congestion window. InitialCwnd = 10 @@ -134,6 +137,10 @@ type sender struct { // rttMeasureTime is the time when the rttMeasureSeqNum was sent. rttMeasureTime time.Time `state:".(unixTime)"` + // firstRetransmittedSegXmitTime is the original transmit time of + // the first segment that was retransmitted due to RTO expiration. + firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + closed bool writeNext *segment writeList segmentList @@ -392,8 +399,8 @@ func (s *sender) updateRTO(rtt time.Duration) { s.rto = s.rtt.srtt + 4*s.rtt.rttvar s.rtt.Unlock() - if s.rto < minRTO { - s.rto = minRTO + if s.rto < MinRTO { + s.rto = MinRTO } } @@ -438,8 +445,30 @@ func (s *sender) retransmitTimerExpired() bool { s.ep.stack.Stats().TCP.Timeouts.Increment() s.ep.stats.SendErrors.Timeouts.Increment() - // Give up if we've waited more than a minute since the last resend. - if s.rto >= 60*time.Second { + // Give up if we've waited more than a minute since the last resend or + // if a user time out is set and we have exceeded the user specified + // timeout since the first retransmission. + s.ep.mu.RLock() + uto := s.ep.userTimeout + s.ep.mu.RUnlock() + + if s.firstRetransmittedSegXmitTime.IsZero() { + // We store the original xmitTime of the segment that we are + // about to retransmit as the retransmission time. This is + // required as by the time the retransmitTimer has expired the + // segment has already been sent and unacked for the RTO at the + // time the segment was sent. + s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime + } + + elapsed := time.Since(s.firstRetransmittedSegXmitTime) + remaining := MaxRTO + if uto != 0 { + // Cap to the user specified timeout if one is specified. + remaining = uto - elapsed + } + + if remaining <= 0 || s.rto >= MaxRTO { return false } @@ -447,6 +476,11 @@ func (s *sender) retransmitTimerExpired() bool { // below. s.rto *= 2 + // Cap RTO to remaining time. + if s.rto > remaining { + s.rto = remaining + } + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. // // Retransmit timeouts: @@ -1168,6 +1202,8 @@ func (s *sender) handleRcvdSegment(seg *segment) { // RFC 6298 Rule 5.3 if s.sndUna == s.sndNxt { s.outstanding = 0 + // Reset firstRetransmittedSegXmitTime to the zero value. + s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() } } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index 12eff8afc..8b20c3455 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) { func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) } + +// saveFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { + return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} +} + +// loadFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { + s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index bc5cfcf0e..2a83f7bcc 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -323,8 +323,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) } func TestTCPResetsReceivedIncrement(t *testing.T) { @@ -460,18 +460,17 @@ func TestConnectResetAfterClose(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) break } } -// TestClosingWithEnqueuedSegments tests handling of -// still enqueued segments when the endpoint transitions -// to StateClose. The in-flight segments would be re-enqueued -// to a any listening endpoint. +// TestClosingWithEnqueuedSegments tests handling of still enqueued segments +// when the endpoint transitions to StateClose. The in-flight segments would be +// re-enqueued to a any listening endpoint. func TestClosingWithEnqueuedSegments(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -576,8 +575,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(793), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) } @@ -914,7 +913,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPFlags(header.TCPFlagRst), checker.SeqNum(200))) } @@ -942,7 +941,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPFlags(header.TCPFlagRst), checker.SeqNum(200))) } @@ -4291,8 +4290,9 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) @@ -4382,13 +4382,29 @@ func TestKeepalive(t *testing.T) { ) } + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + // The connection should be terminated after 5 unacked keepalives. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(next)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), ), ) @@ -6157,8 +6173,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(uint32(ackHeaders.SeqNum)), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) @@ -6336,7 +6352,147 @@ func TestTCPCloseWithData(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(uint32(ackHeaders.SeqNum)), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + +func TestTCPUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + userTimeout := 50 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Send some data and wait before ACKing it. + view := buffer.NewView(3) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Wait for a little over the minimum retransmit timeout of 200ms for + // the retransmitTimer to fire and close the connection. + time.Sleep(tcp.MinRTO + 10*time.Millisecond) + + // No packet should be received as the connection should be silently + // closed due to timeout. + c.CheckNoPacket("unexpected packet received after userTimeout has expired") + + next += uint32(len(view)) + + // The connection should be terminated after userTimeout has expired. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } +} + +func TestKeepaliveWithUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + const keepAliveInterval = 10 * time.Millisecond + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) + c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + + // Set userTimeout to be the duration for 3 keepalive probes. + userTimeout := 30 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Check that the connection is still alive. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Now receive 2 keepalives, but don't ACK them. The connection should + // be reset when the 3rd one should be sent due to userTimeout being + // 30ms and each keepalive probe should be sent 10ms apart as set above after + // the connection has been idle for 10ms. + for i := 0; i < 2; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + + // The connection should be terminated after 30ms. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(c.IRS + 1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index fa4358ae4..761c3a9fe 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -206,7 +206,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { } // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked // before function end. - // ds.reset() + // ds.reset(); } TEST_P(SocketInetLoopbackTest, TCPbacklog) { @@ -603,6 +603,60 @@ TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { SyscallSucceeds()); } +TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the userTimeout on the listening socket. + constexpr int kUserTimeout = 10; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kUserTimeout, sizeof(kUserTimeout)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + // Verify that the accepted socket inherited the user timeout set on + // listening socket. + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(accepted.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kUserTimeout); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index c74273436..57ce8e169 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -812,5 +812,68 @@ TEST_P(TCPSocketPairTest, TestTCPCloseWithData) { ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); } +TEST_P(TCPSocketPairTest, TCPUserTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kZero, sizeof(kZero)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutBelowZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kNeg = -10; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kNeg, sizeof(kNeg)), + SyscallFailsWithErrno(EINVAL)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kAbove = 10; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kAbove, sizeof(kAbove)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kAbove); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 99863b0ed..c503f3568 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1175,6 +1175,31 @@ TEST_P(SimpleTcpSocketTest, SetMaxSegFailsForInvalidMSSValues) { } } +TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + { + constexpr int kTCPUserTimeout = -1; + EXPECT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kTCPUserTimeout, sizeof(kTCPUserTimeout)), + SyscallFailsWithErrno(EINVAL)); + } + + // kTCPUserTimeout is in milliseconds. + constexpr int kTCPUserTimeout = 100; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kTCPUserTimeout, sizeof(kTCPUserTimeout)), + SyscallSucceedsWithValue(0)); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPUserTimeout); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3 From 27500d529f7fb87eef8812278fd1bbca67bcba72 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Thu, 9 Jan 2020 22:00:42 -0800 Subject: New sync package. * Rename syncutil to sync. * Add aliases to sync types. * Replace existing usage of standard library sync package. This will make it easier to swap out synchronization primitives. For example, this will allow us to use primitives from github.com/sasha-s/go-deadlock to check for lock ordering violations. Updates #1472 PiperOrigin-RevId: 289033387 --- pkg/amutex/BUILD | 1 + pkg/amutex/amutex_test.go | 3 +- pkg/atomicbitops/BUILD | 1 + pkg/atomicbitops/atomic_bitops_test.go | 3 +- pkg/compressio/BUILD | 5 +- pkg/compressio/compressio.go | 2 +- pkg/control/server/BUILD | 1 + pkg/control/server/server.go | 2 +- pkg/eventchannel/BUILD | 2 + pkg/eventchannel/event.go | 2 +- pkg/eventchannel/event_test.go | 2 +- pkg/fdchannel/BUILD | 1 + pkg/fdchannel/fdchannel_test.go | 3 +- pkg/fdnotifier/BUILD | 1 + pkg/fdnotifier/fdnotifier.go | 2 +- pkg/flipcall/BUILD | 3 +- pkg/flipcall/flipcall_example_test.go | 3 +- pkg/flipcall/flipcall_test.go | 3 +- pkg/flipcall/flipcall_unsafe.go | 10 +- pkg/gate/BUILD | 1 + pkg/gate/gate_test.go | 2 +- pkg/linewriter/BUILD | 1 + pkg/linewriter/linewriter.go | 3 +- pkg/log/BUILD | 5 +- pkg/log/log.go | 2 +- pkg/metric/BUILD | 1 + pkg/metric/metric.go | 2 +- pkg/p9/BUILD | 1 + pkg/p9/client.go | 2 +- pkg/p9/p9test/BUILD | 2 + pkg/p9/p9test/client_test.go | 2 +- pkg/p9/p9test/p9test.go | 2 +- pkg/p9/path_tree.go | 3 +- pkg/p9/pool.go | 2 +- pkg/p9/server.go | 2 +- pkg/p9/transport.go | 2 +- pkg/procid/BUILD | 2 + pkg/procid/procid_test.go | 3 +- pkg/rand/BUILD | 5 +- pkg/rand/rand_linux.go | 2 +- pkg/refs/BUILD | 2 + pkg/refs/refcounter.go | 2 +- pkg/refs/refcounter_test.go | 3 +- pkg/sentry/arch/BUILD | 1 + pkg/sentry/arch/arch_x86.go | 2 +- pkg/sentry/control/BUILD | 1 + pkg/sentry/control/pprof.go | 2 +- pkg/sentry/device/BUILD | 5 +- pkg/sentry/device/device.go | 2 +- pkg/sentry/fs/BUILD | 3 +- pkg/sentry/fs/copy_up.go | 2 +- pkg/sentry/fs/copy_up_test.go | 2 +- pkg/sentry/fs/dirent.go | 2 +- pkg/sentry/fs/dirent_cache.go | 3 +- pkg/sentry/fs/dirent_cache_limiter.go | 3 +- pkg/sentry/fs/fdpipe/BUILD | 1 + pkg/sentry/fs/fdpipe/pipe.go | 2 +- pkg/sentry/fs/fdpipe/pipe_state.go | 2 +- pkg/sentry/fs/file.go | 2 +- pkg/sentry/fs/file_overlay.go | 2 +- pkg/sentry/fs/filesystems.go | 2 +- pkg/sentry/fs/fs.go | 3 +- pkg/sentry/fs/fsutil/BUILD | 1 + pkg/sentry/fs/fsutil/host_file_mapper.go | 2 +- pkg/sentry/fs/fsutil/host_mappable.go | 2 +- pkg/sentry/fs/fsutil/inode.go | 3 +- pkg/sentry/fs/fsutil/inode_cached.go | 2 +- pkg/sentry/fs/gofer/BUILD | 1 + pkg/sentry/fs/gofer/inode.go | 2 +- pkg/sentry/fs/gofer/session.go | 2 +- pkg/sentry/fs/host/BUILD | 1 + pkg/sentry/fs/host/inode.go | 2 +- pkg/sentry/fs/host/socket.go | 2 +- pkg/sentry/fs/host/tty.go | 3 +- pkg/sentry/fs/inode.go | 3 +- pkg/sentry/fs/inode_inotify.go | 3 +- pkg/sentry/fs/inotify.go | 2 +- pkg/sentry/fs/inotify_watch.go | 2 +- pkg/sentry/fs/lock/BUILD | 1 + pkg/sentry/fs/lock/lock.go | 2 +- pkg/sentry/fs/mounts.go | 2 +- pkg/sentry/fs/overlay.go | 5 +- pkg/sentry/fs/proc/BUILD | 1 + pkg/sentry/fs/proc/seqfile/BUILD | 1 + pkg/sentry/fs/proc/seqfile/seqfile.go | 2 +- pkg/sentry/fs/proc/sys_net.go | 2 +- pkg/sentry/fs/ramfs/BUILD | 1 + pkg/sentry/fs/ramfs/dir.go | 2 +- pkg/sentry/fs/restore.go | 2 +- pkg/sentry/fs/tmpfs/BUILD | 1 + pkg/sentry/fs/tmpfs/inode_file.go | 2 +- pkg/sentry/fs/tty/BUILD | 1 + pkg/sentry/fs/tty/dir.go | 2 +- pkg/sentry/fs/tty/line_discipline.go | 2 +- pkg/sentry/fs/tty/queue.go | 3 +- pkg/sentry/fsimpl/ext/BUILD | 1 + pkg/sentry/fsimpl/ext/directory.go | 3 +- pkg/sentry/fsimpl/ext/filesystem.go | 2 +- pkg/sentry/fsimpl/ext/regular_file.go | 2 +- pkg/sentry/fsimpl/kernfs/BUILD | 2 + pkg/sentry/fsimpl/kernfs/inode_impl_util.go | 2 +- pkg/sentry/fsimpl/kernfs/kernfs.go | 2 +- pkg/sentry/fsimpl/kernfs/kernfs_test.go | 2 +- pkg/sentry/fsimpl/tmpfs/BUILD | 1 + pkg/sentry/fsimpl/tmpfs/regular_file.go | 2 +- pkg/sentry/fsimpl/tmpfs/tmpfs.go | 2 +- pkg/sentry/kernel/BUILD | 5 +- pkg/sentry/kernel/abstract_socket_namespace.go | 2 +- pkg/sentry/kernel/auth/BUILD | 3 +- pkg/sentry/kernel/auth/user_namespace.go | 2 +- pkg/sentry/kernel/epoll/BUILD | 1 + pkg/sentry/kernel/epoll/epoll.go | 2 +- pkg/sentry/kernel/eventfd/BUILD | 1 + pkg/sentry/kernel/eventfd/eventfd.go | 2 +- pkg/sentry/kernel/fasync/BUILD | 1 + pkg/sentry/kernel/fasync/fasync.go | 3 +- pkg/sentry/kernel/fd_table.go | 2 +- pkg/sentry/kernel/fd_table_test.go | 2 +- pkg/sentry/kernel/fs_context.go | 2 +- pkg/sentry/kernel/futex/BUILD | 8 +- pkg/sentry/kernel/futex/futex.go | 3 +- pkg/sentry/kernel/futex/futex_test.go | 2 +- pkg/sentry/kernel/kernel.go | 2 +- pkg/sentry/kernel/memevent/BUILD | 1 + pkg/sentry/kernel/memevent/memory_events.go | 2 +- pkg/sentry/kernel/pipe/BUILD | 1 + pkg/sentry/kernel/pipe/buffer.go | 2 +- pkg/sentry/kernel/pipe/node.go | 3 +- pkg/sentry/kernel/pipe/pipe.go | 2 +- pkg/sentry/kernel/pipe/pipe_util.go | 2 +- pkg/sentry/kernel/pipe/vfs.go | 3 +- pkg/sentry/kernel/semaphore/BUILD | 1 + pkg/sentry/kernel/semaphore/semaphore.go | 2 +- pkg/sentry/kernel/shm/BUILD | 1 + pkg/sentry/kernel/shm/shm.go | 2 +- pkg/sentry/kernel/signal_handlers.go | 3 +- pkg/sentry/kernel/signalfd/BUILD | 1 + pkg/sentry/kernel/signalfd/signalfd.go | 3 +- pkg/sentry/kernel/syscalls.go | 2 +- pkg/sentry/kernel/syslog.go | 3 +- pkg/sentry/kernel/task.go | 5 +- pkg/sentry/kernel/thread_group.go | 2 +- pkg/sentry/kernel/threads.go | 2 +- pkg/sentry/kernel/time/BUILD | 1 + pkg/sentry/kernel/time/time.go | 2 +- pkg/sentry/kernel/timekeeper.go | 2 +- pkg/sentry/kernel/tty.go | 2 +- pkg/sentry/kernel/uts_namespace.go | 3 +- pkg/sentry/limits/BUILD | 1 + pkg/sentry/limits/limits.go | 3 +- pkg/sentry/mm/BUILD | 2 +- pkg/sentry/mm/aio_context.go | 3 +- pkg/sentry/mm/mm.go | 8 +- pkg/sentry/pgalloc/BUILD | 1 + pkg/sentry/pgalloc/pgalloc.go | 2 +- pkg/sentry/platform/interrupt/BUILD | 1 + pkg/sentry/platform/interrupt/interrupt.go | 3 +- pkg/sentry/platform/kvm/BUILD | 1 + pkg/sentry/platform/kvm/address_space.go | 2 +- pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go | 2 - pkg/sentry/platform/kvm/kvm.go | 2 +- pkg/sentry/platform/kvm/machine.go | 2 +- pkg/sentry/platform/ptrace/BUILD | 1 + pkg/sentry/platform/ptrace/ptrace.go | 2 +- pkg/sentry/platform/ptrace/subprocess.go | 2 +- .../platform/ptrace/subprocess_linux_unsafe.go | 2 +- pkg/sentry/platform/ring0/defs.go | 2 +- pkg/sentry/platform/ring0/defs_amd64.go | 1 + pkg/sentry/platform/ring0/defs_arm64.go | 1 + pkg/sentry/platform/ring0/pagetables/BUILD | 5 +- pkg/sentry/platform/ring0/pagetables/pcids_x86.go | 2 +- pkg/sentry/socket/netlink/BUILD | 1 + pkg/sentry/socket/netlink/port/BUILD | 1 + pkg/sentry/socket/netlink/port/port.go | 3 +- pkg/sentry/socket/netlink/socket.go | 2 +- pkg/sentry/socket/netstack/BUILD | 1 + pkg/sentry/socket/netstack/netstack.go | 2 +- pkg/sentry/socket/rpcinet/conn/BUILD | 1 + pkg/sentry/socket/rpcinet/conn/conn.go | 2 +- pkg/sentry/socket/rpcinet/notifier/BUILD | 1 + pkg/sentry/socket/rpcinet/notifier/notifier.go | 2 +- pkg/sentry/socket/unix/transport/BUILD | 1 + pkg/sentry/socket/unix/transport/connectioned.go | 3 +- pkg/sentry/socket/unix/transport/queue.go | 3 +- pkg/sentry/socket/unix/transport/unix.go | 2 +- pkg/sentry/syscalls/linux/BUILD | 1 + pkg/sentry/syscalls/linux/error.go | 2 +- pkg/sentry/time/BUILD | 4 +- pkg/sentry/time/calibrated_clock.go | 2 +- pkg/sentry/usage/BUILD | 1 + pkg/sentry/usage/memory.go | 2 +- pkg/sentry/vfs/BUILD | 3 +- pkg/sentry/vfs/dentry.go | 2 +- pkg/sentry/vfs/file_description_impl_util.go | 2 +- pkg/sentry/vfs/mount_test.go | 3 +- pkg/sentry/vfs/mount_unsafe.go | 4 +- pkg/sentry/vfs/pathname.go | 3 +- pkg/sentry/vfs/resolving_path.go | 2 +- pkg/sentry/vfs/vfs.go | 2 +- pkg/sentry/watchdog/BUILD | 1 + pkg/sentry/watchdog/watchdog.go | 2 +- pkg/sync/BUILD | 53 +++++++ pkg/sync/LICENSE | 27 ++++ pkg/sync/README.md | 5 + pkg/sync/aliases.go | 37 +++++ pkg/sync/atomicptr_unsafe.go | 47 +++++++ pkg/sync/atomicptrtest/BUILD | 29 ++++ pkg/sync/atomicptrtest/atomicptr_test.go | 31 +++++ pkg/sync/downgradable_rwmutex_test.go | 150 ++++++++++++++++++++ pkg/sync/downgradable_rwmutex_unsafe.go | 146 ++++++++++++++++++++ pkg/sync/memmove_unsafe.go | 28 ++++ pkg/sync/norace_unsafe.go | 35 +++++ pkg/sync/race_unsafe.go | 41 ++++++ pkg/sync/seqatomic_unsafe.go | 72 ++++++++++ pkg/sync/seqatomictest/BUILD | 33 +++++ pkg/sync/seqatomictest/seqatomic_test.go | 132 ++++++++++++++++++ pkg/sync/seqcount.go | 149 ++++++++++++++++++++ pkg/sync/seqcount_test.go | 153 +++++++++++++++++++++ pkg/sync/syncutil.go | 7 + pkg/syncutil/BUILD | 52 ------- pkg/syncutil/LICENSE | 27 ---- pkg/syncutil/README.md | 5 - pkg/syncutil/atomicptr_unsafe.go | 47 ------- pkg/syncutil/atomicptrtest/BUILD | 29 ---- pkg/syncutil/atomicptrtest/atomicptr_test.go | 31 ----- pkg/syncutil/downgradable_rwmutex_test.go | 150 -------------------- pkg/syncutil/downgradable_rwmutex_unsafe.go | 146 -------------------- pkg/syncutil/memmove_unsafe.go | 28 ---- pkg/syncutil/norace_unsafe.go | 35 ----- pkg/syncutil/race_unsafe.go | 41 ------ pkg/syncutil/seqatomic_unsafe.go | 72 ---------- pkg/syncutil/seqatomictest/BUILD | 35 ----- pkg/syncutil/seqatomictest/seqatomic_test.go | 132 ------------------ pkg/syncutil/seqcount.go | 149 -------------------- pkg/syncutil/seqcount_test.go | 153 --------------------- pkg/syncutil/syncutil.go | 7 - pkg/tcpip/BUILD | 1 + pkg/tcpip/adapters/gonet/BUILD | 1 + pkg/tcpip/adapters/gonet/gonet.go | 2 +- pkg/tcpip/link/fdbased/BUILD | 1 + pkg/tcpip/link/fdbased/endpoint.go | 2 +- pkg/tcpip/link/sharedmem/BUILD | 2 + pkg/tcpip/link/sharedmem/pipe/BUILD | 1 + pkg/tcpip/link/sharedmem/pipe/pipe_test.go | 3 +- pkg/tcpip/link/sharedmem/sharedmem.go | 2 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/network/fragmentation/BUILD | 1 + pkg/tcpip/network/fragmentation/fragmentation.go | 2 +- pkg/tcpip/network/fragmentation/reassembler.go | 2 +- pkg/tcpip/ports/BUILD | 1 + pkg/tcpip/ports/ports.go | 2 +- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/linkaddrcache.go | 2 +- pkg/tcpip/stack/linkaddrcache_test.go | 2 +- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/stack/transport_demuxer.go | 2 +- pkg/tcpip/tcpip.go | 2 +- pkg/tcpip/transport/icmp/BUILD | 1 + pkg/tcpip/transport/icmp/endpoint.go | 3 +- pkg/tcpip/transport/packet/BUILD | 1 + pkg/tcpip/transport/packet/endpoint.go | 3 +- pkg/tcpip/transport/raw/BUILD | 1 + pkg/tcpip/transport/raw/endpoint.go | 3 +- pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/accept.go | 2 +- pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 2 +- pkg/tcpip/transport/tcp/endpoint_state.go | 2 +- pkg/tcpip/transport/tcp/forwarder.go | 3 +- pkg/tcpip/transport/tcp/protocol.go | 2 +- pkg/tcpip/transport/tcp/segment_queue.go | 2 +- pkg/tcpip/transport/tcp/snd.go | 2 +- pkg/tcpip/transport/udp/BUILD | 1 + pkg/tcpip/transport/udp/endpoint.go | 3 +- pkg/tmutex/BUILD | 1 + pkg/tmutex/tmutex_test.go | 3 +- pkg/unet/BUILD | 1 + pkg/unet/unet_test.go | 3 +- pkg/urpc/BUILD | 1 + pkg/urpc/urpc.go | 2 +- pkg/waiter/BUILD | 1 + pkg/waiter/waiter.go | 2 +- runsc/boot/BUILD | 2 + runsc/boot/compat.go | 2 +- runsc/boot/limits.go | 2 +- runsc/boot/loader.go | 2 +- runsc/boot/loader_test.go | 2 +- runsc/cmd/BUILD | 1 + runsc/cmd/create.go | 1 + runsc/cmd/gofer.go | 2 +- runsc/cmd/start.go | 1 + runsc/container/BUILD | 2 + runsc/container/console_test.go | 2 +- runsc/container/container_test.go | 2 +- runsc/container/multi_container_test.go | 2 +- runsc/container/state_file.go | 2 +- runsc/fsgofer/BUILD | 1 + runsc/fsgofer/fsgofer.go | 2 +- runsc/sandbox/BUILD | 1 + runsc/sandbox/sandbox.go | 2 +- runsc/testutil/BUILD | 1 + runsc/testutil/testutil.go | 2 +- 303 files changed, 1507 insertions(+), 1368 deletions(-) create mode 100644 pkg/sync/BUILD create mode 100644 pkg/sync/LICENSE create mode 100644 pkg/sync/README.md create mode 100644 pkg/sync/aliases.go create mode 100644 pkg/sync/atomicptr_unsafe.go create mode 100644 pkg/sync/atomicptrtest/BUILD create mode 100644 pkg/sync/atomicptrtest/atomicptr_test.go create mode 100644 pkg/sync/downgradable_rwmutex_test.go create mode 100644 pkg/sync/downgradable_rwmutex_unsafe.go create mode 100644 pkg/sync/memmove_unsafe.go create mode 100644 pkg/sync/norace_unsafe.go create mode 100644 pkg/sync/race_unsafe.go create mode 100644 pkg/sync/seqatomic_unsafe.go create mode 100644 pkg/sync/seqatomictest/BUILD create mode 100644 pkg/sync/seqatomictest/seqatomic_test.go create mode 100644 pkg/sync/seqcount.go create mode 100644 pkg/sync/seqcount_test.go create mode 100644 pkg/sync/syncutil.go delete mode 100644 pkg/syncutil/BUILD delete mode 100644 pkg/syncutil/LICENSE delete mode 100644 pkg/syncutil/README.md delete mode 100644 pkg/syncutil/atomicptr_unsafe.go delete mode 100644 pkg/syncutil/atomicptrtest/BUILD delete mode 100644 pkg/syncutil/atomicptrtest/atomicptr_test.go delete mode 100644 pkg/syncutil/downgradable_rwmutex_test.go delete mode 100644 pkg/syncutil/downgradable_rwmutex_unsafe.go delete mode 100644 pkg/syncutil/memmove_unsafe.go delete mode 100644 pkg/syncutil/norace_unsafe.go delete mode 100644 pkg/syncutil/race_unsafe.go delete mode 100644 pkg/syncutil/seqatomic_unsafe.go delete mode 100644 pkg/syncutil/seqatomictest/BUILD delete mode 100644 pkg/syncutil/seqatomictest/seqatomic_test.go delete mode 100644 pkg/syncutil/seqcount.go delete mode 100644 pkg/syncutil/seqcount_test.go delete mode 100644 pkg/syncutil/syncutil.go (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index 6bc486b62..d99e37b40 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -15,4 +15,5 @@ go_test( size = "small", srcs = ["amutex_test.go"], embed = [":amutex"], + deps = ["//pkg/sync"], ) diff --git a/pkg/amutex/amutex_test.go b/pkg/amutex/amutex_test.go index 1d7f45641..8a3952f2a 100644 --- a/pkg/amutex/amutex_test.go +++ b/pkg/amutex/amutex_test.go @@ -15,9 +15,10 @@ package amutex import ( - "sync" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) type sleeper struct { diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 36beaade9..6403c60c2 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -20,4 +20,5 @@ go_test( size = "small", srcs = ["atomic_bitops_test.go"], embed = [":atomicbitops"], + deps = ["//pkg/sync"], ) diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomic_bitops_test.go index 965e9be79..9466d3e23 100644 --- a/pkg/atomicbitops/atomic_bitops_test.go +++ b/pkg/atomicbitops/atomic_bitops_test.go @@ -16,8 +16,9 @@ package atomicbitops import ( "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) const iterations = 100 diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index a0b21d4bd..2bb581b18 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -8,7 +8,10 @@ go_library( srcs = ["compressio.go"], importpath = "gvisor.dev/gvisor/pkg/compressio", visibility = ["//:sandbox"], - deps = ["//pkg/binary"], + deps = [ + "//pkg/binary", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index 3b0bb086e..5f52cbe74 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -52,9 +52,9 @@ import ( "hash" "io" "runtime" - "sync" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sync" ) var bufPool = sync.Pool{ diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD index 21adf3adf..adbd1e3f8 100644 --- a/pkg/control/server/BUILD +++ b/pkg/control/server/BUILD @@ -9,6 +9,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", ], diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go index a56152d10..41abe1f2d 100644 --- a/pkg/control/server/server.go +++ b/pkg/control/server/server.go @@ -22,9 +22,9 @@ package server import ( "os" - "sync" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/urpc" ) diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 0b4b7cc44..9d68682c7 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ ":eventchannel_go_proto", "//pkg/log", + "//pkg/sync", "//pkg/unet", "@com_github_golang_protobuf//proto:go_default_library", "@com_github_golang_protobuf//ptypes:go_default_library_gen", @@ -40,6 +41,7 @@ go_test( srcs = ["event_test.go"], embed = [":eventchannel"], deps = [ + "//pkg/sync", "@com_github_golang_protobuf//proto:go_default_library", ], ) diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go index d37ad0428..9a29c58bd 100644 --- a/pkg/eventchannel/event.go +++ b/pkg/eventchannel/event.go @@ -22,13 +22,13 @@ package eventchannel import ( "encoding/binary" "fmt" - "sync" "syscall" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go index 3649097d6..7f41b4a27 100644 --- a/pkg/eventchannel/event_test.go +++ b/pkg/eventchannel/event_test.go @@ -16,11 +16,11 @@ package eventchannel import ( "fmt" - "sync" "testing" "time" "github.com/golang/protobuf/proto" + "gvisor.dev/gvisor/pkg/sync" ) // testEmitter is an emitter that can be used in tests. It records all events diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index 56495cbd9..b0478c672 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -15,4 +15,5 @@ go_test( size = "small", srcs = ["fdchannel_test.go"], embed = [":fdchannel"], + deps = ["//pkg/sync"], ) diff --git a/pkg/fdchannel/fdchannel_test.go b/pkg/fdchannel/fdchannel_test.go index 5d01dc636..7a8a63a59 100644 --- a/pkg/fdchannel/fdchannel_test.go +++ b/pkg/fdchannel/fdchannel_test.go @@ -17,10 +17,11 @@ package fdchannel import ( "io/ioutil" "os" - "sync" "syscall" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func TestSendRecvFD(t *testing.T) { diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD index aca2d8a82..91a202a30 100644 --- a/pkg/fdnotifier/BUILD +++ b/pkg/fdnotifier/BUILD @@ -11,6 +11,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/fdnotifier", visibility = ["//:sandbox"], deps = [ + "//pkg/sync", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/fdnotifier/fdnotifier.go b/pkg/fdnotifier/fdnotifier.go index f4aae1953..a6b63c982 100644 --- a/pkg/fdnotifier/fdnotifier.go +++ b/pkg/fdnotifier/fdnotifier.go @@ -22,10 +22,10 @@ package fdnotifier import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index e590a71ba..85bd83af1 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -19,7 +19,7 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/memutil", - "//pkg/syncutil", + "//pkg/sync", ], ) @@ -31,4 +31,5 @@ go_test( "flipcall_test.go", ], embed = [":flipcall"], + deps = ["//pkg/sync"], ) diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go index 8d88b845d..2e28a149a 100644 --- a/pkg/flipcall/flipcall_example_test.go +++ b/pkg/flipcall/flipcall_example_test.go @@ -17,7 +17,8 @@ package flipcall import ( "bytes" "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) func Example() { diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go index 168a487ec..33fd55a44 100644 --- a/pkg/flipcall/flipcall_test.go +++ b/pkg/flipcall/flipcall_test.go @@ -16,9 +16,10 @@ package flipcall import ( "runtime" - "sync" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) var testPacketWindowSize = pageSize diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index 27b8939fc..ac974b232 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -18,7 +18,7 @@ import ( "reflect" "unsafe" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // Packets consist of a 16-byte header followed by an arbitrarily-sized @@ -75,13 +75,13 @@ func (ep *Endpoint) Data() []byte { var ioSync int64 func raceBecomeActive() { - if syncutil.RaceEnabled { - syncutil.RaceAcquire((unsafe.Pointer)(&ioSync)) + if sync.RaceEnabled { + sync.RaceAcquire((unsafe.Pointer)(&ioSync)) } } func raceBecomeInactive() { - if syncutil.RaceEnabled { - syncutil.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + if sync.RaceEnabled { + sync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) } } diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index 4b9321711..f22bd070d 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -19,5 +19,6 @@ go_test( ], deps = [ ":gate", + "//pkg/sync", ], ) diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go index 5dbd8d712..850693df8 100644 --- a/pkg/gate/gate_test.go +++ b/pkg/gate/gate_test.go @@ -15,11 +15,11 @@ package gate_test import ( - "sync" "testing" "time" "gvisor.dev/gvisor/pkg/gate" + "gvisor.dev/gvisor/pkg/sync" ) func TestBasicEnter(t *testing.T) { diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index a5d980d14..bcde6d308 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -8,6 +8,7 @@ go_library( srcs = ["linewriter.go"], importpath = "gvisor.dev/gvisor/pkg/linewriter", visibility = ["//visibility:public"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/linewriter/linewriter.go b/pkg/linewriter/linewriter.go index cd6e4e2ce..a1b1285d4 100644 --- a/pkg/linewriter/linewriter.go +++ b/pkg/linewriter/linewriter.go @@ -17,7 +17,8 @@ package linewriter import ( "bytes" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Writer is an io.Writer which buffers input, flushing diff --git a/pkg/log/BUILD b/pkg/log/BUILD index fc5f5779b..0df0f2849 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -16,7 +16,10 @@ go_library( visibility = [ "//visibility:public", ], - deps = ["//pkg/linewriter"], + deps = [ + "//pkg/linewriter", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/log/log.go b/pkg/log/log.go index 9387586e6..91a81b288 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -25,12 +25,12 @@ import ( stdlog "log" "os" "runtime" - "sync" "sync/atomic" "syscall" "time" "gvisor.dev/gvisor/pkg/linewriter" + "gvisor.dev/gvisor/pkg/sync" ) // Level is the log level. diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index dd6ca6d39..9145f3233 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -14,6 +14,7 @@ go_library( ":metric_go_proto", "//pkg/eventchannel", "//pkg/log", + "//pkg/sync", ], ) diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index eadde06e4..93d4f2b8c 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -18,12 +18,12 @@ package metric import ( "errors" "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/eventchannel" "gvisor.dev/gvisor/pkg/log" pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index f32244c69..a3e05c96d 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -29,6 +29,7 @@ go_library( "//pkg/fdchannel", "//pkg/flipcall", "//pkg/log", + "//pkg/sync", "//pkg/unet", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 221516c6c..4045e41fa 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -17,12 +17,12 @@ package p9 import ( "errors" "fmt" - "sync" "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 28707c0ca..f4edd68b2 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -70,6 +70,7 @@ go_library( "//pkg/fd", "//pkg/log", "//pkg/p9", + "//pkg/sync", "//pkg/unet", "@com_github_golang_mock//gomock:go_default_library", ], @@ -83,6 +84,7 @@ go_test( deps = [ "//pkg/fd", "//pkg/p9", + "//pkg/sync", "@com_github_golang_mock//gomock:go_default_library", ], ) diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 6e758148d..6e7bb3db2 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -22,7 +22,6 @@ import ( "os" "reflect" "strings" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( "github.com/golang/mock/gomock" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" ) func TestPanic(t *testing.T) { diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index 4d3271b37..dd8b01b6d 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -17,13 +17,13 @@ package p9test import ( "fmt" - "sync" "sync/atomic" "syscall" "testing" "github.com/golang/mock/gomock" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/path_tree.go b/pkg/p9/path_tree.go index 865459411..72ef53313 100644 --- a/pkg/p9/path_tree.go +++ b/pkg/p9/path_tree.go @@ -16,7 +16,8 @@ package p9 import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // pathNode is a single node in a path traversal. diff --git a/pkg/p9/pool.go b/pkg/p9/pool.go index 52de889e1..2b14a5ce3 100644 --- a/pkg/p9/pool.go +++ b/pkg/p9/pool.go @@ -15,7 +15,7 @@ package p9 import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // pool is a simple allocator. diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 40b8fa023..fdfa83648 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -17,7 +17,6 @@ package p9 import ( "io" "runtime/debug" - "sync" "sync/atomic" "syscall" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/fdchannel" "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go index 6e8b4bbcd..9c11e28ce 100644 --- a/pkg/p9/transport.go +++ b/pkg/p9/transport.go @@ -19,11 +19,11 @@ import ( "fmt" "io" "io/ioutil" - "sync" "syscall" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index 078f084b2..b506813f0 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -21,6 +21,7 @@ go_test( "procid_test.go", ], embed = [":procid"], + deps = ["//pkg/sync"], ) go_test( @@ -31,4 +32,5 @@ go_test( "procid_test.go", ], embed = [":procid"], + deps = ["//pkg/sync"], ) diff --git a/pkg/procid/procid_test.go b/pkg/procid/procid_test.go index 88dd0b3ae..9ec08c3d6 100644 --- a/pkg/procid/procid_test.go +++ b/pkg/procid/procid_test.go @@ -17,9 +17,10 @@ package procid import ( "os" "runtime" - "sync" "syscall" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) // runOnMain is used to send functions to run on the main (initial) thread. diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD index f4f2001f3..9d5b4859b 100644 --- a/pkg/rand/BUILD +++ b/pkg/rand/BUILD @@ -10,5 +10,8 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/rand", visibility = ["//:sandbox"], - deps = ["@org_golang_x_sys//unix:go_default_library"], + deps = [ + "//pkg/sync", + "@org_golang_x_sys//unix:go_default_library", + ], ) diff --git a/pkg/rand/rand_linux.go b/pkg/rand/rand_linux.go index 2b92db3e6..0bdad5fad 100644 --- a/pkg/rand/rand_linux.go +++ b/pkg/rand/rand_linux.go @@ -19,9 +19,9 @@ package rand import ( "crypto/rand" "io" - "sync" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" ) // reader implements an io.Reader that returns pseudorandom bytes. diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 7ad59dfd7..974d9af9b 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -27,6 +27,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", ], ) @@ -35,4 +36,5 @@ go_test( size = "small", srcs = ["refcounter_test.go"], embed = [":refs"], + deps = ["//pkg/sync"], ) diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index ad69e0757..c45ba8200 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -21,10 +21,10 @@ import ( "fmt" "reflect" "runtime" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" ) // RefCounter is the interface to be implemented by objects that are reference diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go index ffd3d3f07..1ab4a4440 100644 --- a/pkg/refs/refcounter_test.go +++ b/pkg/refs/refcounter_test.go @@ -16,8 +16,9 @@ package refs import ( "reflect" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) type testCounter struct { diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 18c73cc24..ae3e364cd 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/sentry/context", "//pkg/sentry/limits", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index 9294ac773..9f41e566f 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -19,7 +19,6 @@ package arch import ( "fmt" "io" - "sync" "syscall" "gvisor.dev/gvisor/pkg/binary" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/log" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index 5522cecd0..2561a6109 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/sentry/strace", "//pkg/sentry/usage", "//pkg/sentry/watchdog", + "//pkg/sync", "//pkg/tcpip/link/sniffer", "//pkg/urpc", ], diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index e1f2fea60..151808911 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -19,10 +19,10 @@ import ( "runtime" "runtime/pprof" "runtime/trace" - "sync" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/urpc" ) diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 1098ed777..97fa1512c 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -8,7 +8,10 @@ go_library( srcs = ["device.go"], importpath = "gvisor.dev/gvisor/pkg/sentry/device", visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/abi/linux"], + deps = [ + "//pkg/abi/linux", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go index 47945d1a7..69e71e322 100644 --- a/pkg/sentry/device/device.go +++ b/pkg/sentry/device/device.go @@ -19,10 +19,10 @@ package device import ( "bytes" "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" ) // Registry tracks all simple devices and related state on the system for diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index c035ffff7..7d5d72d5a 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -68,7 +68,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], @@ -115,6 +115,7 @@ go_test( "//pkg/sentry/fs/tmpfs", "//pkg/sentry/kernel/contexttest", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index 9ac62c84d..734177e90 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -17,12 +17,12 @@ package fs import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index 1d80bf15a..738580c5f 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -19,13 +19,13 @@ import ( "crypto/rand" "fmt" "io" - "sync" "testing" "gvisor.dev/gvisor/pkg/sentry/fs" _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) const ( diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index 3cb73bd78..31fc4d87b 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -18,7 +18,6 @@ import ( "fmt" "path" "sort" - "sync" "sync/atomic" "syscall" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go index 60a15a275..25514ace4 100644 --- a/pkg/sentry/fs/dirent_cache.go +++ b/pkg/sentry/fs/dirent_cache.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // DirentCache is an LRU cache of Dirents. The Dirent's refCount is diff --git a/pkg/sentry/fs/dirent_cache_limiter.go b/pkg/sentry/fs/dirent_cache_limiter.go index ebb80bd50..525ee25f9 100644 --- a/pkg/sentry/fs/dirent_cache_limiter.go +++ b/pkg/sentry/fs/dirent_cache_limiter.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // DirentCacheLimiter acts as a global limit for all dirent caches in the diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index 277ee4c31..cc43de69d 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -23,6 +23,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/safemem", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index 669ffcb75..5b6cfeb0a 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -17,7 +17,6 @@ package fdpipe import ( "os" - "sync" "syscall" "gvisor.dev/gvisor/pkg/fd" @@ -29,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go index 29175fb3d..cee87f726 100644 --- a/pkg/sentry/fs/fdpipe/pipe_state.go +++ b/pkg/sentry/fs/fdpipe/pipe_state.go @@ -17,10 +17,10 @@ package fdpipe import ( "fmt" "io/ioutil" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" ) // beforeSave is invoked by stateify. diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index a2f966cb6..7c4586296 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -16,7 +16,6 @@ package fs import ( "math" - "sync" "sync/atomic" "time" @@ -29,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 225e40186..8a633b1ba 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -16,13 +16,13 @@ package fs import ( "io" - "sync" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go index b157fd228..c5b51620a 100644 --- a/pkg/sentry/fs/filesystems.go +++ b/pkg/sentry/fs/filesystems.go @@ -18,9 +18,9 @@ import ( "fmt" "sort" "strings" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" ) // FilesystemFlags matches include/linux/fs.h:file_system_type.fs_flags. diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index 8b2a5e6b2..26abf49e2 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -54,10 +54,9 @@ package fs import ( - "sync" - "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 9ca695a95..945b6270d 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -93,6 +93,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index b06a71cc2..837fc70b5 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -16,7 +16,6 @@ package fsutil import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/log" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // HostFileMapper caches mappings of an arbitrary host file descriptor. It is diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index 30475f340..a625f0e26 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -16,7 +16,6 @@ package fsutil import ( "math" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // HostMappable implements memmap.Mappable and platform.File over a diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go index 4e100a402..adf5ec69c 100644 --- a/pkg/sentry/fs/fsutil/inode.go +++ b/pkg/sentry/fs/fsutil/inode.go @@ -15,13 +15,12 @@ package fsutil import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 798920d18..20a014402 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -17,7 +17,6 @@ package fsutil import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // Lock order (compare the lock order model in mm/mm.go): diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 4a005c605..fd870e8e1 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -44,6 +44,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 91263ebdc..245fe2ef1 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -16,7 +16,6 @@ package gofer import ( "errors" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/host" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 4e358a46a..edc796ce0 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -16,7 +16,6 @@ package gofer import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refs" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 23daeb528..2b581aa69 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/sentry/unimpl", "//pkg/sentry/uniqueid", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index a6e4a09e3..873a1c52d 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -15,7 +15,6 @@ package host import ( - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 107336a3e..c076d5bdd 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -16,7 +16,6 @@ package host import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -30,6 +29,7 @@ import ( unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 90331e3b2..753ef8cd6 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -15,8 +15,6 @@ package host import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index 91e2fde2f..468043df0 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -15,8 +15,6 @@ package fs import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" @@ -26,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go index 0f2a66a79..efd3c962b 100644 --- a/pkg/sentry/fs/inode_inotify.go +++ b/pkg/sentry/fs/inode_inotify.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Watches is the collection of inotify watches on an inode. diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index ba3e0233d..cc7dd1c92 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -16,7 +16,6 @@ package fs import ( "io" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go index 0aa0a5e9b..900cba3ca 100644 --- a/pkg/sentry/fs/inotify_watch.go +++ b/pkg/sentry/fs/inotify_watch.go @@ -15,10 +15,10 @@ package fs import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" ) // Watch represent a particular inotify watch created by inotify_add_watch. diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 8d62642e7..2c332a82a 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -44,6 +44,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go index 636484424..41b040818 100644 --- a/pkg/sentry/fs/lock/lock.go +++ b/pkg/sentry/fs/lock/lock.go @@ -52,9 +52,9 @@ package lock import ( "fmt" "math" - "sync" "syscall" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index ac0398bd9..db3dfd096 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -19,7 +19,6 @@ import ( "math" "path" "strings" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index 25573e986..4cad55327 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -17,13 +17,12 @@ package fs import ( "fmt" "strings" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -199,7 +198,7 @@ type overlayEntry struct { upper *Inode // dirCacheMu protects dirCache. - dirCacheMu syncutil.DowngradableRWMutex `state:"nosave"` + dirCacheMu sync.DowngradableRWMutex `state:"nosave"` // dirCache is cache of DentAttrs from upper and lower Inodes. dirCache *SortedDentryMap diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 75cbb0622..94d46ab1b 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -51,6 +51,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", "//pkg/waiter", diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index fe7067be1..38b246dff 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/proc/device", "//pkg/sentry/kernel/time", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go index 5fe823000..f9af191d5 100644 --- a/pkg/sentry/fs/proc/seqfile/seqfile.go +++ b/pkg/sentry/fs/proc/seqfile/seqfile.go @@ -17,7 +17,6 @@ package seqfile import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -26,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index bd93f83fa..a37e1fa06 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -17,7 +17,6 @@ package proc import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 012cb3e44..3fb7b0633 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go index 78e082b8e..dcbb8eb2e 100644 --- a/pkg/sentry/fs/ramfs/dir.go +++ b/pkg/sentry/fs/ramfs/dir.go @@ -17,7 +17,6 @@ package ramfs import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/restore.go b/pkg/sentry/fs/restore.go index f10168125..64c6a6ae9 100644 --- a/pkg/sentry/fs/restore.go +++ b/pkg/sentry/fs/restore.go @@ -15,7 +15,7 @@ package fs import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // RestoreEnvironment is the restore environment for file systems. It consists diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 59ce400c2..3400b940c 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index f86dfaa36..f1c87fe41 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -17,7 +17,6 @@ package tmpfs import ( "fmt" "io" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 95ad98cb0..f6f60d0cf 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 2f639c823..88aa66b24 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -19,7 +19,6 @@ import ( "fmt" "math" "strconv" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go index 7cc0eb409..894964260 100644 --- a/pkg/sentry/fs/tty/line_discipline.go +++ b/pkg/sentry/fs/tty/line_discipline.go @@ -16,13 +16,13 @@ package tty import ( "bytes" - "sync" "unicode/utf8" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go index 231e4e6eb..8b5d4699a 100644 --- a/pkg/sentry/fs/tty/queue.go +++ b/pkg/sentry/fs/tty/queue.go @@ -15,13 +15,12 @@ package tty import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index bc90330bc..903874141 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/sentry/syscalls/linux", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 91802dc1e..8944171c8 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -15,8 +15,6 @@ package ext import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/log" @@ -25,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 616fc002a..9afb1a84c 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -17,13 +17,13 @@ package ext import ( "errors" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index aec33e00a..d11153c90 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -16,7 +16,6 @@ package ext import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 39c03ee9d..809178250 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -39,6 +39,7 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", ], ) @@ -56,6 +57,7 @@ go_test( "//pkg/sentry/kernel/auth", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "@com_github_google_go-cmp//cmp:go_default_library", ], diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 752e0f659..1d469a0db 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -16,7 +16,6 @@ package kernfs import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index d69b299ae..bb12f39a2 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -53,7 +53,6 @@ package kernfs import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -61,6 +60,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" ) // FilesystemType implements vfs.FilesystemType. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 4b6b95f5f..5c9d580e1 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "runtime" - "sync" "testing" "github.com/google/go-cmp/cmp" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index a5b285987..82f5c2f41 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -47,6 +47,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index f51e247a7..f200e767d 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -17,7 +17,6 @@ package tmpfs import ( "io" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 7be6faa5b..701826f90 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -26,7 +26,6 @@ package tmpfs import ( "fmt" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 2706927ff..ac85ba0c8 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -35,7 +35,7 @@ go_template_instance( out = "seqatomic_taskgoroutineschedinfo_unsafe.go", package = "kernel", suffix = "TaskGoroutineSchedInfo", - template = "//pkg/syncutil:generic_seqatomic", + template = "//pkg/sync:generic_seqatomic", types = { "Value": "TaskGoroutineSchedInfo", }, @@ -209,7 +209,7 @@ go_library( "//pkg/sentry/usermem", "//pkg/state", "//pkg/state/statefile", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", @@ -241,6 +241,7 @@ go_test( "//pkg/sentry/time", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index 244655b5c..920fe4329 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -15,11 +15,11 @@ package kernel import ( - "sync" "syscall" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" ) // +stateify savable diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 04c244447..1aa72fa47 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_credentials_unsafe.go", package = "auth", suffix = "Credentials", - template = "//pkg/syncutil:generic_atomicptr", + template = "//pkg/sync:generic_atomicptr", types = { "Value": "Credentials", }, @@ -64,6 +64,7 @@ go_library( "//pkg/bits", "//pkg/log", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/auth/user_namespace.go b/pkg/sentry/kernel/auth/user_namespace.go index af28ccc65..9dd52c860 100644 --- a/pkg/sentry/kernel/auth/user_namespace.go +++ b/pkg/sentry/kernel/auth/user_namespace.go @@ -16,8 +16,8 @@ package auth import ( "math" - "sync" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index 3361e8b7d..c47f6b6fc 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 9c0a4e1b4..430311cc0 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -18,7 +18,6 @@ package epoll import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/refs" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index e65b961e8..c831fbab2 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 12f0d429b..687690679 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -18,7 +18,6 @@ package eventfd import ( "math" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 49d81b712..6b36bc63e 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index 6b0bb0324..d32c3e90a 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -16,12 +16,11 @@ package fasync import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 11f613a11..cd1501f85 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -18,7 +18,6 @@ import ( "bytes" "fmt" "math" - "sync" "sync/atomic" "syscall" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) // FDFlags define flags for an individual descriptor. diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 2bcb6216a..eccb7d1e7 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -16,7 +16,6 @@ package kernel import ( "runtime" - "sync" "testing" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/filetest" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) const ( diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index ded27d668..2448c1d99 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -16,10 +16,10 @@ package kernel import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" ) // FSContext contains filesystem context. diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 75ec31761..50db443ce 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "atomicptr_bucket_unsafe.go", package = "futex", suffix = "Bucket", - template = "//pkg/syncutil:generic_atomicptr", + template = "//pkg/sync:generic_atomicptr", types = { "Value": "bucket", }, @@ -42,6 +42,7 @@ go_library( "//pkg/sentry/context", "//pkg/sentry/memmap", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) @@ -51,5 +52,8 @@ go_test( size = "small", srcs = ["futex_test.go"], embed = [":futex"], - deps = ["//pkg/sentry/usermem"], + deps = [ + "//pkg/sentry/usermem", + "//pkg/sync", + ], ) diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index 278cc8143..d1931c8f4 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -18,11 +18,10 @@ package futex import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go index 65e5d1428..c23126ca5 100644 --- a/pkg/sentry/kernel/futex/futex_test.go +++ b/pkg/sentry/kernel/futex/futex_test.go @@ -17,13 +17,13 @@ package futex import ( "math" "runtime" - "sync" "sync/atomic" "syscall" "testing" "unsafe" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // testData implements the Target interface, and allows us to diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 8653d2f63..c85e97fef 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -36,7 +36,6 @@ import ( "fmt" "io" "path/filepath" - "sync" "sync/atomic" "time" @@ -67,6 +66,7 @@ import ( uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD index d7a7d1169..7f36252a9 100644 --- a/pkg/sentry/kernel/memevent/BUILD +++ b/pkg/sentry/kernel/memevent/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/metric", "//pkg/sentry/kernel", "//pkg/sentry/usage", + "//pkg/sync", ], ) diff --git a/pkg/sentry/kernel/memevent/memory_events.go b/pkg/sentry/kernel/memevent/memory_events.go index b0d98e7f0..200565bb8 100644 --- a/pkg/sentry/kernel/memevent/memory_events.go +++ b/pkg/sentry/kernel/memevent/memory_events.go @@ -17,7 +17,6 @@ package memevent import ( - "sync" "time" "gvisor.dev/gvisor/pkg/eventchannel" @@ -26,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" pb "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sync" ) var totalTicks = metric.MustCreateNewUint64Metric("/memory_events/ticks", false /*sync*/, "Total number of memory event periods that have elapsed since startup.") diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 9d34f6d4d..5eeaeff66 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -43,6 +43,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go index 95bee2d37..1c0f34269 100644 --- a/pkg/sentry/kernel/pipe/buffer.go +++ b/pkg/sentry/kernel/pipe/buffer.go @@ -16,9 +16,9 @@ package pipe import ( "io" - "sync" "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/sync" ) // buffer encapsulates a queueable byte buffer. diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index 4a19ab7ce..716f589af 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -15,12 +15,11 @@ package pipe import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 1a1b38f83..e4fd7d420 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -17,12 +17,12 @@ package pipe import ( "fmt" - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index ef9641e6a..8394eb78b 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -17,7 +17,6 @@ package pipe import ( "io" "math" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 6416e0dd8..bf7461cbb 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -15,13 +15,12 @@ package pipe import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index f4c00cd86..13a961594 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index de9617e9d..18299814e 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -17,7 +17,6 @@ package semaphore import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index cd48945e6..7321b22ed 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -24,6 +24,7 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 19034a21e..8ddef7eb8 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -35,7 +35,6 @@ package shm import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" @@ -49,6 +48,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go index a16f3d57f..768fda220 100644 --- a/pkg/sentry/kernel/signal_handlers.go +++ b/pkg/sentry/kernel/signal_handlers.go @@ -15,10 +15,9 @@ package kernel import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sync" ) // SignalHandlers holds information about signal actions. diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 9f7e19b4d..89e4d84b1 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 4b08d7d72..28be4a939 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -16,8 +16,6 @@ package signalfd import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/context" @@ -26,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 2fdee0282..d2d01add4 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -16,13 +16,13 @@ package kernel import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // maxSyscallNum is the highest supported syscall number. diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go index 8227ecf1d..4607cde2f 100644 --- a/pkg/sentry/kernel/syslog.go +++ b/pkg/sentry/kernel/syslog.go @@ -17,7 +17,8 @@ package kernel import ( "fmt" "math/rand" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // syslog represents a sentry-global kernel log. diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index d25a7903b..978d66da8 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -17,7 +17,6 @@ package kernel import ( gocontext "context" "runtime/trace" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -37,7 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) @@ -85,7 +84,7 @@ type Task struct { // // gosched is protected by goschedSeq. gosched is owned by the task // goroutine. - goschedSeq syncutil.SeqCount `state:"nosave"` + goschedSeq sync.SeqCount `state:"nosave"` gosched TaskGoroutineSchedInfo // yieldCount is the number of times the task goroutine has called diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index c0197a563..768e958d2 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -15,7 +15,6 @@ package kernel import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 8267929a6..bf2dabb6e 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -16,9 +16,9 @@ package kernel import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 31847e1df..4e4de0512 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -13,6 +13,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index 107394183..706de83ef 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -19,10 +19,10 @@ package time import ( "fmt" "math" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 76417342a..dc99301de 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -16,7 +16,6 @@ package kernel import ( "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/log" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sync" ) // Timekeeper manages all of the kernel clocks. diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go index 048de26dc..464d2306a 100644 --- a/pkg/sentry/kernel/tty.go +++ b/pkg/sentry/kernel/tty.go @@ -14,7 +14,7 @@ package kernel -import "sync" +import "gvisor.dev/gvisor/pkg/sync" // TTY defines the relationship between a thread group and its controlling // terminal. diff --git a/pkg/sentry/kernel/uts_namespace.go b/pkg/sentry/kernel/uts_namespace.go index 0a563e715..8ccf04bd1 100644 --- a/pkg/sentry/kernel/uts_namespace.go +++ b/pkg/sentry/kernel/uts_namespace.go @@ -15,9 +15,8 @@ package kernel import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" ) // UTSNamespace represents a UTS namespace, a holder of two system identifiers: diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 156e67bf8..9fa841e8b 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/sentry/context", + "//pkg/sync", ], ) diff --git a/pkg/sentry/limits/limits.go b/pkg/sentry/limits/limits.go index b6c22656b..31b9e9ff6 100644 --- a/pkg/sentry/limits/limits.go +++ b/pkg/sentry/limits/limits.go @@ -16,8 +16,9 @@ package limits import ( - "sync" "syscall" + + "gvisor.dev/gvisor/pkg/sync" ) // LimitType defines a type of resource limit. diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 839931f67..83e248431 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -118,7 +118,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/usage", "//pkg/sentry/usermem", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/buffer", ], diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 1b746d030..4b48866ad 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -15,8 +15,6 @@ package mm import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" @@ -25,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 58a5c186d..fa86ebced 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -35,8 +35,6 @@ package mm import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -44,7 +42,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // MemoryManager implements a virtual address space. @@ -82,7 +80,7 @@ type MemoryManager struct { users int32 // mappingMu is analogous to Linux's struct mm_struct::mmap_sem. - mappingMu syncutil.DowngradableRWMutex `state:"nosave"` + mappingMu sync.DowngradableRWMutex `state:"nosave"` // vmas stores virtual memory areas. Since vmas are stored by value, // clients should usually use vmaIterator.ValuePtr() instead of @@ -125,7 +123,7 @@ type MemoryManager struct { // activeMu is loosely analogous to Linux's struct // mm_struct::page_table_lock. - activeMu syncutil.DowngradableRWMutex `state:"nosave"` + activeMu sync.DowngradableRWMutex `state:"nosave"` // pmas stores platform mapping areas used to implement vmas. Since pmas // are stored by value, clients should usually use pmaIterator.ValuePtr() diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index f404107af..a9a2642c5 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -73,6 +73,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index f7f7298c4..c99e023d9 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -25,7 +25,6 @@ import ( "fmt" "math" "os" - "sync" "sync/atomic" "syscall" "time" @@ -37,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index b6d008dbe..85e882df9 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -10,6 +10,7 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/interrupt", visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go index a4651f500..57be41647 100644 --- a/pkg/sentry/platform/interrupt/interrupt.go +++ b/pkg/sentry/platform/interrupt/interrupt.go @@ -17,7 +17,8 @@ package interrupt import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Receiver receives interrupt notifications from a Forwarder. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index f3afd98da..6a358d1d4 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -55,6 +55,7 @@ go_library( "//pkg/sentry/platform/safecopy", "//pkg/sentry/time", "//pkg/sentry/usermem", + "//pkg/sync", ], ) diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index ea8b9632e..a25f3c449 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -15,13 +15,13 @@ package kvm import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // dirtySet tracks vCPUs for invalidation. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index e5fac0d6a..2f02c03cf 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -17,8 +17,6 @@ package kvm import ( - "unsafe" - "gvisor.dev/gvisor/pkg/sentry/arch" ) diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index f2c2c059e..a7850faed 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -18,13 +18,13 @@ package kvm import ( "fmt" "os" - "sync" "syscall" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // KVM represents a lightweight VM context. diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 7d02ebf19..e6d912168 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -17,7 +17,6 @@ package kvm import ( "fmt" "runtime" - "sync" "sync/atomic" "syscall" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // machine contains state associated with the VM as a whole. diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 0df8cfa0f..cd13390c3 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -33,6 +33,7 @@ go_library( "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/safecopy", "//pkg/sentry/usermem", + "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index 7b120a15d..bb0e03880 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -46,13 +46,13 @@ package ptrace import ( "os" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 20244fd95..15dc46a5b 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -18,7 +18,6 @@ import ( "fmt" "os" "runtime" - "sync" "syscall" "golang.org/x/sys/unix" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // Linux kernel errnos which "should never be seen by user programs", but will diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go index 2e6fbe488..245b20722 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go @@ -18,7 +18,6 @@ package ptrace import ( - "sync" "sync/atomic" "syscall" "unsafe" @@ -26,6 +25,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/hostcpu" + "gvisor.dev/gvisor/pkg/sync" ) // maskPool contains reusable CPU masks for setting affinity. Unfortunately, diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index 3f094c2a7..86fd5ed58 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -17,7 +17,7 @@ package ring0 import ( "syscall" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) // Kernel is a global kernel object. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 10dbd381f..9dae0dccb 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -18,6 +18,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) var ( diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index dc0eeec01..a850ce6cf 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -18,6 +18,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) var ( diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index e2e15ba5c..387a7f6c3 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -96,7 +96,10 @@ go_library( "//pkg/sentry/platform/kvm:__subpackages__", "//pkg/sentry/platform/ring0:__subpackages__", ], - deps = ["//pkg/sentry/usermem"], + deps = [ + "//pkg/sentry/usermem", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go index 0f029f25d..e199bae18 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go +++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go @@ -17,7 +17,7 @@ package pagetables import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // limitPCID is the number of valid PCIDs. diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 136821963..103933144 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -27,6 +27,7 @@ go_library( "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 463544c1a..2d9f4ba9b 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -8,6 +8,7 @@ go_library( srcs = ["port.go"], importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port", visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go index e9d3275b1..2cd3afc22 100644 --- a/pkg/sentry/socket/netlink/port/port.go +++ b/pkg/sentry/socket/netlink/port/port.go @@ -24,7 +24,8 @@ import ( "fmt" "math" "math/rand" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // maxPorts is a sanity limit on the maximum number of ports to allocate per diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index d2e3644a6..cea56f4ed 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -17,7 +17,6 @@ package netlink import ( "math" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index e414d8055..f78784569 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -34,6 +34,7 @@ go_library( "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 764f11a6b..0affb8071 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -29,7 +29,6 @@ import ( "io" "math" "reflect" - "sync" "syscall" "time" @@ -49,6 +48,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/socket/rpcinet/conn/BUILD b/pkg/sentry/socket/rpcinet/conn/BUILD index 23eadcb1b..b2677c659 100644 --- a/pkg/sentry/socket/rpcinet/conn/BUILD +++ b/pkg/sentry/socket/rpcinet/conn/BUILD @@ -10,6 +10,7 @@ go_library( deps = [ "//pkg/binary", "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", + "//pkg/sync", "//pkg/syserr", "//pkg/unet", "@com_github_golang_protobuf//proto:go_default_library", diff --git a/pkg/sentry/socket/rpcinet/conn/conn.go b/pkg/sentry/socket/rpcinet/conn/conn.go index 356adad99..02f39c767 100644 --- a/pkg/sentry/socket/rpcinet/conn/conn.go +++ b/pkg/sentry/socket/rpcinet/conn/conn.go @@ -17,12 +17,12 @@ package conn import ( "fmt" - "sync" "sync/atomic" "syscall" "github.com/golang/protobuf/proto" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/unet" diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD index a3585e10d..a5954f22b 100644 --- a/pkg/sentry/socket/rpcinet/notifier/BUILD +++ b/pkg/sentry/socket/rpcinet/notifier/BUILD @@ -10,6 +10,7 @@ go_library( deps = [ "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", "//pkg/sentry/socket/rpcinet/conn", + "//pkg/sync", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go index 7efe4301f..82b75d6dd 100644 --- a/pkg/sentry/socket/rpcinet/notifier/notifier.go +++ b/pkg/sentry/socket/rpcinet/notifier/notifier.go @@ -17,12 +17,12 @@ package notifier import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 788ad70d2..d7ba95dff 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/ilist", "//pkg/refs", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index dea11e253..9e6fbc111 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -15,10 +15,9 @@ package transport import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index e27b1c714..5dcd3d95e 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -15,9 +15,8 @@ package transport import ( - "sync" - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 37c7ac3c1..fcc0da332 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,11 +16,11 @@ package transport import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index a76975cee..aa05e208a 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -91,6 +91,7 @@ go_library( "//pkg/sentry/syscalls", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/waiter", diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 1d9018c96..60469549d 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -16,13 +16,13 @@ package linux import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 18e212dff..3cde3a0be 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "seqatomic_parameters_unsafe.go", package = "time", suffix = "Parameters", - template = "//pkg/syncutil:generic_seqatomic", + template = "//pkg/sync:generic_seqatomic", types = { "Value": "Parameters", }, @@ -36,7 +36,7 @@ go_library( deps = [ "//pkg/log", "//pkg/metric", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go index 318503277..f9a93115d 100644 --- a/pkg/sentry/time/calibrated_clock.go +++ b/pkg/sentry/time/calibrated_clock.go @@ -17,11 +17,11 @@ package time import ( - "sync" "time" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD index c32fe3241..5518ac3d0 100644 --- a/pkg/sentry/usage/BUILD +++ b/pkg/sentry/usage/BUILD @@ -18,5 +18,6 @@ go_library( deps = [ "//pkg/bits", "//pkg/memutil", + "//pkg/sync", ], ) diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go index d6ef644d8..538c645eb 100644 --- a/pkg/sentry/usage/memory.go +++ b/pkg/sentry/usage/memory.go @@ -17,12 +17,12 @@ package usage import ( "fmt" "os" - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/sync" ) // MemoryKind represents a type of memory used by the application. diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 4c6aa04a1..35c7be259 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -34,7 +34,7 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/usermem", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], @@ -54,6 +54,7 @@ go_test( "//pkg/sentry/context/contexttest", "//pkg/sentry/kernel/auth", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 1bc9c4a38..486a76475 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -16,9 +16,9 @@ package vfs import ( "fmt" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 66eb57bc2..c00b3c84b 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -17,13 +17,13 @@ package vfs import ( "bytes" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index adff0b94b..3b933468d 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -17,8 +17,9 @@ package vfs import ( "fmt" "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) func TestMountTableLookupEmpty(t *testing.T) { diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index ab13fa461..bd90d36c4 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -26,7 +26,7 @@ import ( "sync/atomic" "unsafe" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // mountKey represents the location at which a Mount is mounted. It is @@ -75,7 +75,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq syncutil.SeqCount + seq sync.SeqCount seed uint32 // for hashing keys // size holds both length (number of elements) and capacity (number of diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go index 8e155654f..cf80df90e 100644 --- a/pkg/sentry/vfs/pathname.go +++ b/pkg/sentry/vfs/pathname.go @@ -15,10 +15,9 @@ package vfs import ( - "sync" - "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index f0641d314..8a0b382f6 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -16,11 +16,11 @@ package vfs import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index ea2db7031..1f21b0b31 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -29,12 +29,12 @@ package vfs import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD index 4d8435265..28f21f13d 100644 --- a/pkg/sentry/watchdog/BUILD +++ b/pkg/sentry/watchdog/BUILD @@ -13,5 +13,6 @@ go_library( "//pkg/metric", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", + "//pkg/sync", ], ) diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 5e4611333..bfb2fac26 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -32,7 +32,6 @@ package watchdog import ( "bytes" "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -40,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sync" ) // Opts configures the watchdog. diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD new file mode 100644 index 000000000..e8cd16b8f --- /dev/null +++ b/pkg/sync/BUILD @@ -0,0 +1,53 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +exports_files(["LICENSE"]) + +go_template( + name = "generic_atomicptr", + srcs = ["atomicptr_unsafe.go"], + types = [ + "Value", + ], +) + +go_template( + name = "generic_seqatomic", + srcs = ["seqatomic_unsafe.go"], + types = [ + "Value", + ], + deps = [ + ":sync", + ], +) + +go_library( + name = "sync", + srcs = [ + "aliases.go", + "downgradable_rwmutex_unsafe.go", + "memmove_unsafe.go", + "norace_unsafe.go", + "race_unsafe.go", + "seqcount.go", + "syncutil.go", + ], + importpath = "gvisor.dev/gvisor/pkg/sync", +) + +go_test( + name = "sync_test", + size = "small", + srcs = [ + "downgradable_rwmutex_test.go", + "seqcount_test.go", + ], + embed = [":sync"], +) diff --git a/pkg/sync/LICENSE b/pkg/sync/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/pkg/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/sync/README.md b/pkg/sync/README.md new file mode 100644 index 000000000..2183c4e20 --- /dev/null +++ b/pkg/sync/README.md @@ -0,0 +1,5 @@ +# Syncutil + +This package provides additional synchronization primitives not provided by the +Go stdlib 'sync' package. It is partially derived from the upstream 'sync' +package from go1.10. diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go new file mode 100644 index 000000000..20c7ca041 --- /dev/null +++ b/pkg/sync/aliases.go @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "sync" +) + +// Aliases of standard library types. +type ( + // Mutex is an alias of sync.Mutex. + Mutex = sync.Mutex + + // RWMutex is an alias of sync.RWMutex. + RWMutex = sync.RWMutex + + // Cond is an alias of sync.Cond. + Cond = sync.Cond + + // Locker is an alias of sync.Locker. + Locker = sync.Locker + + // Once is an alias of sync.Once. + Once = sync.Once + + // Pool is an alias of sync.Pool. + Pool = sync.Pool + + // WaitGroup is an alias of sync.WaitGroup. + WaitGroup = sync.WaitGroup + + // Map is an alias of sync.Map. + Map = sync.Map +) diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sync/atomicptr_unsafe.go new file mode 100644 index 000000000..525c4beed --- /dev/null +++ b/pkg/sync/atomicptr_unsafe.go @@ -0,0 +1,47 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "sync/atomic" + "unsafe" +) + +// Value is a required type parameter. +type Value struct{} + +// An AtomicPtr is a pointer to a value of type Value that can be atomically +// loaded and stored. The zero value of an AtomicPtr represents nil. +// +// Note that copying AtomicPtr by value performs a non-atomic read of the +// stored pointer, which is unsafe if Store() can be called concurrently; in +// this case, do `dst.Store(src.Load())` instead. +// +// +stateify savable +type AtomicPtr struct { + ptr unsafe.Pointer `state:".(*Value)"` +} + +func (p *AtomicPtr) savePtr() *Value { + return p.Load() +} + +func (p *AtomicPtr) loadPtr(v *Value) { + p.Store(v) +} + +// Load returns the value set by the most recent Store. It returns nil if there +// has been no previous call to Store. +func (p *AtomicPtr) Load() *Value { + return (*Value)(atomic.LoadPointer(&p.ptr)) +} + +// Store sets the value returned by Load to x. +func (p *AtomicPtr) Store(x *Value) { + atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) +} diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD new file mode 100644 index 000000000..418eda29c --- /dev/null +++ b/pkg/sync/atomicptrtest/BUILD @@ -0,0 +1,29 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "atomicptr_int", + out = "atomicptr_int_unsafe.go", + package = "atomicptr", + suffix = "Int", + template = "//pkg/sync:generic_atomicptr", + types = { + "Value": "int", + }, +) + +go_library( + name = "atomicptr", + srcs = ["atomicptr_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/sync/atomicptr", +) + +go_test( + name = "atomicptr_test", + size = "small", + srcs = ["atomicptr_test.go"], + embed = [":atomicptr"], +) diff --git a/pkg/sync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptrtest/atomicptr_test.go new file mode 100644 index 000000000..8fdc5112e --- /dev/null +++ b/pkg/sync/atomicptrtest/atomicptr_test.go @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package atomicptr + +import ( + "testing" +) + +func newInt(val int) *int { + return &val +} + +func TestAtomicPtr(t *testing.T) { + var p AtomicPtrInt + if got := p.Load(); got != nil { + t.Errorf("initial value is %p (%v), wanted nil", got, got) + } + want := newInt(42) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } + want = newInt(100) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } +} diff --git a/pkg/sync/downgradable_rwmutex_test.go b/pkg/sync/downgradable_rwmutex_test.go new file mode 100644 index 000000000..f04496bc5 --- /dev/null +++ b/pkg/sync/downgradable_rwmutex_test.go @@ -0,0 +1,150 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// GOMAXPROCS=10 go test + +// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the +// addition of downgradingWriter and the renaming of num_iterations to +// numIterations to shut up Golint. + +package sync + +import ( + "fmt" + "runtime" + "sync/atomic" + "testing" +) + +func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) { + m.RLock() + clocked <- true + <-cunlock + m.RUnlock() + cdone <- true +} + +func doTestParallelReaders(numReaders, gomaxprocs int) { + runtime.GOMAXPROCS(gomaxprocs) + var m DowngradableRWMutex + clocked := make(chan bool) + cunlock := make(chan bool) + cdone := make(chan bool) + for i := 0; i < numReaders; i++ { + go parallelReader(&m, clocked, cunlock, cdone) + } + // Wait for all parallel RLock()s to succeed. + for i := 0; i < numReaders; i++ { + <-clocked + } + for i := 0; i < numReaders; i++ { + cunlock <- true + } + // Wait for the goroutines to finish. + for i := 0; i < numReaders; i++ { + <-cdone + } +} + +func TestParallelReaders(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + doTestParallelReaders(1, 4) + doTestParallelReaders(3, 4) + doTestParallelReaders(4, 2) +} + +func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.RLock() + n := atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.Unlock() + } + cdone <- true +} + +func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.DowngradeLock() + n = atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + n = atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { + runtime.GOMAXPROCS(gomaxprocs) + // Number of active readers + 10000 * number of active writers. + var activity int32 + var rwm DowngradableRWMutex + cdone := make(chan bool) + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + var i int + for i = 0; i < numReaders/2; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + for ; i < numReaders; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + // Wait for the 4 writers and all readers to finish. + for i := 0; i < 4+numReaders; i++ { + <-cdone + } +} + +func TestDowngradableRWMutex(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + n := 1000 + if testing.Short() { + n = 5 + } + HammerDowngradableRWMutex(1, 1, n) + HammerDowngradableRWMutex(1, 3, n) + HammerDowngradableRWMutex(1, 10, n) + HammerDowngradableRWMutex(4, 1, n) + HammerDowngradableRWMutex(4, 3, n) + HammerDowngradableRWMutex(4, 10, n) + HammerDowngradableRWMutex(10, 1, n) + HammerDowngradableRWMutex(10, 3, n) + HammerDowngradableRWMutex(10, 10, n) + HammerDowngradableRWMutex(10, 5, n) +} diff --git a/pkg/sync/downgradable_rwmutex_unsafe.go b/pkg/sync/downgradable_rwmutex_unsafe.go new file mode 100644 index 000000000..9bb55cd3a --- /dev/null +++ b/pkg/sync/downgradable_rwmutex_unsafe.go @@ -0,0 +1,146 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +// This is mostly copied from the standard library's sync/rwmutex.go. +// +// Happens-before relationships indicated to the race detector: +// - Unlock -> Lock (via writerSem) +// - Unlock -> RLock (via readerSem) +// - RUnlock -> Lock (via writerSem) +// - DowngradeLock -> RLock (via readerSem) + +package sync + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +//go:linkname runtimeSemacquire sync.runtime_Semacquire +func runtimeSemacquire(s *uint32) + +//go:linkname runtimeSemrelease sync.runtime_Semrelease +func runtimeSemrelease(s *uint32, handoff bool, skipframes int) + +// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock +// method. +type DowngradableRWMutex struct { + w sync.Mutex // held if there are pending writers + writerSem uint32 // semaphore for writers to wait for completing readers + readerSem uint32 // semaphore for readers to wait for completing writers + readerCount int32 // number of pending readers + readerWait int32 // number of departing readers +} + +const rwmutexMaxReaders = 1 << 30 + +// RLock locks rw for reading. +func (rw *DowngradableRWMutex) RLock() { + if RaceEnabled { + RaceDisable() + } + if atomic.AddInt32(&rw.readerCount, 1) < 0 { + // A writer is pending, wait for it. + runtimeSemacquire(&rw.readerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.readerSem)) + } +} + +// RUnlock undoes a single RLock call. +func (rw *DowngradableRWMutex) RUnlock() { + if RaceEnabled { + RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) + RaceDisable() + } + if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 { + if r+1 == 0 || r+1 == -rwmutexMaxReaders { + panic("RUnlock of unlocked DowngradableRWMutex") + } + // A writer is pending. + if atomic.AddInt32(&rw.readerWait, -1) == 0 { + // The last reader unblocks the writer. + runtimeSemrelease(&rw.writerSem, false, 0) + } + } + if RaceEnabled { + RaceEnable() + } +} + +// Lock locks rw for writing. +func (rw *DowngradableRWMutex) Lock() { + if RaceEnabled { + RaceDisable() + } + // First, resolve competition with other writers. + rw.w.Lock() + // Announce to readers there is a pending writer. + r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders + // Wait for active readers. + if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 { + runtimeSemacquire(&rw.writerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.writerSem)) + } +} + +// Unlock unlocks rw for writing. +func (rw *DowngradableRWMutex) Unlock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.writerSem)) + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders) + if r >= rwmutexMaxReaders { + panic("Unlock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. + for i := 0; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} + +// DowngradeLock atomically unlocks rw for writing and locks it for reading. +func (rw *DowngradableRWMutex) DowngradeLock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer and one additional reader. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1) + if r >= rwmutexMaxReaders+1 { + panic("DowngradeLock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. Note that this loop starts as 1 since r + // includes this goroutine. + for i := 1; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed to rw.w.Lock(). Note that they will still + // block on rw.writerSem since at least this reader exists, such that + // DowngradeLock() is atomic with the previous write lock. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go new file mode 100644 index 000000000..ad4a3a37e --- /dev/null +++ b/pkg/sync/memmove_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.12 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +package sync + +import ( + "unsafe" +) + +//go:linkname memmove runtime.memmove +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) + +// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad, which can't +// define it because go_generics can't update the go:linkname annotation. +// Furthermore, go:linkname silently doesn't work if the local name is exported +// (this is of course undocumented), which is why this indirection is +// necessary. +func Memmove(to, from unsafe.Pointer, n uintptr) { + memmove(to, from, n) +} diff --git a/pkg/sync/norace_unsafe.go b/pkg/sync/norace_unsafe.go new file mode 100644 index 000000000..006055dd6 --- /dev/null +++ b/pkg/sync/norace_unsafe.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !race + +package sync + +import ( + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = false + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { +} diff --git a/pkg/sync/race_unsafe.go b/pkg/sync/race_unsafe.go new file mode 100644 index 000000000..31d8fa9a6 --- /dev/null +++ b/pkg/sync/race_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build race + +package sync + +import ( + "runtime" + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = true + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { + runtime.RaceDisable() +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { + runtime.RaceEnable() +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { + runtime.RaceAcquire(addr) +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { + runtime.RaceRelease(addr) +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { + runtime.RaceReleaseMerge(addr) +} diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go new file mode 100644 index 000000000..eda6fb131 --- /dev/null +++ b/pkg/sync/seqatomic_unsafe.go @@ -0,0 +1,72 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "fmt" + "reflect" + "strings" + "unsafe" + + "gvisor.dev/gvisor/pkg/sync" +) + +// Value is a required type parameter. +// +// Value must not contain any pointers, including interface objects, function +// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs +// containing any of the above. An init() function will panic if this property +// does not hold. +type Value struct{} + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in sc. +func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value { + // This function doesn't use SeqAtomicTryLoad because doing so is + // measurably, significantly (~20%) slower; Go is awful at inlining. + var val Value + for { + epoch := sc.BeginRead() + if sync.RaceEnabled { + // runtime.RaceDisable() doesn't actually stop the race detector, + // so it can't help us here. Instead, call runtime.memmove + // directly, which is not instrumented by the race detector. + sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + // This is ~40% faster for short reads than going through memmove. + val = *ptr + } + if sc.ReadOk(epoch) { + break + } + } + return val +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read +// would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) { + var val Value + if sync.RaceEnabled { + sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + val = *ptr + } + return val, sc.ReadOk(epoch) +} + +func init() { + var val Value + typ := reflect.TypeOf(val) + name := typ.Name() + if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 { + panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) + } +} diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD new file mode 100644 index 000000000..eba21518d --- /dev/null +++ b/pkg/sync/seqatomictest/BUILD @@ -0,0 +1,33 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "seqatomic_int", + out = "seqatomic_int_unsafe.go", + package = "seqatomic", + suffix = "Int", + template = "//pkg/sync:generic_seqatomic", + types = { + "Value": "int", + }, +) + +go_library( + name = "seqatomic", + srcs = ["seqatomic_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/sync/seqatomic", + deps = [ + "//pkg/sync", + ], +) + +go_test( + name = "seqatomic_test", + size = "small", + srcs = ["seqatomic_test.go"], + embed = [":seqatomic"], + deps = ["//pkg/sync"], +) diff --git a/pkg/sync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomictest/seqatomic_test.go new file mode 100644 index 000000000..2c4568b07 --- /dev/null +++ b/pkg/sync/seqatomictest/seqatomic_test.go @@ -0,0 +1,132 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seqatomic + +import ( + "sync/atomic" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +func TestSeqAtomicLoadUncontended(t *testing.T) { + var seq sync.SeqCount + const want = 1 + data := want + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadAfterWrite(t *testing.T) { + var seq sync.SeqCount + var data int + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadDuringWrite(t *testing.T) { + var seq sync.SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicTryLoadUncontended(t *testing.T) { + var seq sync.SeqCount + const want = 1 + data := want + epoch := seq.BeginRead() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } +} + +func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { + var seq sync.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } + seq.EndWrite() +} + +func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { + var seq sync.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } +} + +func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { + var seq sync.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := SeqAtomicLoadInt(&seq, &data); got != want { + b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } + } + }) +} + +func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { + var seq sync.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + epoch := seq.BeginRead() + for pb.Next() { + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } + } + }) +} + +// For comparison: +func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { + var a atomic.Value + const want = 42 + a.Store(int(want)) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := a.Load().(int); got != want { + b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) + } + } + }) +} diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go new file mode 100644 index 000000000..a1e895352 --- /dev/null +++ b/pkg/sync/seqcount.go @@ -0,0 +1,149 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "fmt" + "reflect" + "runtime" + "sync/atomic" +) + +// SeqCount is a synchronization primitive for optimistic reader/writer +// synchronization in cases where readers can work with stale data and +// therefore do not need to block writers. +// +// Compared to sync/atomic.Value: +// +// - Mutation of SeqCount-protected data does not require memory allocation, +// whereas atomic.Value generally does. This is a significant advantage when +// writes are common. +// +// - Atomic reads of SeqCount-protected data require copying. This is a +// disadvantage when atomic reads are common. +// +// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other +// operations to be made atomic with reads of SeqCount-protected data. +// +// - SeqCount may be less flexible: as of this writing, SeqCount-protected data +// cannot include pointers. +// +// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected +// data require instantiating function templates using go_generics (see +// seqatomic.go). +type SeqCount struct { + // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd + // if a writer critical section is active, and a read from data protected + // by this SeqCount is atomic iff epoch is the same even value before and + // after the read. + epoch uint32 +} + +// SeqCountEpoch tracks writer critical sections in a SeqCount. +type SeqCountEpoch struct { + val uint32 +} + +// We assume that: +// +// - All functions in sync/atomic that perform a memory read are at least a +// read fence: memory reads before calls to such functions cannot be reordered +// after the call, and memory reads after calls to such functions cannot be +// reordered before the call, even if those reads do not use sync/atomic. +// +// - All functions in sync/atomic that perform a memory write are at least a +// write fence: memory writes before calls to such functions cannot be +// reordered after the call, and memory writes after calls to such functions +// cannot be reordered before the call, even if those writes do not use +// sync/atomic. +// +// As of this writing, the Go memory model completely fails to describe +// sync/atomic, but these properties are implied by +// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8. + +// BeginRead indicates the beginning of a reader critical section. Reader +// critical sections DO NOT BLOCK writer critical sections, so operations in a +// reader critical section MAY RACE with writer critical sections. Races are +// detected by ReadOk at the end of the reader critical section. Thus, the +// low-level structure of readers is generally: +// +// for { +// epoch := seq.BeginRead() +// // do something idempotent with seq-protected data +// if seq.ReadOk(epoch) { +// break +// } +// } +// +// However, since reader critical sections may race with writer critical +// sections, the Go race detector will (accurately) flag data races in readers +// using this pattern. Most users of SeqCount will need to use the +// SeqAtomicLoad function template in seqatomic.go. +func (s *SeqCount) BeginRead() SeqCountEpoch { + epoch := atomic.LoadUint32(&s.epoch) + for epoch&1 != 0 { + runtime.Gosched() + epoch = atomic.LoadUint32(&s.epoch) + } + return SeqCountEpoch{epoch} +} + +// ReadOk returns true if the reader critical section initiated by a previous +// call to BeginRead() that returned epoch did not race with any writer critical +// sections. +// +// ReadOk may be called any number of times during a reader critical section. +// Reader critical sections do not need to be explicitly terminated; the last +// call to ReadOk is implicitly the end of the reader critical section. +func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { + return atomic.LoadUint32(&s.epoch) == epoch.val +} + +// BeginWrite indicates the beginning of a writer critical section. +// +// SeqCount does not support concurrent writer critical sections; clients with +// concurrent writers must synchronize them using e.g. sync.Mutex. +func (s *SeqCount) BeginWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 { + panic("SeqCount.BeginWrite during writer critical section") + } +} + +// EndWrite ends the effect of a preceding BeginWrite. +func (s *SeqCount) EndWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 { + panic("SeqCount.EndWrite outside writer critical section") + } +} + +// PointersInType returns a list of pointers reachable from values named +// valName of the given type. +// +// PointersInType is not exhaustive, but it is guaranteed that if typ contains +// at least one pointer, then PointersInTypeOf returns a non-empty list. +func PointersInType(typ reflect.Type, valName string) []string { + switch kind := typ.Kind(); kind { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return nil + + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer: + return []string{valName} + + case reflect.Array: + return PointersInType(typ.Elem(), valName+"[]") + + case reflect.Struct: + var ptrs []string + for i, n := 0, typ.NumField(); i < n; i++ { + field := typ.Field(i) + ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...) + } + return ptrs + + default: + return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)} + } +} diff --git a/pkg/sync/seqcount_test.go b/pkg/sync/seqcount_test.go new file mode 100644 index 000000000..6eb7b4b59 --- /dev/null +++ b/pkg/sync/seqcount_test.go @@ -0,0 +1,153 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "reflect" + "testing" + "time" +) + +func TestSeqCountWriteUncontended(t *testing.T) { + var seq SeqCount + seq.BeginWrite() + seq.EndWrite() +} + +func TestSeqCountReadUncontended(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadAfterWrite(t *testing.T) { + var seq SeqCount + var data int32 + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadDuringWrite(t *testing.T) { + var seq SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountReadOkAfterWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } +} + +func TestSeqCountReadOkDuringWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } + seq.EndWrite() +} + +func BenchmarkSeqCountWriteUncontended(b *testing.B) { + var seq SeqCount + for i := 0; i < b.N; i++ { + seq.BeginWrite() + seq.EndWrite() + } +} + +func BenchmarkSeqCountReadUncontended(b *testing.B) { + var seq SeqCount + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + b.Fatalf("ReadOk: got false, wanted true") + } + } + }) +} + +func TestPointersInType(t *testing.T) { + for _, test := range []struct { + name string // used for both test and value name + val interface{} + ptrs []string + }{ + { + name: "EmptyStruct", + val: struct{}{}, + }, + { + name: "Int", + val: int(0), + }, + { + name: "MixedStruct", + val: struct { + b bool + I int + ExportedPtr *struct{} + unexportedPtr *struct{} + arr [2]int + ptrArr [2]*int + nestedStruct struct { + nestedNonptr int + nestedPtr *int + } + structArr [1]struct { + nonptr int + ptr *int + } + }{}, + ptrs: []string{ + "MixedStruct.ExportedPtr", + "MixedStruct.unexportedPtr", + "MixedStruct.ptrArr[]", + "MixedStruct.nestedStruct.nestedPtr", + "MixedStruct.structArr[].ptr", + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + typ := reflect.TypeOf(test.val) + ptrs := PointersInType(typ, test.name) + t.Logf("Found pointers: %v", ptrs) + if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { + t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) + } + }) + } +} diff --git a/pkg/sync/syncutil.go b/pkg/sync/syncutil.go new file mode 100644 index 000000000..b16cf5333 --- /dev/null +++ b/pkg/sync/syncutil.go @@ -0,0 +1,7 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sync provides synchronization primitives. +package sync diff --git a/pkg/syncutil/BUILD b/pkg/syncutil/BUILD deleted file mode 100644 index cb1f41628..000000000 --- a/pkg/syncutil/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -exports_files(["LICENSE"]) - -go_template( - name = "generic_atomicptr", - srcs = ["atomicptr_unsafe.go"], - types = [ - "Value", - ], -) - -go_template( - name = "generic_seqatomic", - srcs = ["seqatomic_unsafe.go"], - types = [ - "Value", - ], - deps = [ - ":sync", - ], -) - -go_library( - name = "syncutil", - srcs = [ - "downgradable_rwmutex_unsafe.go", - "memmove_unsafe.go", - "norace_unsafe.go", - "race_unsafe.go", - "seqcount.go", - "syncutil.go", - ], - importpath = "gvisor.dev/gvisor/pkg/syncutil", -) - -go_test( - name = "syncutil_test", - size = "small", - srcs = [ - "downgradable_rwmutex_test.go", - "seqcount_test.go", - ], - embed = [":syncutil"], -) diff --git a/pkg/syncutil/LICENSE b/pkg/syncutil/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/syncutil/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/syncutil/README.md b/pkg/syncutil/README.md deleted file mode 100644 index 2183c4e20..000000000 --- a/pkg/syncutil/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Syncutil - -This package provides additional synchronization primitives not provided by the -Go stdlib 'sync' package. It is partially derived from the upstream 'sync' -package from go1.10. diff --git a/pkg/syncutil/atomicptr_unsafe.go b/pkg/syncutil/atomicptr_unsafe.go deleted file mode 100644 index 525c4beed..000000000 --- a/pkg/syncutil/atomicptr_unsafe.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package template doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package template - -import ( - "sync/atomic" - "unsafe" -) - -// Value is a required type parameter. -type Value struct{} - -// An AtomicPtr is a pointer to a value of type Value that can be atomically -// loaded and stored. The zero value of an AtomicPtr represents nil. -// -// Note that copying AtomicPtr by value performs a non-atomic read of the -// stored pointer, which is unsafe if Store() can be called concurrently; in -// this case, do `dst.Store(src.Load())` instead. -// -// +stateify savable -type AtomicPtr struct { - ptr unsafe.Pointer `state:".(*Value)"` -} - -func (p *AtomicPtr) savePtr() *Value { - return p.Load() -} - -func (p *AtomicPtr) loadPtr(v *Value) { - p.Store(v) -} - -// Load returns the value set by the most recent Store. It returns nil if there -// has been no previous call to Store. -func (p *AtomicPtr) Load() *Value { - return (*Value)(atomic.LoadPointer(&p.ptr)) -} - -// Store sets the value returned by Load to x. -func (p *AtomicPtr) Store(x *Value) { - atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) -} diff --git a/pkg/syncutil/atomicptrtest/BUILD b/pkg/syncutil/atomicptrtest/BUILD deleted file mode 100644 index 63f411a90..000000000 --- a/pkg/syncutil/atomicptrtest/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "atomicptr_int", - out = "atomicptr_int_unsafe.go", - package = "atomicptr", - suffix = "Int", - template = "//pkg/syncutil:generic_atomicptr", - types = { - "Value": "int", - }, -) - -go_library( - name = "atomicptr", - srcs = ["atomicptr_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/syncutil/atomicptr", -) - -go_test( - name = "atomicptr_test", - size = "small", - srcs = ["atomicptr_test.go"], - embed = [":atomicptr"], -) diff --git a/pkg/syncutil/atomicptrtest/atomicptr_test.go b/pkg/syncutil/atomicptrtest/atomicptr_test.go deleted file mode 100644 index 8fdc5112e..000000000 --- a/pkg/syncutil/atomicptrtest/atomicptr_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package atomicptr - -import ( - "testing" -) - -func newInt(val int) *int { - return &val -} - -func TestAtomicPtr(t *testing.T) { - var p AtomicPtrInt - if got := p.Load(); got != nil { - t.Errorf("initial value is %p (%v), wanted nil", got, got) - } - want := newInt(42) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } - want = newInt(100) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } -} diff --git a/pkg/syncutil/downgradable_rwmutex_test.go b/pkg/syncutil/downgradable_rwmutex_test.go deleted file mode 100644 index ffaf7ecc7..000000000 --- a/pkg/syncutil/downgradable_rwmutex_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// GOMAXPROCS=10 go test - -// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the -// addition of downgradingWriter and the renaming of num_iterations to -// numIterations to shut up Golint. - -package syncutil - -import ( - "fmt" - "runtime" - "sync/atomic" - "testing" -) - -func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) { - m.RLock() - clocked <- true - <-cunlock - m.RUnlock() - cdone <- true -} - -func doTestParallelReaders(numReaders, gomaxprocs int) { - runtime.GOMAXPROCS(gomaxprocs) - var m DowngradableRWMutex - clocked := make(chan bool) - cunlock := make(chan bool) - cdone := make(chan bool) - for i := 0; i < numReaders; i++ { - go parallelReader(&m, clocked, cunlock, cdone) - } - // Wait for all parallel RLock()s to succeed. - for i := 0; i < numReaders; i++ { - <-clocked - } - for i := 0; i < numReaders; i++ { - cunlock <- true - } - // Wait for the goroutines to finish. - for i := 0; i < numReaders; i++ { - <-cdone - } -} - -func TestParallelReaders(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - doTestParallelReaders(1, 4) - doTestParallelReaders(3, 4) - doTestParallelReaders(4, 2) -} - -func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.RLock() - n := atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.Unlock() - } - cdone <- true -} - -func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.DowngradeLock() - n = atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - n = atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { - runtime.GOMAXPROCS(gomaxprocs) - // Number of active readers + 10000 * number of active writers. - var activity int32 - var rwm DowngradableRWMutex - cdone := make(chan bool) - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - var i int - for i = 0; i < numReaders/2; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - for ; i < numReaders; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - // Wait for the 4 writers and all readers to finish. - for i := 0; i < 4+numReaders; i++ { - <-cdone - } -} - -func TestDowngradableRWMutex(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - n := 1000 - if testing.Short() { - n = 5 - } - HammerDowngradableRWMutex(1, 1, n) - HammerDowngradableRWMutex(1, 3, n) - HammerDowngradableRWMutex(1, 10, n) - HammerDowngradableRWMutex(4, 1, n) - HammerDowngradableRWMutex(4, 3, n) - HammerDowngradableRWMutex(4, 10, n) - HammerDowngradableRWMutex(10, 1, n) - HammerDowngradableRWMutex(10, 3, n) - HammerDowngradableRWMutex(10, 10, n) - HammerDowngradableRWMutex(10, 5, n) -} diff --git a/pkg/syncutil/downgradable_rwmutex_unsafe.go b/pkg/syncutil/downgradable_rwmutex_unsafe.go deleted file mode 100644 index 51e11555d..000000000 --- a/pkg/syncutil/downgradable_rwmutex_unsafe.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.13 -// +build !go1.15 - -// Check go:linkname function signatures when updating Go version. - -// This is mostly copied from the standard library's sync/rwmutex.go. -// -// Happens-before relationships indicated to the race detector: -// - Unlock -> Lock (via writerSem) -// - Unlock -> RLock (via readerSem) -// - RUnlock -> Lock (via writerSem) -// - DowngradeLock -> RLock (via readerSem) - -package syncutil - -import ( - "sync" - "sync/atomic" - "unsafe" -) - -//go:linkname runtimeSemacquire sync.runtime_Semacquire -func runtimeSemacquire(s *uint32) - -//go:linkname runtimeSemrelease sync.runtime_Semrelease -func runtimeSemrelease(s *uint32, handoff bool, skipframes int) - -// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock -// method. -type DowngradableRWMutex struct { - w sync.Mutex // held if there are pending writers - writerSem uint32 // semaphore for writers to wait for completing readers - readerSem uint32 // semaphore for readers to wait for completing writers - readerCount int32 // number of pending readers - readerWait int32 // number of departing readers -} - -const rwmutexMaxReaders = 1 << 30 - -// RLock locks rw for reading. -func (rw *DowngradableRWMutex) RLock() { - if RaceEnabled { - RaceDisable() - } - if atomic.AddInt32(&rw.readerCount, 1) < 0 { - // A writer is pending, wait for it. - runtimeSemacquire(&rw.readerSem) - } - if RaceEnabled { - RaceEnable() - RaceAcquire(unsafe.Pointer(&rw.readerSem)) - } -} - -// RUnlock undoes a single RLock call. -func (rw *DowngradableRWMutex) RUnlock() { - if RaceEnabled { - RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) - RaceDisable() - } - if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 { - if r+1 == 0 || r+1 == -rwmutexMaxReaders { - panic("RUnlock of unlocked DowngradableRWMutex") - } - // A writer is pending. - if atomic.AddInt32(&rw.readerWait, -1) == 0 { - // The last reader unblocks the writer. - runtimeSemrelease(&rw.writerSem, false, 0) - } - } - if RaceEnabled { - RaceEnable() - } -} - -// Lock locks rw for writing. -func (rw *DowngradableRWMutex) Lock() { - if RaceEnabled { - RaceDisable() - } - // First, resolve competition with other writers. - rw.w.Lock() - // Announce to readers there is a pending writer. - r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders - // Wait for active readers. - if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 { - runtimeSemacquire(&rw.writerSem) - } - if RaceEnabled { - RaceEnable() - RaceAcquire(unsafe.Pointer(&rw.writerSem)) - } -} - -// Unlock unlocks rw for writing. -func (rw *DowngradableRWMutex) Unlock() { - if RaceEnabled { - RaceRelease(unsafe.Pointer(&rw.writerSem)) - RaceRelease(unsafe.Pointer(&rw.readerSem)) - RaceDisable() - } - // Announce to readers there is no active writer. - r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders) - if r >= rwmutexMaxReaders { - panic("Unlock of unlocked DowngradableRWMutex") - } - // Unblock blocked readers, if any. - for i := 0; i < int(r); i++ { - runtimeSemrelease(&rw.readerSem, false, 0) - } - // Allow other writers to proceed. - rw.w.Unlock() - if RaceEnabled { - RaceEnable() - } -} - -// DowngradeLock atomically unlocks rw for writing and locks it for reading. -func (rw *DowngradableRWMutex) DowngradeLock() { - if RaceEnabled { - RaceRelease(unsafe.Pointer(&rw.readerSem)) - RaceDisable() - } - // Announce to readers there is no active writer and one additional reader. - r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1) - if r >= rwmutexMaxReaders+1 { - panic("DowngradeLock of unlocked DowngradableRWMutex") - } - // Unblock blocked readers, if any. Note that this loop starts as 1 since r - // includes this goroutine. - for i := 1; i < int(r); i++ { - runtimeSemrelease(&rw.readerSem, false, 0) - } - // Allow other writers to proceed to rw.w.Lock(). Note that they will still - // block on rw.writerSem since at least this reader exists, such that - // DowngradeLock() is atomic with the previous write lock. - rw.w.Unlock() - if RaceEnabled { - RaceEnable() - } -} diff --git a/pkg/syncutil/memmove_unsafe.go b/pkg/syncutil/memmove_unsafe.go deleted file mode 100644 index 348675baa..000000000 --- a/pkg/syncutil/memmove_unsafe.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.12 -// +build !go1.15 - -// Check go:linkname function signatures when updating Go version. - -package syncutil - -import ( - "unsafe" -) - -//go:linkname memmove runtime.memmove -//go:noescape -func memmove(to, from unsafe.Pointer, n uintptr) - -// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad, which can't -// define it because go_generics can't update the go:linkname annotation. -// Furthermore, go:linkname silently doesn't work if the local name is exported -// (this is of course undocumented), which is why this indirection is -// necessary. -func Memmove(to, from unsafe.Pointer, n uintptr) { - memmove(to, from, n) -} diff --git a/pkg/syncutil/norace_unsafe.go b/pkg/syncutil/norace_unsafe.go deleted file mode 100644 index 0a0a9deda..000000000 --- a/pkg/syncutil/norace_unsafe.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !race - -package syncutil - -import ( - "unsafe" -) - -// RaceEnabled is true if the Go data race detector is enabled. -const RaceEnabled = false - -// RaceDisable has the same semantics as runtime.RaceDisable. -func RaceDisable() { -} - -// RaceEnable has the same semantics as runtime.RaceEnable. -func RaceEnable() { -} - -// RaceAcquire has the same semantics as runtime.RaceAcquire. -func RaceAcquire(addr unsafe.Pointer) { -} - -// RaceRelease has the same semantics as runtime.RaceRelease. -func RaceRelease(addr unsafe.Pointer) { -} - -// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. -func RaceReleaseMerge(addr unsafe.Pointer) { -} diff --git a/pkg/syncutil/race_unsafe.go b/pkg/syncutil/race_unsafe.go deleted file mode 100644 index 206067ec1..000000000 --- a/pkg/syncutil/race_unsafe.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build race - -package syncutil - -import ( - "runtime" - "unsafe" -) - -// RaceEnabled is true if the Go data race detector is enabled. -const RaceEnabled = true - -// RaceDisable has the same semantics as runtime.RaceDisable. -func RaceDisable() { - runtime.RaceDisable() -} - -// RaceEnable has the same semantics as runtime.RaceEnable. -func RaceEnable() { - runtime.RaceEnable() -} - -// RaceAcquire has the same semantics as runtime.RaceAcquire. -func RaceAcquire(addr unsafe.Pointer) { - runtime.RaceAcquire(addr) -} - -// RaceRelease has the same semantics as runtime.RaceRelease. -func RaceRelease(addr unsafe.Pointer) { - runtime.RaceRelease(addr) -} - -// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. -func RaceReleaseMerge(addr unsafe.Pointer) { - runtime.RaceReleaseMerge(addr) -} diff --git a/pkg/syncutil/seqatomic_unsafe.go b/pkg/syncutil/seqatomic_unsafe.go deleted file mode 100644 index cb6d2eb22..000000000 --- a/pkg/syncutil/seqatomic_unsafe.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package template doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package template - -import ( - "fmt" - "reflect" - "strings" - "unsafe" - - "gvisor.dev/gvisor/pkg/syncutil" -) - -// Value is a required type parameter. -// -// Value must not contain any pointers, including interface objects, function -// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs -// containing any of the above. An init() function will panic if this property -// does not hold. -type Value struct{} - -// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race -// with any writer critical sections in sc. -func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value { - // This function doesn't use SeqAtomicTryLoad because doing so is - // measurably, significantly (~20%) slower; Go is awful at inlining. - var val Value - for { - epoch := sc.BeginRead() - if syncutil.RaceEnabled { - // runtime.RaceDisable() doesn't actually stop the race detector, - // so it can't help us here. Instead, call runtime.memmove - // directly, which is not instrumented by the race detector. - syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - // This is ~40% faster for short reads than going through memmove. - val = *ptr - } - if sc.ReadOk(epoch) { - break - } - } - return val -} - -// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section -// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read -// would race with a writer critical section, SeqAtomicTryLoad returns -// (unspecified, false). -func SeqAtomicTryLoad(sc *syncutil.SeqCount, epoch syncutil.SeqCountEpoch, ptr *Value) (Value, bool) { - var val Value - if syncutil.RaceEnabled { - syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - val = *ptr - } - return val, sc.ReadOk(epoch) -} - -func init() { - var val Value - typ := reflect.TypeOf(val) - name := typ.Name() - if ptrs := syncutil.PointersInType(typ, name); len(ptrs) != 0 { - panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) - } -} diff --git a/pkg/syncutil/seqatomictest/BUILD b/pkg/syncutil/seqatomictest/BUILD deleted file mode 100644 index ba18f3238..000000000 --- a/pkg/syncutil/seqatomictest/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "seqatomic_int", - out = "seqatomic_int_unsafe.go", - package = "seqatomic", - suffix = "Int", - template = "//pkg/syncutil:generic_seqatomic", - types = { - "Value": "int", - }, -) - -go_library( - name = "seqatomic", - srcs = ["seqatomic_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/syncutil/seqatomic", - deps = [ - "//pkg/syncutil", - ], -) - -go_test( - name = "seqatomic_test", - size = "small", - srcs = ["seqatomic_test.go"], - embed = [":seqatomic"], - deps = [ - "//pkg/syncutil", - ], -) diff --git a/pkg/syncutil/seqatomictest/seqatomic_test.go b/pkg/syncutil/seqatomictest/seqatomic_test.go deleted file mode 100644 index b0db44999..000000000 --- a/pkg/syncutil/seqatomictest/seqatomic_test.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seqatomic - -import ( - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/syncutil" -) - -func TestSeqAtomicLoadUncontended(t *testing.T) { - var seq syncutil.SeqCount - const want = 1 - data := want - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadAfterWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadDuringWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicTryLoadUncontended(t *testing.T) { - var seq syncutil.SeqCount - const want = 1 - data := want - epoch := seq.BeginRead() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } -} - -func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } - seq.EndWrite() -} - -func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } -} - -func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { - var seq syncutil.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := SeqAtomicLoadInt(&seq, &data); got != want { - b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } - } - }) -} - -func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { - var seq syncutil.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - epoch := seq.BeginRead() - for pb.Next() { - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } - } - }) -} - -// For comparison: -func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { - var a atomic.Value - const want = 42 - a.Store(int(want)) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := a.Load().(int); got != want { - b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) - } - } - }) -} diff --git a/pkg/syncutil/seqcount.go b/pkg/syncutil/seqcount.go deleted file mode 100644 index 11d8dbfaa..000000000 --- a/pkg/syncutil/seqcount.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package syncutil - -import ( - "fmt" - "reflect" - "runtime" - "sync/atomic" -) - -// SeqCount is a synchronization primitive for optimistic reader/writer -// synchronization in cases where readers can work with stale data and -// therefore do not need to block writers. -// -// Compared to sync/atomic.Value: -// -// - Mutation of SeqCount-protected data does not require memory allocation, -// whereas atomic.Value generally does. This is a significant advantage when -// writes are common. -// -// - Atomic reads of SeqCount-protected data require copying. This is a -// disadvantage when atomic reads are common. -// -// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other -// operations to be made atomic with reads of SeqCount-protected data. -// -// - SeqCount may be less flexible: as of this writing, SeqCount-protected data -// cannot include pointers. -// -// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected -// data require instantiating function templates using go_generics (see -// seqatomic.go). -type SeqCount struct { - // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd - // if a writer critical section is active, and a read from data protected - // by this SeqCount is atomic iff epoch is the same even value before and - // after the read. - epoch uint32 -} - -// SeqCountEpoch tracks writer critical sections in a SeqCount. -type SeqCountEpoch struct { - val uint32 -} - -// We assume that: -// -// - All functions in sync/atomic that perform a memory read are at least a -// read fence: memory reads before calls to such functions cannot be reordered -// after the call, and memory reads after calls to such functions cannot be -// reordered before the call, even if those reads do not use sync/atomic. -// -// - All functions in sync/atomic that perform a memory write are at least a -// write fence: memory writes before calls to such functions cannot be -// reordered after the call, and memory writes after calls to such functions -// cannot be reordered before the call, even if those writes do not use -// sync/atomic. -// -// As of this writing, the Go memory model completely fails to describe -// sync/atomic, but these properties are implied by -// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8. - -// BeginRead indicates the beginning of a reader critical section. Reader -// critical sections DO NOT BLOCK writer critical sections, so operations in a -// reader critical section MAY RACE with writer critical sections. Races are -// detected by ReadOk at the end of the reader critical section. Thus, the -// low-level structure of readers is generally: -// -// for { -// epoch := seq.BeginRead() -// // do something idempotent with seq-protected data -// if seq.ReadOk(epoch) { -// break -// } -// } -// -// However, since reader critical sections may race with writer critical -// sections, the Go race detector will (accurately) flag data races in readers -// using this pattern. Most users of SeqCount will need to use the -// SeqAtomicLoad function template in seqatomic.go. -func (s *SeqCount) BeginRead() SeqCountEpoch { - epoch := atomic.LoadUint32(&s.epoch) - for epoch&1 != 0 { - runtime.Gosched() - epoch = atomic.LoadUint32(&s.epoch) - } - return SeqCountEpoch{epoch} -} - -// ReadOk returns true if the reader critical section initiated by a previous -// call to BeginRead() that returned epoch did not race with any writer critical -// sections. -// -// ReadOk may be called any number of times during a reader critical section. -// Reader critical sections do not need to be explicitly terminated; the last -// call to ReadOk is implicitly the end of the reader critical section. -func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { - return atomic.LoadUint32(&s.epoch) == epoch.val -} - -// BeginWrite indicates the beginning of a writer critical section. -// -// SeqCount does not support concurrent writer critical sections; clients with -// concurrent writers must synchronize them using e.g. sync.Mutex. -func (s *SeqCount) BeginWrite() { - if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 { - panic("SeqCount.BeginWrite during writer critical section") - } -} - -// EndWrite ends the effect of a preceding BeginWrite. -func (s *SeqCount) EndWrite() { - if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 { - panic("SeqCount.EndWrite outside writer critical section") - } -} - -// PointersInType returns a list of pointers reachable from values named -// valName of the given type. -// -// PointersInType is not exhaustive, but it is guaranteed that if typ contains -// at least one pointer, then PointersInTypeOf returns a non-empty list. -func PointersInType(typ reflect.Type, valName string) []string { - switch kind := typ.Kind(); kind { - case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: - return nil - - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer: - return []string{valName} - - case reflect.Array: - return PointersInType(typ.Elem(), valName+"[]") - - case reflect.Struct: - var ptrs []string - for i, n := 0, typ.NumField(); i < n; i++ { - field := typ.Field(i) - ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...) - } - return ptrs - - default: - return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)} - } -} diff --git a/pkg/syncutil/seqcount_test.go b/pkg/syncutil/seqcount_test.go deleted file mode 100644 index 14d6aedea..000000000 --- a/pkg/syncutil/seqcount_test.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package syncutil - -import ( - "reflect" - "testing" - "time" -) - -func TestSeqCountWriteUncontended(t *testing.T) { - var seq SeqCount - seq.BeginWrite() - seq.EndWrite() -} - -func TestSeqCountReadUncontended(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadAfterWrite(t *testing.T) { - var seq SeqCount - var data int32 - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadDuringWrite(t *testing.T) { - var seq SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountReadOkAfterWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } -} - -func TestSeqCountReadOkDuringWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } - seq.EndWrite() -} - -func BenchmarkSeqCountWriteUncontended(b *testing.B) { - var seq SeqCount - for i := 0; i < b.N; i++ { - seq.BeginWrite() - seq.EndWrite() - } -} - -func BenchmarkSeqCountReadUncontended(b *testing.B) { - var seq SeqCount - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - b.Fatalf("ReadOk: got false, wanted true") - } - } - }) -} - -func TestPointersInType(t *testing.T) { - for _, test := range []struct { - name string // used for both test and value name - val interface{} - ptrs []string - }{ - { - name: "EmptyStruct", - val: struct{}{}, - }, - { - name: "Int", - val: int(0), - }, - { - name: "MixedStruct", - val: struct { - b bool - I int - ExportedPtr *struct{} - unexportedPtr *struct{} - arr [2]int - ptrArr [2]*int - nestedStruct struct { - nestedNonptr int - nestedPtr *int - } - structArr [1]struct { - nonptr int - ptr *int - } - }{}, - ptrs: []string{ - "MixedStruct.ExportedPtr", - "MixedStruct.unexportedPtr", - "MixedStruct.ptrArr[]", - "MixedStruct.nestedStruct.nestedPtr", - "MixedStruct.structArr[].ptr", - }, - }, - } { - t.Run(test.name, func(t *testing.T) { - typ := reflect.TypeOf(test.val) - ptrs := PointersInType(typ, test.name) - t.Logf("Found pointers: %v", ptrs) - if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { - t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) - } - }) - } -} diff --git a/pkg/syncutil/syncutil.go b/pkg/syncutil/syncutil.go deleted file mode 100644 index 66e750d06..000000000 --- a/pkg/syncutil/syncutil.go +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package syncutil provides synchronization primitives. -package syncutil diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index e07ebd153..db06d02c6 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -15,6 +15,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip/buffer", "//pkg/tcpip/iptables", "//pkg/waiter", diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 78df5a0b1..3df7d18d3 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index cd6ce930a..a2f44b496 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -20,9 +20,9 @@ import ( "errors" "io" "net" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 897c94821..66cc53ed4 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -16,6 +16,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index fa8a703d9..b7f60178e 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,10 +41,10 @@ package fdbased import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index a4f9cdd69..09165dd4c 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -15,6 +15,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -31,6 +32,7 @@ go_test( ], embed = [":sharedmem"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 6b5bc542c..a0d4ad0be 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -21,4 +21,5 @@ go_test( "pipe_test.go", ], embed = [":pipe"], + deps = ["//pkg/sync"], ) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index 59ef69a8b..dc239a0d0 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -18,8 +18,9 @@ import ( "math/rand" "reflect" "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) func TestSimpleReadWrite(t *testing.T) { diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 080f9d667..655e537c4 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -23,11 +23,11 @@ package sharedmem import ( - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 89603c48f..5c729a439 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -22,11 +22,11 @@ import ( "math/rand" "os" "strings" - "sync" "syscall" "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index acf1e022c..ed16076fd 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -28,6 +28,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", ], diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 6da5238ec..92f2aa13a 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -19,9 +19,9 @@ package fragmentation import ( "fmt" "log" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9e002e396..0a83d81f2 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -18,9 +18,9 @@ import ( "container/heap" "fmt" "math" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index e156b01f6..a6ef3bdcc 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 6c5e19e8f..b937cb84b 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -18,9 +18,9 @@ package ports import ( "math" "math/rand" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 826fca4de..6a8654105 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -36,6 +36,7 @@ go_library( "//pkg/ilist", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", @@ -80,6 +81,7 @@ go_test( embed = [":stack"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 267df60d1..403557fd7 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -16,10 +16,10 @@ package stack import ( "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 9946b8fe8..1baa498d0 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -16,12 +16,12 @@ package stack import ( "fmt" - "sync" "sync/atomic" "testing" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3810c6602..fe557ccbd 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,9 +16,9 @@ package stack import ( "strings" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 41bf9fd9b..a47ceba54 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -21,13 +21,13 @@ package stack import ( "encoding/binary" - "sync" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 67c21be42..f384a91de 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -18,8 +18,8 @@ import ( "fmt" "math/rand" "sort" - "sync" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 72b5ce179..4a090ac86 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -35,10 +35,10 @@ import ( "reflect" "strconv" "strings" - "sync" "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/waiter" diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index d8c5b5058..3aa23d529 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -28,6 +28,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index c7ce74cdd..330786f4c 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,8 +15,7 @@ package icmp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index 44b58ff6b..4858d150c 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -28,6 +28,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 07ffa8aba..fc5bc69fa 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,8 +25,7 @@ package packet import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 00991ac8e..2f2131ff7 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -29,6 +29,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 85f7eb76b..ee9c4c58b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,8 +26,7 @@ package raw import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 3b353d56c..353bd06f4 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -48,6 +48,7 @@ go_library( "//pkg/log", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 5422ae80c..1ea996936 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -19,11 +19,11 @@ import ( "encoding/binary" "hash" "io" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index cdd69f360..613ec1775 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -16,11 +16,11 @@ package tcp import ( "encoding/binary" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 830bc1e3e..cca511fb9 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -19,12 +19,12 @@ import ( "fmt" "math" "strings" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 7aa4c3f0e..4b8d867bc 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -16,9 +16,9 @@ package tcp import ( "fmt" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 4983bca81..7eb613be5 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -15,8 +15,7 @@ package tcp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index bc718064c..9a8f64aa6 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -22,9 +22,9 @@ package tcp import ( "strings" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index e0759225e..bd20a7ee9 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -15,7 +15,7 @@ package tcp import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // segmentQueue is a bounded, thread-safe queue of TCP segments. diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8a947dc66..79f2d274b 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -16,11 +16,11 @@ package tcp import ( "math" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 97e4d5825..57ff123e3 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -30,6 +30,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 864dc8733..a4ff29a7d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,8 +15,7 @@ package udp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 6afdb29b7..07778e4f7 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -15,4 +15,5 @@ go_test( size = "medium", srcs = ["tmutex_test.go"], embed = [":tmutex"], + deps = ["//pkg/sync"], ) diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go index ce34c7962..05540696a 100644 --- a/pkg/tmutex/tmutex_test.go +++ b/pkg/tmutex/tmutex_test.go @@ -17,10 +17,11 @@ package tmutex import ( "fmt" "runtime" - "sync" "sync/atomic" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func TestBasicLock(t *testing.T) { diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index 8f6f180e5..d1885ae66 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -24,4 +24,5 @@ go_test( "unet_test.go", ], embed = [":unet"], + deps = ["//pkg/sync"], ) diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go index a3cc6f5d3..5c4b9e8e9 100644 --- a/pkg/unet/unet_test.go +++ b/pkg/unet/unet_test.go @@ -19,10 +19,11 @@ import ( "os" "path/filepath" "reflect" - "sync" "syscall" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func randomFilename() (string, error) { diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b6bbb0ea2..b8fdc3125 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -11,6 +11,7 @@ go_library( deps = [ "//pkg/fd", "//pkg/log", + "//pkg/sync", "//pkg/unet", ], ) diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go index df59ffab1..13b2ea314 100644 --- a/pkg/urpc/urpc.go +++ b/pkg/urpc/urpc.go @@ -27,10 +27,10 @@ import ( "os" "reflect" "runtime" - "sync" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 0427bc41f..1c6890e52 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -24,6 +24,7 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/waiter", visibility = ["//visibility:public"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 8a65ed164..f708e95fa 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -58,7 +58,7 @@ package waiter import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // EventMask represents io events as used in the poll() syscall. diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 6226b63f8..3e20f8f2f 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -74,6 +74,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/sentry/watchdog", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/link/fdbased", @@ -114,6 +115,7 @@ go_test( "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//runsc/fsgofer", "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go index 352e710d2..9c23b9553 100644 --- a/runsc/boot/compat.go +++ b/runsc/boot/compat.go @@ -17,7 +17,6 @@ package boot import ( "fmt" "os" - "sync" "syscall" "github.com/golang/protobuf/proto" @@ -27,6 +26,7 @@ import ( ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" "gvisor.dev/gvisor/pkg/sentry/strace" spb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto" + "gvisor.dev/gvisor/pkg/sync" ) func initCompatLogs(fd int) error { diff --git a/runsc/boot/limits.go b/runsc/boot/limits.go index d1c0bb9b5..ce62236e5 100644 --- a/runsc/boot/limits.go +++ b/runsc/boot/limits.go @@ -16,12 +16,12 @@ package boot import ( "fmt" - "sync" "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) // Mapping from linux resource names to limits.LimitType. diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index bc1d0c1bb..fad72f4ab 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -20,7 +20,6 @@ import ( mrand "math/rand" "os" "runtime" - "sync" "sync/atomic" "syscall" gtime "time" @@ -46,6 +45,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/watchdog" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index 147ff7703..bec0dc292 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -19,7 +19,6 @@ import ( "math/rand" "os" "reflect" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/runsc/fsgofer" ) diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index 250845ad7..b94bc4fa0 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -44,6 +44,7 @@ go_library( "//pkg/sentry/control", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", "//runsc/boot", diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go index a4e3071b3..1815c93b9 100644 --- a/runsc/cmd/create.go +++ b/runsc/cmd/create.go @@ -16,6 +16,7 @@ package cmd import ( "context" + "flag" "github.com/google/subcommands" "gvisor.dev/gvisor/runsc/boot" diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 4831210c0..7df7995f0 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -21,7 +21,6 @@ import ( "os" "path/filepath" "strings" - "sync" "syscall" "flag" @@ -30,6 +29,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/fsgofer" diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go index de2115dff..5e9bc53ab 100644 --- a/runsc/cmd/start.go +++ b/runsc/cmd/start.go @@ -16,6 +16,7 @@ package cmd import ( "context" + "flag" "github.com/google/subcommands" "gvisor.dev/gvisor/runsc/boot" diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 2bd12120d..6dea179e4 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -18,6 +18,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sentry/control", + "//pkg/sync", "//runsc/boot", "//runsc/cgroup", "//runsc/sandbox", @@ -53,6 +54,7 @@ go_test( "//pkg/sentry/control", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", "//runsc/boot", diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 5ed131a7f..060b63bf3 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -20,7 +20,6 @@ import ( "io" "os" "path/filepath" - "sync" "syscall" "testing" "time" @@ -29,6 +28,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/testutil" diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index c10f85992..b54d8f712 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -26,7 +26,6 @@ import ( "reflect" "strconv" "strings" - "sync" "syscall" "testing" "time" @@ -39,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" "gvisor.dev/gvisor/runsc/specutils" diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 4ad09ceab..2da93ec5b 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -22,7 +22,6 @@ import ( "path" "path/filepath" "strings" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" "gvisor.dev/gvisor/runsc/testutil" diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go index d95151ea5..17a251530 100644 --- a/runsc/container/state_file.go +++ b/runsc/container/state_file.go @@ -20,10 +20,10 @@ import ( "io/ioutil" "os" "path/filepath" - "sync" "github.com/gofrs/flock" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" ) const stateFileExtension = ".state" diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index afcb41801..a9582d92b 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/fd", "//pkg/log", "//pkg/p9", + "//pkg/sync", "//pkg/syserr", "//runsc/specutils", "@org_golang_x_sys//unix:go_default_library", diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index b59e1a70e..93606d051 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -29,7 +29,6 @@ import ( "path/filepath" "runtime" "strconv" - "sync" "syscall" "golang.org/x/sys/unix" @@ -37,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/specutils" ) diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index 8001949d5..ddbc37456 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/log", "//pkg/sentry/control", "//pkg/sentry/platform", + "//pkg/sync", "//pkg/tcpip/header", "//pkg/tcpip/stack", "//pkg/urpc", diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index ce1452b87..ec72bdbfd 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -22,7 +22,6 @@ import ( "os" "os/exec" "strconv" - "sync" "syscall" "time" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD index c96ca2eb6..3c3027cb5 100644 --- a/runsc/testutil/BUILD +++ b/runsc/testutil/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", "//runsc/boot", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go index 9632776d2..fb22eae39 100644 --- a/runsc/testutil/testutil.go +++ b/runsc/testutil/testutil.go @@ -34,7 +34,6 @@ import ( "path/filepath" "strconv" "strings" - "sync" "sync/atomic" "syscall" "time" @@ -42,6 +41,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" ) -- cgit v1.2.3 From a611fdaee3c14abe2222140ae0a8a742ebfd31ab Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 14 Jan 2020 14:14:17 -0800 Subject: Changes TCP packet dispatch to use a pool of goroutines. All inbound segments for connections in ESTABLISHED state are delivered to the endpoint's queue but for every segment delivered we also queue the endpoint for processing to a selected processor. This ensures that when there are a large number of connections in ESTABLISHED state the inbound packets are all handled by a small number of goroutines and significantly reduces the amount of work the goscheduler has to perform. We let connections in other states follow the current path where the endpoint's goroutine directly handles the segments. Updates #231 PiperOrigin-RevId: 289728325 --- benchmarks/tcp/tcp_proxy.go | 6 +- pkg/sleep/sleep_test.go | 31 +++ pkg/tcpip/stack/transport_demuxer.go | 54 ++++- pkg/tcpip/transport/tcp/BUILD | 15 +- pkg/tcpip/transport/tcp/accept.go | 9 +- pkg/tcpip/transport/tcp/connect.go | 310 ++++++++++++++++------------ pkg/tcpip/transport/tcp/dispatcher.go | 218 +++++++++++++++++++ pkg/tcpip/transport/tcp/endpoint.go | 303 ++++++++++++++++++--------- pkg/tcpip/transport/tcp/endpoint_state.go | 30 +-- pkg/tcpip/transport/tcp/protocol.go | 11 + pkg/tcpip/transport/tcp/rcv.go | 21 +- pkg/tcpip/transport/tcp/snd.go | 14 +- pkg/tcpip/transport/tcp/tcp_test.go | 11 +- test/syscalls/linux/socket_inet_loopback.cc | 2 +- test/syscalls/linux/tcp_socket.cc | 14 ++ 15 files changed, 769 insertions(+), 280 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/dispatcher.go (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index be0d7bdd6..dc96add66 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -85,7 +85,7 @@ func (netImpl) printStats() { const ( nicID = 1 // Fixed. - rcvBufSize = 1 << 20 // 1MB. + rcvBufSize = 4 << 20 // 1MB. ) type netstackImpl struct { @@ -130,6 +130,10 @@ func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) { return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) } + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, rcvBufSize); err != nil { + return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) + } + if !*swgso && *gso != 0 { if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go index 130806c86..af47e2ba1 100644 --- a/pkg/sleep/sleep_test.go +++ b/pkg/sleep/sleep_test.go @@ -376,6 +376,37 @@ func TestRace(t *testing.T) { } } +// TestRaceInOrder tests that multiple wakers can continuously send wake requests to +// the sleeper and that the wakers are retrieved in the order asserted. +func TestRaceInOrder(t *testing.T) { + const wakers = 100 + const wakeRequests = 10000 + + w := make([]Waker, wakers) + s := Sleeper{} + + // Associate each waker and start goroutines that will assert them. + for i := range w { + s.AddWaker(&w[i], i) + } + go func() { + n := 0 + for n < wakeRequests { + wk := w[n%len(w)] + wk.Assert() + n++ + } + }() + + // Wait for all wake up notifications from all wakers. + for i := 0; i < wakeRequests; i++ { + v, _ := s.Fetch(true) + if got, want := v, i%wakers; got != want { + t.Fatalf("got %d want %d", got, want) + } + } +} + // BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up // from 4 wakers when at least one is already asserted. func BenchmarkSleeperMultiSelect(b *testing.B) { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index f384a91de..d686e6eb8 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -104,7 +104,14 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p return } // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) + transEP := selectEndpoint(id, mpep, epsByNic.seed) + if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { + queuedProtocol.QueuePacket(r, transEP, id, pkt) + epsByNic.mu.RUnlock() + return + } + + transEP.HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. } @@ -130,7 +137,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { epsByNic.mu.Lock() defer epsByNic.mu.Unlock() @@ -140,7 +147,7 @@ func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort } // This is a new binding. - multiPortEp := &multiPortEndpoint{} + multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto} multiPortEp.endpointsMap = make(map[TransportEndpoint]int) multiPortEp.reuse = reusePort epsByNic.endpoints[bindToDevice] = multiPortEp @@ -168,18 +175,34 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T // newTransportDemuxer. type transportDemuxer struct { // protocol is immutable. - protocol map[protocolIDs]*transportEndpoints + protocol map[protocolIDs]*transportEndpoints + queuedProtocols map[protocolIDs]queuedTransportProtocol +} + +// queuedTransportProtocol if supported by a protocol implementation will cause +// the dispatcher to delivery packets to the QueuePacket method instead of +// calling HandlePacket directly on the endpoint. +type queuedTransportProtocol interface { + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { - d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} + d := &transportDemuxer{ + protocol: make(map[protocolIDs]*transportEndpoints), + queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), + } // Add each network and transport pair to the demuxer. for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { - d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ + protoIDs := protocolIDs{netProto, proto} + d.protocol[protoIDs] = &transportEndpoints{ endpoints: make(map[TransportEndpointID]*endpointsByNic), } + qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) + if isQueued { + d.queuedProtocols[protoIDs] = qTransProto + } } } @@ -209,7 +232,11 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // // +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + demux *transportDemuxer + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int // reuse indicates if more than one endpoint is allowed. @@ -258,13 +285,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() + queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] for i, endpoint := range ep.endpointsArr { // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. if i == len(ep.endpointsArr)-1 { + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt) + break + } endpoint.HandlePacket(r, id, pkt) break } + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + continue + } endpoint.HandlePacket(r, id, pkt.Clone()) } ep.mu.RUnlock() // Don't use defer for performance reasons. @@ -357,7 +393,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if epsByNic, ok := eps.endpoints[id]; ok { // There was already a binding. - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // This is a new binding. @@ -367,7 +403,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol } eps.endpoints[id] = epsByNic - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 353bd06f4..0e3ab05ad 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -16,6 +16,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tcp_endpoint_list", + out = "tcp_endpoint_list.go", + package = "tcp", + prefix = "endpoint", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*endpoint", + "Linker": "*endpoint", + }, +) + go_library( name = "tcp", srcs = [ @@ -23,6 +35,7 @@ go_library( "connect.go", "cubic.go", "cubic_state.go", + "dispatcher.go", "endpoint.go", "endpoint_state.go", "forwarder.go", @@ -38,6 +51,7 @@ go_library( "segment_state.go", "snd.go", "snd_state.go", + "tcp_endpoint_list.go", "tcp_segment_list.go", "timer.go", ], @@ -45,7 +59,6 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ - "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 1ea996936..1a2e3efa9 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -285,7 +285,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // listenEP is nil when listenContext is used by tcp.Forwarder. if l.listenEP != nil { l.listenEP.mu.Lock() - if l.listenEP.state != StateListen { + if l.listenEP.EndpointState() != StateListen { l.listenEP.mu.Unlock() return nil, tcpip.ErrConnectionAborted } @@ -344,11 +344,12 @@ func (l *listenContext) closeAllPendingEndpoints() { // instead. func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.Lock() - state := e.state + state := e.EndpointState() e.pendingAccepted.Add(1) defer e.pendingAccepted.Done() acceptedChan := e.acceptedChan e.mu.Unlock() + if state == StateListen { acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) @@ -562,8 +563,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // We do not use transitionToStateEstablishedLocked here as there is // no handshake state available when doing a SYN cookie based accept. n.stack.Stats().TCP.CurrentEstablished.Increment() - n.state = StateEstablished n.isConnectNotified = true + n.setEndpointState(StateEstablished) // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -596,7 +597,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // handleSynSegment() from attempting to queue new connections // to the endpoint. e.mu.Lock() - e.state = StateClose + e.setEndpointState(StateClose) // close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 613ec1775..f3896715b 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -190,7 +190,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.ep.mu.Lock() - h.ep.state = StateSynRecv + h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() } @@ -274,14 +274,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // SYN-RCVD state. h.state = handshakeSynRcvd h.ep.mu.Lock() - h.ep.state = StateSynRecv ttl := h.ep.ttl + h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), TS: rcvSynOpts.TS, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), // We only send SACKPermitted if the other side indicated it // permits SACK. This is not explicitly defined in the RFC but @@ -341,7 +341,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { WS: h.rcvWndScale, TS: h.ep.sendTSOk, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } @@ -501,7 +501,7 @@ func (h *handshake) execute() *tcpip.Error { WS: h.rcvWndScale, TS: true, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } @@ -792,7 +792,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // Ref: https://tools.ietf.org/html/rfc7323#section-5.4. offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:]) + offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) } if e.sackPermitted && len(sackBlocks) > 0 { offset += header.EncodeNOP(options[offset:]) @@ -811,7 +811,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -848,6 +848,9 @@ func (e *endpoint) handleWrite() *tcpip.Error { } func (e *endpoint) handleClose() *tcpip.Error { + if !e.EndpointState().connected() { + return nil + } // Drain the send queue. e.handleWrite() @@ -864,11 +867,7 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. - if e.state == StateEstablished || e.state == StateCloseWait { - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() - } - e.state = StateError + e.setEndpointState(StateError) e.HardError = err if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the @@ -888,9 +887,12 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { } // completeWorkerLocked is called by the worker goroutine when it's about to -// exit. It marks the worker as completed and performs cleanup work if requested -// by Close(). +// exit. func (e *endpoint) completeWorkerLocked() { + // Worker is terminating(either due to moving to + // CLOSED or ERROR state, ensure we release all + // registrations port reservations even if the socket + // itself is not yet closed by the application. e.workerRunning = false if e.workerCleanup { e.cleanupLocked() @@ -917,8 +919,7 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { e.rcvAutoParams.prevCopied = int(h.rcvWnd) e.rcvListMu.Unlock() } - h.ep.stack.Stats().TCP.CurrentEstablished.Increment() - e.state = StateEstablished + e.setEndpointState(StateEstablished) } // transitionToStateCloseLocked ensures that the endpoint is @@ -927,11 +928,12 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // delivered to this endpoint from the demuxer when the endpoint // is transitioned to StateClose. func (e *endpoint) transitionToStateCloseLocked() { - if e.state == StateClose { + if e.EndpointState() == StateClose { return } + // Mark the endpoint as fully closed for reads/writes. e.cleanupLocked() - e.state = StateClose + e.setEndpointState(StateClose) e.stack.Stats().TCP.EstablishedClosed.Increment() } @@ -946,7 +948,9 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { s.decRef() return } - ep.(*endpoint).enqueueSegment(s) + if ep.(*endpoint).enqueueSegment(s) { + ep.(*endpoint).newSegmentWaker.Assert() + } } func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { @@ -955,9 +959,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // except SYN-SENT, all reset (RST) segments are // validated by checking their SEQ-fields." So // we only process it if it's acceptable. - s.decRef() e.mu.Lock() - switch e.state { + switch e.EndpointState() { // In case of a RST in CLOSE-WAIT linux moves // the socket to closed state with an error set // to indicate EPIPE. @@ -981,103 +984,57 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { e.transitionToStateCloseLocked() e.HardError = tcpip.ErrAborted e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: e.mu.Unlock() + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + + // Notify protocol goroutine. This is required when + // handleSegment is invoked from the processor goroutine + // rather than the worker goroutine. + e.notifyProtocolGoroutine(notifyResetByPeer) return false, tcpip.ErrConnectionReset } } return true, nil } -// handleSegments pulls segments from the queue and processes them. It returns -// no error if the protocol loop should continue, an error otherwise. -func (e *endpoint) handleSegments() *tcpip.Error { +// handleSegments processes all inbound segments. +func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { + if e.EndpointState() == StateClose || e.EndpointState() == StateError { + return nil + } s := e.segmentQueue.dequeue() if s == nil { checkRequeue = false break } - // Invoke the tcp probe if installed. - if e.probe != nil { - e.probe(e.completeState()) + cont, err := e.handleSegment(s) + if err != nil { + s.decRef() + e.mu.Lock() + e.setEndpointState(StateError) + e.HardError = err + e.mu.Unlock() + return err } - - if s.flagIsSet(header.TCPFlagRst) { - if ok, err := e.handleReset(s); !ok { - return err - } - } else if s.flagIsSet(header.TCPFlagSyn) { - // See: https://tools.ietf.org/html/rfc5961#section-4.1 - // 1) If the SYN bit is set, irrespective of the sequence number, TCP - // MUST send an ACK (also referred to as challenge ACK) to the remote - // peer: - // - // - // - // After sending the acknowledgment, TCP MUST drop the unacceptable - // segment and stop processing further. - // - // By sending an ACK, the remote peer is challenged to confirm the loss - // of the previous connection and the request to start a new connection. - // A legitimate peer, after restart, would not have a TCB in the - // synchronized state. Thus, when the ACK arrives, the peer should send - // a RST segment back with the sequence number derived from the ACK - // field that caused the RST. - - // This RST will confirm that the remote peer has indeed closed the - // previous connection. Upon receipt of a valid RST, the local TCP - // endpoint MUST terminate its connection. The local TCP endpoint - // should then rely on SYN retransmission from the remote end to - // re-establish the connection. - - e.snd.sendAck() - } else if s.flagIsSet(header.TCPFlagAck) { - // Patch the window size in the segment according to the - // send window scale. - s.window <<= e.snd.sndWndScale - - // RFC 793, page 41 states that "once in the ESTABLISHED - // state all segments must carry current acknowledgment - // information." - drop, err := e.rcv.handleRcvdSegment(s) - if err != nil { - s.decRef() - return err - } - if drop { - s.decRef() - continue - } - - // Now check if the received segment has caused us to transition - // to a CLOSED state, if yes then terminate processing and do - // not invoke the sender. - e.mu.RLock() - state := e.state - e.mu.RUnlock() - if state == StateClose { - // When we get into StateClose while processing from the queue, - // return immediately and let the protocolMainloop handle it. - // - // We can reach StateClose only while processing a previous segment - // or a notification from the protocolMainLoop (caller goroutine). - // This means that with this return, the segment dequeue below can - // never occur on a closed endpoint. - s.decRef() - return nil - } - e.snd.handleRcvdSegment(s) + if !cont { + s.decRef() + return nil } - s.decRef() } - // If the queue is not empty, make sure we'll wake up in the next - // iteration. - if checkRequeue && !e.segmentQueue.empty() { + // When fastPath is true we don't want to wake up the worker + // goroutine. If the endpoint has more segments to process the + // dispatcher will call handleSegments again anyway. + if !fastPath && checkRequeue && !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } @@ -1086,11 +1043,88 @@ func (e *endpoint) handleSegments() *tcpip.Error { e.snd.sendAck() } - e.resetKeepaliveTimer(true) + e.resetKeepaliveTimer(true /* receivedData */) return nil } +// handleSegment handles a given segment and notifies the worker goroutine if +// if the connection should be terminated. +func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { + // Invoke the tcp probe if installed. + if e.probe != nil { + e.probe(e.completeState()) + } + + if s.flagIsSet(header.TCPFlagRst) { + if ok, err := e.handleReset(s); !ok { + return false, err + } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() + } else if s.flagIsSet(header.TCPFlagAck) { + // Patch the window size in the segment according to the + // send window scale. + s.window <<= e.snd.sndWndScale + + // RFC 793, page 41 states that "once in the ESTABLISHED + // state all segments must carry current acknowledgment + // information." + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + return false, err + } + if drop { + return true, nil + } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return false, nil + } + + e.snd.handleRcvdSegment(s) + } + + return true, nil +} + // keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. @@ -1160,7 +1194,7 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. -func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1182,6 +1216,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() + e.workMu.Unlock() // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1193,7 +1228,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { initialRcvWnd := e.initialReceiveWindow() h := newHandshake(e, seqnum.Size(initialRcvWnd)) e.mu.Lock() - h.ep.state = StateSynSent + h.ep.setEndpointState(StateSynSent) e.mu.Unlock() if err := h.execute(); err != nil { @@ -1202,12 +1237,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() - e.state = StateError + e.setEndpointState(StateError) e.HardError = err // Lock released below. epilogue() - return err } } @@ -1215,7 +1249,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.keepalive.timer.init(&e.keepalive.waker) defer e.keepalive.timer.cleanup() - // Tell waiters that the endpoint is connected and writable. e.mu.Lock() drained := e.drainDone != nil e.mu.Unlock() @@ -1224,8 +1257,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { <-e.undrain } - e.waiterQueue.Notify(waiter.EventOut) - // Set up the functions that will be called when the main protocol loop // wakes up. funcs := []struct { @@ -1240,18 +1271,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { w: &e.sndCloseWaker, f: e.handleClose, }, - { - w: &e.newSegmentWaker, - f: e.handleSegments, - }, { w: &closeWaker, f: func() *tcpip.Error { // This means the socket is being closed due - // to the TCP_FIN_WAIT2 timeout was hit. Just + // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. e.mu.Lock() e.transitionToStateCloseLocked() + e.workerCleanup = true e.mu.Unlock() return nil }, @@ -1266,6 +1294,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil }, }, + { + w: &e.newSegmentWaker, + f: func() *tcpip.Error { + return e.handleSegments(false /* fastPath */) + }, + }, { w: &e.keepalive.waker, f: e.keepaliveTimerExpired, @@ -1293,14 +1327,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } if n¬ifyReset != 0 { - e.mu.Lock() - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.mu.Unlock() + return tcpip.ErrConnectionAborted + } + + if n¬ifyResetByPeer != 0 { + return tcpip.ErrConnectionReset } if n¬ifyClose != 0 && closeTimer == nil { e.mu.Lock() - if e.state == StateFinWait2 && e.closed { + if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { @@ -1320,11 +1356,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { - if err := e.handleSegments(); err != nil { + if err := e.handleSegments(false /* fastPath */); err != nil { return err } } - if e.state != StateClose && e.state != StateError { + if e.EndpointState() != StateClose && e.EndpointState() != StateError { // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) @@ -1349,14 +1385,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { s.AddWaker(funcs[i].w, i) } + // Notify the caller that the waker initialization is complete and the + // endpoint is ready. + if wakerInitDone != nil { + close(wakerInitDone) + } + + // Tell waiters that the endpoint is connected and writable. + e.waiterQueue.Notify(waiter.EventOut) + // The following assertions and notifications are needed for restored // endpoints. Fresh newly created endpoints have empty states and should // not invoke any. - e.segmentQueue.mu.Lock() - if !e.segmentQueue.list.Empty() { + if !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } - e.segmentQueue.mu.Unlock() e.rcvListMu.Lock() if !e.rcvList.Empty() { @@ -1372,27 +1415,32 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError { e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() + // We need to double check here because the notification + // maybe stale by the time we got around to processing it. + // NOTE: since we now hold the workMu the processors cannot + // change the state of the endpoint so it' safe to proceed + // after this check. + if e.EndpointState() == StateTimeWait || e.EndpointState() == StateClose || e.EndpointState() == StateError { + e.mu.Lock() + break + } if err := funcs[v].f(); err != nil { e.mu.Lock() - // Ensure we release all endpoint registration and route - // references as the connection is now in an error - // state. e.workerCleanup = true e.resetConnectionLocked(err) // Lock released below. epilogue() - return nil } e.mu.Lock() } - state := e.state + state := e.EndpointState() e.mu.Unlock() var reuseTW func() if state == StateTimeWait { @@ -1405,13 +1453,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { s.Done() // Wake up any waiters before we enter TIME_WAIT. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.mu.Lock() + e.workerCleanup = true + e.mu.Unlock() reuseTW = e.doTimeWait() } // Mark endpoint as closed. e.mu.Lock() - if e.state != StateError { - e.stack.Stats().TCP.CurrentEstablished.Decrement() + if e.EndpointState() != StateError { e.transitionToStateCloseLocked() } @@ -1468,7 +1518,11 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() tcpEP := listenEP.(*endpoint) if EndpointState(tcpEP.State()) == StateListen { reuseTW = func() { - tcpEP.enqueueSegment(s) + if !tcpEP.enqueueSegment(s) { + s.decRef() + return + } + tcpEP.newSegmentWaker.Assert() } // We explicitly do not decRef // the segment as it's still diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go new file mode 100644 index 000000000..a72f0c379 --- /dev/null +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -0,0 +1,218 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// epQueue is a queue of endpoints. +type epQueue struct { + mu sync.Mutex + list endpointList +} + +// enqueue adds e to the queue if the endpoint is not already on the queue. +func (q *epQueue) enqueue(e *endpoint) { + q.mu.Lock() + if e.pendingProcessing { + q.mu.Unlock() + return + } + q.list.PushBack(e) + e.pendingProcessing = true + q.mu.Unlock() +} + +// dequeue removes and returns the first element from the queue if available, +// returns nil otherwise. +func (q *epQueue) dequeue() *endpoint { + q.mu.Lock() + if e := q.list.Front(); e != nil { + q.list.Remove(e) + e.pendingProcessing = false + q.mu.Unlock() + return e + } + q.mu.Unlock() + return nil +} + +// empty returns true if the queue is empty, false otherwise. +func (q *epQueue) empty() bool { + q.mu.Lock() + v := q.list.Empty() + q.mu.Unlock() + return v +} + +// processor is responsible for processing packets queued to a tcp endpoint. +type processor struct { + epQ epQueue + newEndpointWaker sleep.Waker + id int +} + +func newProcessor(id int) *processor { + p := &processor{ + id: id, + } + go p.handleSegments() + return p +} + +func (p *processor) queueEndpoint(ep *endpoint) { + // Queue an endpoint for processing by the processor goroutine. + p.epQ.enqueue(ep) + p.newEndpointWaker.Assert() +} + +func (p *processor) handleSegments() { + const newEndpointWaker = 1 + s := sleep.Sleeper{} + s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + defer s.Done() + for { + s.Fetch(true) + for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { + if ep.segmentQueue.empty() { + continue + } + + // If socket has transitioned out of connected state + // then just let the worker handle the packet. + // + // NOTE: We read this outside of e.mu lock which means + // that by the time we get to handleSegments the + // endpoint may not be in ESTABLISHED. But this should + // be fine as all normal shutdown states are handled by + // handleSegments and if the endpoint moves to a + // CLOSED/ERROR state then handleSegments is a noop. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + continue + } + + if !ep.workMu.TryLock() { + ep.newSegmentWaker.Assert() + continue + } + // If the endpoint is in a connected state then we do + // direct delivery to ensure low latency and avoid + // scheduler interactions. + if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { + ep.notifyProtocolGoroutine(notifyTickleWorker) + ep.workMu.Unlock() + continue + } + + if !ep.segmentQueue.empty() { + p.epQ.enqueue(ep) + } + + ep.workMu.Unlock() + } + } +} + +// dispatcher manages a pool of TCP endpoint processors which are responsible +// for the processing of inbound segments. This fixed pool of processor +// goroutines do full tcp processing. The processor is selected based on the +// hash of the endpoint id to ensure that delivery for the same endpoint happens +// in-order. +type dispatcher struct { + processors []*processor + seed uint32 +} + +func newDispatcher(nProcessors int) *dispatcher { + processors := []*processor{} + for i := 0; i < nProcessors; i++ { + processors = append(processors, newProcessor(i)) + } + return &dispatcher{ + processors: processors, + seed: generateRandUint32(), + } +} + +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + ep := stackEP.(*endpoint) + s := newSegment(r, id, pkt) + if !s.parse() { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + s.decRef() + return + } + + if !s.csumValid { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.ChecksumErrors.Increment() + ep.stats.ReceiveErrors.ChecksumErrors.Increment() + s.decRef() + return + } + + ep.stack.Stats().TCP.ValidSegmentsReceived.Increment() + ep.stats.SegmentsReceived.Increment() + if (s.flags & header.TCPFlagRst) != 0 { + ep.stack.Stats().TCP.ResetsReceived.Increment() + } + + if !ep.enqueueSegment(s) { + s.decRef() + return + } + + // For sockets not in established state let the worker goroutine + // handle the packets. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + return + } + + d.selectProcessor(id).queueEndpoint(ep) +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { + payload := []byte{ + byte(id.LocalPort), + byte(id.LocalPort >> 8), + byte(id.RemotePort), + byte(id.RemotePort >> 8)} + + h := jenkins.Sum32(d.seed) + h.Write(payload) + h.Write([]byte(id.LocalAddress)) + h.Write([]byte(id.RemoteAddress)) + + return d.processors[h.Sum32()%uint32(len(d.processors))] +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cc8b533c8..1799c6e10 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -120,6 +120,7 @@ const ( notifyMTUChanged notifyDrain notifyReset + notifyResetByPeer notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -127,6 +128,7 @@ const ( // ensures the loop terminates if the final state of the endpoint is // say TIME_WAIT. notifyTickleWorker + notifyError ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -283,6 +285,18 @@ func (*EndpointInfo) IsEndpointInfo() {} type endpoint struct { EndpointInfo + // endpointEntry is used to queue endpoints for processing to the + // a given tcp processor goroutine. + // + // Precondition: epQueue.mu must be held to read/write this field.. + endpointEntry `state:"nosave"` + + // pendingProcessing is true if this endpoint is queued for processing + // to a TCP processor. + // + // Precondition: epQueue.mu must be held to read/write this field.. + pendingProcessing bool `state:"nosave"` + // workMu is used to arbitrate which goroutine may perform protocol // work. Only the main protocol goroutine is expected to call Lock() on // it, but other goroutines (e.g., send) may call TryLock() to eagerly @@ -324,6 +338,7 @@ type endpoint struct { // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` + // state must be read/set using the EndpointState()/setEndpointState() methods. state EndpointState `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the @@ -359,7 +374,7 @@ type endpoint struct { workerRunning bool // workerCleanup specifies if the worker goroutine must perform cleanup - // before exitting. This can only be set to true when workerRunning is + // before exiting. This can only be set to true when workerRunning is // also true, and they're both protected by the mutex. workerCleanup bool @@ -371,6 +386,8 @@ type endpoint struct { // recentTS is the timestamp that should be sent in the TSEcr field of // the timestamp for future segments sent by the endpoint. This field is // updated if required when a new segment is received by this endpoint. + // + // recentTS must be read/written atomically. recentTS uint32 // tsOffset is a randomized offset added to the value of the @@ -567,6 +584,47 @@ func (e *endpoint) ResumeWork() { e.workMu.Unlock() } +// setEndpointState updates the state of the endpoint to state atomically. This +// method is unexported as the only place we should update the state is in this +// package but we allow the state to be read freely without holding e.mu. +// +// Precondition: e.mu must be held to call this method. +func (e *endpoint) setEndpointState(state EndpointState) { + oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + switch state { + case StateEstablished: + e.stack.Stats().TCP.CurrentEstablished.Increment() + case StateError: + fallthrough + case StateClose: + if oldstate == StateCloseWait || oldstate == StateEstablished { + e.stack.Stats().TCP.EstablishedResets.Increment() + } + fallthrough + default: + if oldstate == StateEstablished { + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } + } + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) +} + +// EndpointState returns the current state of the endpoint. +func (e *endpoint) EndpointState() EndpointState { + return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) +} + +// setRecentTimestamp atomically sets the recentTS field to the +// provided value. +func (e *endpoint) setRecentTimestamp(recentTS uint32) { + atomic.StoreUint32(&e.recentTS, recentTS) +} + +// recentTimestamp atomically reads and returns the value of the recentTS field. +func (e *endpoint) recentTimestamp() uint32 { + return atomic.LoadUint32(&e.recentTS) +} + // keepalive is a synchronization wrapper used to appease stateify. See the // comment in endpoint, where it is used. // @@ -656,7 +714,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { e.mu.RLock() defer e.mu.RUnlock() - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. @@ -672,7 +730,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } } } - if e.state.connected() { + if e.EndpointState().connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -733,14 +791,20 @@ func (e *endpoint) Close() { // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) + e.closeNoShutdown() +} +// closeNoShutdown closes the endpoint without doing a full shutdown. This is +// used when a connection needs to be aborted with a RST and we want to skip +// a full 4 way TCP shutdown. +func (e *endpoint) closeNoShutdown() { e.mu.Lock() // For listening sockets, we always release ports inline so that they // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == StateListen && e.isPortReserved { + if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false @@ -780,6 +844,8 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { defer close(done) for n := range e.acceptedChan { n.notifyProtocolGoroutine(notifyReset) + // close all connections that have completed but + // not accepted by the application. n.Close() } }() @@ -797,11 +863,13 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { // after Close() is called and the worker goroutine (if any) is done with its // work. func (e *endpoint) cleanupLocked() { + // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { e.closePendingAcceptableConnectionsLocked() } + e.workerCleanup = false if e.isRegistered { @@ -920,7 +988,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { + if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError e.mu.RUnlock() @@ -944,7 +1012,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -980,8 +1048,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. - if !e.state.connected() { - switch e.state { + if !e.EndpointState().connected() { + switch e.EndpointState() { case StateError: return 0, e.HardError default: @@ -1039,42 +1107,86 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, perr } - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() + if opts.Atomic { + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() + } - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { + if e.workMu.TryLock() { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] } - } - - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() - - if e.workMu.TryLock() { // Do the work inline. e.handleWrite() e.workMu.Unlock() } else { + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() + + } // Let the protocol goroutine do the work. e.sndWaker.Assert() } @@ -1091,7 +1203,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; !s.connected() && s != StateClose { + if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { return 0, tcpip.ControlMessages{}, e.HardError } @@ -1103,7 +1215,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } @@ -1187,7 +1299,7 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -1402,14 +1514,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // Acquire the work mutex as we may need to // reinitialize the congestion control state. e.mu.Lock() - state := e.state + state := e.EndpointState() e.cc = v e.mu.Unlock() switch state { case StateEstablished: e.workMu.Lock() e.mu.Lock() - if e.state == state { + if e.EndpointState() == state { e.snd.cc = e.snd.initCongestionControl(e.cc) } e.mu.Unlock() @@ -1472,7 +1584,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { defer e.mu.RUnlock() // The endpoint cannot be in listen state. - if e.state == StateListen { + if e.EndpointState() == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1731,7 +1843,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return err } - if e.state.connected() { + if e.EndpointState().connected() { // The endpoint is already connected. If caller hasn't been // notified yet, return success. if !e.isConnectNotified { @@ -1743,7 +1855,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } nicID := addr.NIC - switch e.state { + switch e.EndpointState() { case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. @@ -1850,7 +1962,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.isRegistered = true - e.state = StateConnecting + e.setEndpointState(StateConnecting) e.route = r.Clone() e.boundNICID = nicID e.effectiveNetProtos = netProtos @@ -1871,14 +1983,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = StateEstablished - e.stack.Stats().TCP.CurrentEstablished.Increment() + e.setEndpointState(StateEstablished) } if run { e.workerRunning = true e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save. + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } return tcpip.ErrConnectStarted @@ -1896,7 +2007,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.shutdownFlags |= flags finQueued := false switch { - case e.state.connected(): + case e.EndpointState().connected(): // Close for read. if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { // Mark read side as closed. @@ -1908,8 +2019,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { - e.notifyProtocolGoroutine(notifyReset) + // Move the socket to error state immediately. + // This is done redundantly because in case of + // save/restore on a Shutdown/Close() the socket + // state needs to indicate the error otherwise + // save file will show the socket in established + // state even though snd/rcv are closed. e.mu.Unlock() + // Try to send an active reset immediately if the + // work mutex is available. + if e.workMu.TryLock() { + e.mu.Lock() + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.mu.Unlock() + e.workMu.Unlock() + } else { + e.notifyProtocolGoroutine(notifyReset) + } return nil } } @@ -1931,11 +2057,10 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { finQueued = true // Mark endpoint as closed. e.sndClosed = true - e.sndBufMu.Unlock() } - case e.state == StateListen: + case e.EndpointState() == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) @@ -1976,7 +2101,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // When the endpoint shuts down, it sets workerCleanup to true, and from // that point onward, acceptedChan is the responsibility of the cleanup() // method (and should not be touched anywhere else, including here). - if e.state == StateListen && !e.workerCleanup { + if e.EndpointState() == StateListen && !e.workerCleanup { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. if len(e.acceptedChan) > backlog { @@ -1994,7 +2119,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { return nil } - if e.state == StateInitial { + if e.EndpointState() == StateInitial { // The listen is called on an unbound socket, the socket is // automatically bound to a random free port with the local // address set to INADDR_ANY. @@ -2004,7 +2129,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Endpoint must be bound before it can transition to listen mode. - if e.state != StateBound { + if e.EndpointState() != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } @@ -2015,24 +2140,27 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } e.isRegistered = true - e.state = StateListen + e.setEndpointState(StateListen) + if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } e.workerRunning = true - go e.protocolListenLoop( // S/R-SAFE: drained on save. seqnum.Size(e.receiveBufferAvailable())) - return nil } // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { + e.mu.Lock() e.waiterQueue = waiterQueue e.workerRunning = true - go e.protocolMainLoop(false) // S/R-SAFE: drained on save. + e.mu.Unlock() + wakerInitDone := make(chan struct{}) + go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save. + <-wakerInitDone } // Accept returns a new endpoint if a peer has established a connection @@ -2042,7 +2170,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { defer e.mu.RUnlock() // Endpoint must be in listen state before it can accept connections. - if e.state != StateListen { + if e.EndpointState() != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } @@ -2069,7 +2197,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrAlreadyBound } @@ -2131,7 +2259,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } // Mark endpoint as bound. - e.state = StateBound + e.setEndpointState(StateBound) return nil } @@ -2153,7 +2281,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if !e.state.connected() { + if !e.EndpointState().connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -2164,45 +2292,22 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { - s := newSegment(r, id, pkt) - if !s.parse() { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() - e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() - s.decRef() - return - } - - if !s.csumValid { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.ChecksumErrors.Increment() - e.stats.ReceiveErrors.ChecksumErrors.Increment() - s.decRef() - return - } - - e.stack.Stats().TCP.ValidSegmentsReceived.Increment() - e.stats.SegmentsReceived.Increment() - if (s.flags & header.TCPFlagRst) != 0 { - e.stack.Stats().TCP.ResetsReceived.Increment() - } - - e.enqueueSegment(s) + // TCP HandlePacket is not required anymore as inbound packets first + // land at the Dispatcher which then can either delivery using the + // worker go routine or directly do the invoke the tcp processing inline + // based on the state of the endpoint. } -func (e *endpoint) enqueueSegment(s *segment) { +func (e *endpoint) enqueueSegment(s *segment) bool { // Send packet to worker goroutine. - if e.segmentQueue.enqueue(s) { - e.newSegmentWaker.Assert() - } else { + if !e.segmentQueue.enqueue(s) { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.SegmentQueueDropped.Increment() - s.decRef() + return false } + return true } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. @@ -2319,8 +2424,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int { // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { - if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { - e.recentTS = tsVal + if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + e.setRecentTimestamp(tsVal) } } @@ -2330,7 +2435,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { if synOpts.TS { e.sendTSOk = true - e.recentTS = synOpts.TSVal + e.setRecentTimestamp(synOpts.TSVal) } } @@ -2419,7 +2524,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Endpoint TCP Option state. s.SendTSOk = e.sendTSOk - s.RecentTS = e.recentTS + s.RecentTS = e.recentTimestamp() s.TSOffset = e.tsOffset s.SACKPermitted = e.sackPermitted s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) @@ -2526,9 +2631,7 @@ func (e *endpoint) initGSO() { // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { - e.mu.Lock() - defer e.mu.Unlock() - return uint32(e.state) + return uint32(e.EndpointState()) } // Info returns a copy of the endpoint info. diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 4b8d867bc..4a46f0ec5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -16,6 +16,7 @@ package tcp import ( "fmt" + "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sync" @@ -48,7 +49,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() defer e.mu.Unlock() - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound: // TODO(b/138137272): this enumeration duplicates // EndpointState.connected. remove it. @@ -70,31 +71,30 @@ func (e *endpoint) beforeSave() { fallthrough case StateListen, StateConnecting: e.drainSegmentLocked() - if e.state != StateClose && e.state != StateError { + if e.EndpointState() != StateClose && e.EndpointState() != StateError { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } break } - fallthrough case StateError, StateClose: - for (e.state == StateError || e.state == StateClose) && e.workerRunning { + for e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() } if e.workerRunning { - panic("endpoint still has worker running in closed or error state") + panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID)) } default: - panic(fmt.Sprintf("endpoint in unknown state %v", e.state)) + panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) } if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { panic("endpoint still has waiters upon save") } - if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) { + if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) { panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") } } @@ -135,7 +135,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { // saveState is invoked by stateify. func (e *endpoint) saveState() EndpointState { - return e.state + return e.EndpointState() } // Endpoint loading must be done in the following ordering by their state, to @@ -151,7 +151,8 @@ var connectingLoading sync.WaitGroup func (e *endpoint) loadState(state EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. - if state.connected() { + // For restore purposes we treat TimeWait like a connected endpoint. + if state.connected() || state == StateTimeWait { connectedLoading.Add(1) } switch state { @@ -160,13 +161,14 @@ func (e *endpoint) loadState(state EndpointState) { case StateConnecting, StateSynSent, StateSynRecv: connectingLoading.Add(1) } - e.state = state + // Directly update the state here rather than using e.setEndpointState + // as the endpoint is still being loaded and the stack reference to increment + // metrics is not yet initialized. + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) } // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - // Freeze segment queue before registering to prevent any segments - // from being delivered while it is being restored. e.origEndpointState = e.state // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. @@ -180,7 +182,6 @@ func (e *endpoint) Resume(s *stack.Stack) { e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() state := e.origEndpointState - switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -276,7 +277,7 @@ func (e *endpoint) Resume(s *stack.Stack) { listenLoading.Wait() connectingLoading.Wait() bind() - e.state = StateClose + e.setEndpointState(StateClose) tcpip.AsyncLoading.Done() }() } @@ -288,6 +289,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } + } // saveLastError is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 9a8f64aa6..958c06fa7 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,6 +21,7 @@ package tcp import ( + "runtime" "strings" "time" @@ -104,6 +105,7 @@ type protocol struct { moderateReceiveBuffer bool tcpLingerTimeout time.Duration tcpTimeWaitTimeout time.Duration + dispatcher *dispatcher } // Number returns the tcp protocol number. @@ -134,6 +136,14 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { return h.SourcePort(), h.DestinationPort(), nil } +// QueuePacket queues packets targeted at an endpoint after hashing the packet +// to a specific processing queue. Each queue is serviced by its own processor +// goroutine which is responsible for dequeuing and doing full TCP dispatch of +// the packet. +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + p.dispatcher.queuePacket(r, ep, id, pkt) +} + // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. // @@ -330,5 +340,6 @@ func NewProtocol() stack.TransportProtocol { availableCongestionControl: []string{ccReno, ccCubic}, tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 05c8488f8..958f03ac1 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -169,19 +169,19 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // We just received a FIN, our next state depends on whether we sent a // FIN already or not. r.ep.mu.Lock() - switch r.ep.state { + switch r.ep.EndpointState() { case StateEstablished: - r.ep.state = StateCloseWait + r.ep.setEndpointState(StateCloseWait) case StateFinWait1: if s.flagIsSet(header.TCPFlagAck) { // FIN-ACK, transition to TIME-WAIT. - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } else { // Simultaneous close, expecting a final ACK. - r.ep.state = StateClosing + r.ep.setEndpointState(StateClosing) } case StateFinWait2: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } r.ep.mu.Unlock() @@ -205,16 +205,16 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // shutdown states. if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { r.ep.mu.Lock() - switch r.ep.state { + switch r.ep.EndpointState() { case StateFinWait1: - r.ep.state = StateFinWait2 + r.ep.setEndpointState(StateFinWait2) // Notify protocol goroutine that we have received an // ACK to our FIN so that it can start the FIN_WAIT2 // timer to abort connection if the other side does // not close within 2MSL. r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) case StateLastAck: r.ep.transitionToStateCloseLocked() } @@ -267,7 +267,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo switch state { case StateCloseWait, StateClosing, StateLastAck: if !s.sequenceNumber.LessThanEq(r.rcvNxt) { - s.decRef() // Just drop the segment as we have // already received a FIN and this // segment is after the sequence number @@ -284,7 +283,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { - s.decRef() return true, tcpip.ErrConnectionAborted } if state == StateFinWait1 { @@ -314,7 +312,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // the last actual data octet in a segment in // which it occurs. if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { - s.decRef() return true, tcpip.ErrConnectionAborted } } @@ -336,7 +333,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // r as they arrive. It is called by the protocol main loop. func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { r.ep.mu.RLock() - state := r.ep.state + state := r.ep.EndpointState() closed := r.ep.closed r.ep.mu.RUnlock() diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index fdff7ed81..b74b61e7d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -705,17 +705,15 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } seg.flags = header.TCPFlagAck | header.TCPFlagFin segEnd = seg.sequenceNumber.Add(1) - // Transition to FIN-WAIT1 state since we're initiating an active close. - s.ep.mu.Lock() - switch s.ep.state { + // Update the state to reflect that we have now + // queued a FIN. + switch s.ep.EndpointState() { case StateCloseWait: - // We've already received a FIN and are now sending our own. The - // sender is now awaiting a final ACK for this FIN. - s.ep.state = StateLastAck + s.ep.setEndpointState(StateLastAck) default: - s.ep.state = StateFinWait1 + s.ep.setEndpointState(StateFinWait1) } - s.ep.mu.Unlock() + } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 6edfa8dce..a9dfbe857 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -293,7 +293,6 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SeqNum(uint32(c.IRS+1)), checker.AckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - finHeaders := &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -459,6 +458,9 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence number + // of the FIN itself. checker.SeqNum(uint32(c.IRS)+2), checker.AckNum(0), checker.TCPFlags(header.TCPFlagRst), @@ -1500,6 +1502,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence + // number of the FIN itself. checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. @@ -5441,6 +5446,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { rawEP.SendPacketWithTS(b[start:start+mss], tsVal) packetsSent++ } + // Resume the worker so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() @@ -5508,7 +5514,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { stk := c.Stack() // Set lower limits for auto-tuning tests. This is required because the // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. + // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { @@ -5563,6 +5569,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalSent += mss packetsSent++ } + // Resume it so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 5d114d460..2f9821555 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -533,7 +533,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { // Sleep for a little over the linger timeout to reduce flakiness in // save/restore tests. - absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2)); ds.reset(); diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 6b99c021d..33a5ac66c 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -814,6 +814,20 @@ TEST_P(TcpSocketTest, FullBuffer) { t_ = -1; } +TEST_P(TcpSocketTest, PollAfterShutdown) { + ScopedThread client_thread([this]() { + EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceedsWithValue(0)); + struct pollfd poll_fd = {s_, POLLIN | POLLERR | POLLHUP, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); + }); + + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceedsWithValue(0)); + struct pollfd poll_fd = {t_, POLLIN | POLLERR | POLLHUP, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); +} + TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) { // Initialize address to the loopback one. sockaddr_storage addr = -- cgit v1.2.3 From 7e6fbc6afe797752efe066a8aa86f9eca973f3a4 Mon Sep 17 00:00:00 2001 From: Mithun Iyer Date: Tue, 21 Jan 2020 14:47:04 -0800 Subject: Add a new TCP stat for current open connections. Such a stat accounts for all connections that are currently established and not yet transitioned to close state. Also fix bug in double increment of CurrentEstablished stat. Fixes #1579 PiperOrigin-RevId: 290827365 --- pkg/sentry/socket/netstack/netstack.go | 3 +- pkg/tcpip/tcpip.go | 6 ++- pkg/tcpip/transport/tcp/accept.go | 1 - pkg/tcpip/transport/tcp/connect.go | 1 + pkg/tcpip/transport/tcp/endpoint.go | 1 + pkg/tcpip/transport/tcp/tcp_test.go | 83 ++++++++++++++++++++++++++++++++++ 6 files changed, 92 insertions(+), 3 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 2662fbc0f..318acbeff 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -150,7 +150,8 @@ var Metrics = tcpip.Stats{ TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), - CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), + CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in ESTABLISHED state now."), + CurrentConnected: mustCreateMetric("/netstack/tcp/current_open", "Number of connections that are in connected state."), EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."), EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3fc823a36..59c9b3fb0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -938,9 +938,13 @@ type TCPStats struct { PassiveConnectionOpenings *StatCounter // CurrentEstablished is the number of TCP connections for which the - // current state is either ESTABLISHED or CLOSE-WAIT. + // current state is ESTABLISHED. CurrentEstablished *StatCounter + // CurrentConnected is the number of TCP connections that + // are in connected state. + CurrentConnected *StatCounter + // EstablishedResets is the number of times TCP connections have made // a direct transition to the CLOSED state from either the // ESTABLISHED state or the CLOSE-WAIT state. diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 1a2e3efa9..d469758eb 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -562,7 +562,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // Switch state to connected. // We do not use transitionToStateEstablishedLocked here as there is // no handshake state available when doing a SYN cookie based accept. - n.stack.Stats().TCP.CurrentEstablished.Increment() n.isConnectNotified = true n.setEndpointState(StateEstablished) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index a2f384384..4e3c5419c 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -934,6 +934,7 @@ func (e *endpoint) transitionToStateCloseLocked() { // Mark the endpoint as fully closed for reads/writes. e.cleanupLocked() e.setEndpointState(StateClose) + e.stack.Stats().TCP.CurrentConnected.Decrement() e.stack.Stats().TCP.EstablishedClosed.Increment() } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4797f11d1..13718ff55 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -594,6 +594,7 @@ func (e *endpoint) setEndpointState(state EndpointState) { switch state { case StateEstablished: e.stack.Stats().TCP.CurrentEstablished.Increment() + e.stack.Stats().TCP.CurrentConnected.Increment() case StateError: fallthrough case StateClose: diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a9dfbe857..df2fb1071 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -470,6 +470,89 @@ func TestConnectResetAfterClose(t *testing.T) { } } +// TestCurrentConnectedIncrement tests increment of the current +// established and connected counters. +func TestCurrentConnectedIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 1", got) + } + gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() + if gotConnected != 1 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 1", gotConnected) + } + + ep.Close() + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = %v", got, gotConnected) + } + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Wait for the TIME-WAIT state to transition to CLOSED. + time.Sleep(1 * time.Second) + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 0", got) + } +} + // TestClosingWithEnqueuedSegments tests handling of still enqueued segments // when the endpoint transitions to StateClose. The in-flight segments would be // re-enqueued to a any listening endpoint. -- cgit v1.2.3 From 51b783505b1ec164b02b48a0fd234509fba01a73 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 29 Jan 2020 15:41:51 -0800 Subject: Add support for TCP_DEFER_ACCEPT. PiperOrigin-RevId: 292233574 --- pkg/sentry/socket/netstack/netstack.go | 22 ++++ pkg/tcpip/tcpip.go | 6 ++ pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/accept.go | 25 ++--- pkg/tcpip/transport/tcp/connect.go | 53 +++++++++- pkg/tcpip/transport/tcp/endpoint.go | 26 ++++- pkg/tcpip/transport/tcp/forwarder.go | 4 +- pkg/tcpip/transport/tcp/tcp_test.go | 126 ++++++++++++++++++++++ test/syscalls/linux/socket_inet_loopback.cc | 158 ++++++++++++++++++++++++++++ test/syscalls/linux/tcp_socket.cc | 53 ++++++++++ 10 files changed, 451 insertions(+), 23 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 8619cc506..049d04bf2 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1260,6 +1260,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_DEFER_ACCEPT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPDeferAcceptOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + default: emitUnimplementedEventTCP(t, name) } @@ -1713,6 +1725,16 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + case linux.TCP_DEFER_ACCEPT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + v = 0 + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 59c9b3fb0..0fa141d58 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -626,6 +626,12 @@ type TCPLingerTimeoutOption time.Duration // before being marked closed. type TCPTimeWaitTimeoutOption time.Duration +// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a +// accept to return a completed connection only when there is data to be +// read. This usually means the listening socket will drop the final ACK +// for a handshake till the specified timeout until a segment with data arrives. +type TCPDeferAcceptOption time.Duration + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 4acd9fb9a..7b4a87a2d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -57,6 +57,7 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d469758eb..6101f2945 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -222,13 +222,13 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { netProto = s.route.NetProto } - n := newEndpoint(l.stack, netProto, nil) + n := newEndpoint(l.stack, netProto, queue) n.v6only = l.v6only n.ID = s.id n.boundNICID = s.route.NICID() @@ -273,16 +273,17 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // createEndpoint creates a new endpoint in connected state and then performs // the TCP 3-way handshake. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) if err != nil { return nil, err } // listenEP is nil when listenContext is used by tcp.Forwarder. + deferAccept := time.Duration(0) if l.listenEP != nil { l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { @@ -290,13 +291,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } // Perform the 3-way handshake. - h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - - h.resetToSynRcvd(isn, irs, opts) + h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.Close() if l.listenEP != nil { @@ -377,16 +377,14 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer e.decSynRcvdCount() defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts) + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) + n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() e.deliverAccepted(n) @@ -546,7 +544,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) + n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{}) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() @@ -576,8 +574,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // space available in the backlog. // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) + n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 4e3c5419c..9ff7ac261 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -86,6 +86,19 @@ type handshake struct { // rcvWndScale is the receive window scale, as defined in RFC 1323. rcvWndScale int + + // startTime is the time at which the first SYN/SYN-ACK was sent. + startTime time.Time + + // deferAccept if non-zero will drop the final ACK for a passive + // handshake till an ACK segment with data is received or the timeout is + // hit. + deferAccept time.Duration + + // acked is true if the the final ACK for a 3-way handshake has + // been received. This is required to stop retransmitting the + // original SYN-ACK when deferAccept is enabled. + acked bool } func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { @@ -112,6 +125,12 @@ func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { return h } +func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { + h := newHandshake(ep, rcvWnd) + h.resetToSynRcvd(isn, irs, opts, deferAccept) + return h +} + // FindWndScale determines the window scale to use for the given maximum window // size. func FindWndScale(wnd seqnum.Size) int { @@ -181,7 +200,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 { // resetToSynRcvd resets the state of the handshake object to the SYN-RCVD // state. -func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) { h.active = false h.state = handshakeSynRcvd h.flags = header.TCPFlagSyn | header.TCPFlagAck @@ -189,6 +208,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.ackNum = irs + 1 h.mss = opts.MSS h.sndWndScale = opts.WS + h.deferAccept = deferAccept h.ep.mu.Lock() h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() @@ -352,6 +372,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { + // If deferAccept is not zero and this is a bare ACK and the + // timeout is not hit then drop the ACK. + if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept { + h.acked = true + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. @@ -365,10 +393,16 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() h.ep.transitionToStateEstablishedLocked(h) + // If the segment has data then requeue it for the receiver + // to process it again once main loop is started. + if s.data.Size() > 0 { + s.incRef() + h.ep.enqueueSegment(s) + } h.ep.mu.Unlock() - return nil } @@ -471,6 +505,7 @@ func (h *handshake) execute() *tcpip.Error { } } + h.startTime = time.Now() // Initialize the resend timer. resendWaker := sleep.Waker{} timeOut := time.Duration(time.Second) @@ -524,11 +559,21 @@ func (h *handshake) execute() *tcpip.Error { switch index, _ := s.Fetch(true); index { case wakerForResend: timeOut *= 2 - if timeOut > 60*time.Second { + if timeOut > MaxRTO { return tcpip.ErrTimeout } rt.Reset(timeOut) - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + // Resend the SYN/SYN-ACK only if the following conditions hold. + // - It's an active handshake (deferAccept does not apply) + // - It's a passive handshake and we have not yet got the final-ACK. + // - It's a passive handshake and we got an ACK but deferAccept is + // enabled and we are now past the deferAccept duration. + // The last is required to provide a way for the peer to complete + // the connection with another ACK or data (as ACKs are never + // retransmitted on their own). + if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + } case wakerForNotification: n := h.ep.fetchNotifications() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 13718ff55..8d52414b7 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -498,6 +498,13 @@ type endpoint struct { // without any data being acked. userTimeout time.Duration + // deferAccept if non-zero specifies a user specified time during + // which the final ACK of a handshake will be dropped provided the + // ACK is a bare ACK and carries no data. If the timeout is crossed then + // the bare ACK is accepted and the connection is delivered to the + // listener. + deferAccept time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -1574,6 +1581,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.TCPDeferAcceptOption: + e.mu.Lock() + if time.Duration(v) > MaxRTO { + v = tcpip.TCPDeferAcceptOption(MaxRTO) + } + e.deferAccept = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1798,6 +1814,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.TCPDeferAcceptOption: + e.mu.Lock() + *o = tcpip.TCPDeferAcceptOption(e.deferAccept) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -2149,9 +2171,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. -func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { +func (e *endpoint) startAcceptedLoop() { e.mu.Lock() - e.waiterQueue = waiterQueue e.workerRunning = true e.mu.Unlock() wakerInitDone := make(chan struct{}) @@ -2177,7 +2198,6 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { default: return nil, nil, tcpip.ErrWouldBlock } - return n, n.waiterQueue, nil } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 7eb613be5..c9ee5bf06 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -157,13 +157,13 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, TSVal: r.synOptions.TSVal, TSEcr: r.synOptions.TSEcr, SACKPermitted: r.synOptions.SACKPermitted, - }) + }, queue) if err != nil { return nil, err } // Start the protocol goroutine. - ep.startAcceptedLoop(queue) + ep.startAcceptedLoop() return ep, nil } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index df2fb1071..a12336d47 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6787,3 +6787,129 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { ), ) } + +func TestTCPDeferAccept(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + } + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give a bit of time for the socket to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestTCPDeferAcceptTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + } + + // Sleep for a little of the tcpDeferAccept timeout. + time.Sleep(tcpDeferAccept + 100*time.Millisecond) + + // On timeout expiry we should get a SYN-ACK retransmission. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give sometime for the endpoint to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 2f9821555..3bf7081b9 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -828,6 +828,164 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { EXPECT_EQ(get, kUserTimeout); } +// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not +// saved. Enable S/R once issue is fixed. +TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) { + // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not + // saved. Enable S/R issue is fixed. + DisableSave ds; + + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the TCP_DEFER_ACCEPT on the listening socket. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Set the listening socket to nonblock so that we can verify that there is no + // connection in queue despite the connect above succeeding since the peer has + // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the + // FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Set FD back to blocking. + opts &= ~O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Now write some data to the socket. + int data = 0; + ASSERT_THAT(RetryEINTR(write)(conn_fd.get(), &data, sizeof(data)), + SyscallSucceedsWithValue(sizeof(data))); + + // This should now cause the connection to complete and be delivered to the + // accept socket. + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Verify that the accepted socket returns the data written. + int get = -1; + ASSERT_THAT(RetryEINTR(recv)(accepted.get(), &get, sizeof(get), 0), + SyscallSucceedsWithValue(sizeof(get))); + + EXPECT_EQ(get, data); +} + +// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not +// saved. Enable S/R once issue is fixed. +TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) { + // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not + // saved. Enable S/R once issue is fixed. + DisableSave ds; + + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the TCP_DEFER_ACCEPT on the listening socket. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Set the listening socket to nonblock so that we can verify that there is no + // connection in queue despite the connect above succeeding since the peer has + // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the + // FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Verify that there is no acceptable connection before TCP_DEFER_ACCEPT + // timeout is hit. + absl::SleepFor(absl::Seconds(kTCPDeferAccept - 1)); + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Set FD back to blocking. + opts &= ~O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Now sleep for a little over the TCP_DEFER_ACCEPT duration. When the timeout + // is hit a SYN-ACK should be retransmitted by the listener as a last ditch + // attempt to complete the connection with or without data. + absl::SleepFor(absl::Seconds(2)); + + // Verify that we have a connection that can be accepted even though no + // data was written. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 33a5ac66c..525ccbd88 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1286,6 +1286,59 @@ TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) { EXPECT_EQ(get, kTCPUserTimeout); } +TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptNeg) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // -ve TCP_DEFER_ACCEPT is same as setting it to zero. + constexpr int kNeg = -1; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &kNeg, sizeof(kNeg)), + SyscallSucceeds()); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); +} + +TEST_P(SimpleTcpSocketTest, GetTCPDeferAcceptDefault) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); +} + +TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + // kTCPDeferAccept is in seconds. + // NOTE: linux translates seconds to # of retries and back from + // #of retries to seconds. Which means only certain values + // translate back exactly. That's why we use 3 here, a value of + // 5 will result in us getting back 7 instead of 5 in the + // getsockopt. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPDeferAccept); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3 From 02997af5abd62d778fca4d01b047a6bdebab2090 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Fri, 31 Jan 2020 15:08:17 -0800 Subject: Fix method comment to match method name. PiperOrigin-RevId: 292624867 --- pkg/tcpip/transport/tcp/accept.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6101f2945..08afb7c17 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -271,8 +271,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i return n, nil } -// createEndpoint creates a new endpoint in connected state and then performs -// the TCP 3-way handshake. +// createEndpointAndPerformHandshake creates a new endpoint in connected state +// and then performs the TCP 3-way handshake. func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber -- cgit v1.2.3 From c37b196455e8b3816298e3eea98e4ee2dab8d368 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Mon, 24 Feb 2020 10:31:01 -0800 Subject: Add support for tearing down protocol dispatchers and TIME_WAIT endpoints. Protocol dispatchers were previously leaked. Bypassing TIME_WAIT is required to test this change. Also fix a race when a socket in SYN-RCVD is closed. This is also required to test this change. PiperOrigin-RevId: 296922548 --- pkg/tcpip/adapters/gonet/gonet_test.go | 63 ++++++++++++++++++++++++++-------- pkg/tcpip/network/arp/arp.go | 20 +++++++---- pkg/tcpip/network/ipv4/ipv4.go | 6 ++++ pkg/tcpip/network/ipv6/ipv6.go | 6 ++++ pkg/tcpip/stack/registration.go | 23 ++++++++++--- pkg/tcpip/stack/stack.go | 14 +++++++- pkg/tcpip/stack/stack_test.go | 6 ++++ pkg/tcpip/stack/transport_demuxer.go | 20 ----------- pkg/tcpip/stack/transport_test.go | 15 +++++++- pkg/tcpip/tcpip.go | 8 ++++- pkg/tcpip/transport/icmp/endpoint.go | 5 +++ pkg/tcpip/transport/icmp/protocol.go | 16 ++++++--- pkg/tcpip/transport/packet/endpoint.go | 5 +++ pkg/tcpip/transport/raw/endpoint.go | 5 +++ pkg/tcpip/transport/tcp/accept.go | 9 ++++- pkg/tcpip/transport/tcp/connect.go | 4 +-- pkg/tcpip/transport/tcp/dispatcher.go | 31 ++++++++++++++++- pkg/tcpip/transport/tcp/endpoint.go | 33 ++++++++++++++++-- pkg/tcpip/transport/tcp/protocol.go | 14 ++++++-- pkg/tcpip/transport/udp/endpoint.go | 5 +++ pkg/tcpip/transport/udp/protocol.go | 14 +++++--- 21 files changed, 256 insertions(+), 66 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index ea0a0409a..3c552988a 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -127,6 +127,10 @@ func TestCloseReader(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -175,6 +179,10 @@ func TestCloseReaderWithForwarder(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) + _, err := r.CreateEndpoint(&wq) if err != nil { t.Fatalf("r.CreateEndpoint() = %v", err) } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e) - } - - if n, e = c.Write([]byte("abc123")); e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } + // Endpoint will be closed in deferred s.Close (above). }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) @@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -391,6 +398,10 @@ func TestDeadlineChange(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { stop = func() { c1.Close() c2.Close() + s.Close() + s.Wait() } if err := l.Close(); err != nil { @@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 4da13c5df..e9fcc89a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi }, nil } -// LinkAddressProtocol implements stack.LinkAddressResolver. +// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } -// LinkAddressRequest implements stack.LinkAddressResolver. +// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ RemoteLinkAddress: broadcastMAC, @@ -172,7 +172,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) } -// ResolveStaticAddress implements stack.LinkAddressResolver. +// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { return broadcastMAC, true @@ -183,16 +183,22 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo return tcpip.LinkAddress([]byte(nil)), false } -// SetOption implements NetworkProtocol. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.NetworkProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements NetworkProtocol. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.NetworkProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6597e6781..4f1742938 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 180a480fd..9aef5234b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index d83adf0ec..f9fd8f18f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -74,10 +74,11 @@ type TransportEndpoint interface { // HandleControlPacket takes ownership of pkt. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) - // Close puts the endpoint in a closed state and frees all resources - // associated with it. This cleanup may happen asynchronously. Wait can - // be used to block on this asynchronous cleanup. - Close() + // Abort initiates an expedited endpoint teardown. It puts the endpoint + // in a closed state and frees all resources associated with it. This + // cleanup may happen asynchronously. Wait can be used to block on this + // asynchronous cleanup. + Abort() // Wait waits for any worker goroutines owned by the endpoint to stop. // @@ -160,6 +161,13 @@ type TransportProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // TransportDispatcher contains the methods used by the network stack to deliver @@ -293,6 +301,13 @@ type NetworkProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // NetworkDispatcher contains the methods used by the network stack to deliver diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 900dd46c5..ebb6c5e3b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1446,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { // Endpoints created or modified during this call may not get closed. func (s *Stack) Close() { for _, e := range s.RegisteredEndpoints() { - e.Close() + e.Abort() + } + for _, p := range s.transportProtocols { + p.proto.Close() + } + for _, p := range s.networkProtocols { + p.Close() } } @@ -1464,6 +1470,12 @@ func (s *Stack) Wait() { for _, e := range s.CleanupEndpoints() { e.Wait() } + for _, p := range s.transportProtocols { + p.proto.Wait() + } + for _, p := range s.networkProtocols { + p.Wait() + } s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 18016e7db..edf6bec52 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -235,6 +235,12 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +// Close implements TransportProtocol.Close. +func (*fakeNetworkProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeNetworkProtocol) Wait() {} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index d686e6eb8..778c0a4d6 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p ep.mu.RUnlock() // Don't use defer for performance reasons. } -// Close implements stack.TransportEndpoint.Close. -func (ep *multiPortEndpoint) Close() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Close() - } -} - -// Wait implements stack.TransportEndpoint.Wait. -func (ep *multiPortEndpoint) Wait() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Wait() - } -} - // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 869c69a6d..5d1da2f8b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } +func (f *fakeTransportEndpoint) Abort() { + f.Close() +} + func (f *fakeTransportEndpoint) Close() { f.route.Release() } @@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } -func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +// Abort implements TransportProtocol.Abort. +func (*fakeTransportProtocol) Abort() {} + +// Close implements tcpip.Endpoint.Close. +func (*fakeTransportProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeTransportProtocol) Wait() {} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ce5527391..3dc5d87d6 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -341,9 +341,15 @@ type ControlMessages struct { // networking stack. type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources - // associated with it. + // associated with it. Close initiates the teardown process, the + // Endpoint may not be fully closed when Close returns. Close() + // Abort initiates an expedited endpoint teardown. As compared to + // Close, Abort prioritizes closing the Endpoint quickly over cleanly. + // Abort is best effort; implementing Abort with Close is acceptable. + Abort() + // Read reads data from the endpoint and optionally returns the sender. // // This method does not block if there is no data pending. It will also diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 42afb3f5b..426da1ee6 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 9ce500e80..113d92901 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol4 returns an ICMPv4 transport protocol. func NewProtocol4() stack.TransportProtocol { return &protocol{ProtocolNumber4} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index fc5bc69fa..5722815e9 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -98,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb return ep, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (ep *endpoint) Close() { ep.mu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ee9c4c58b..2ef5fac76 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (e *endpoint) Close() { e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 08afb7c17..13e383ffc 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -299,6 +299,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.Close() + // Wake up any waiters. This is strictly not required normally + // as a socket that was never accepted can't really have any + // registered waiters except when stack.Wait() is called which + // waits for all registered endpoints to stop and expects an + // EventHUp. + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + if l.listenEP != nil { l.removePendingEndpoint(ep) } @@ -607,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Unlock() // Notify waiters that the endpoint is shutdown. - e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() s := sleep.Sleeper{} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5c5397823..7730e6445 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1372,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.snd.updateMaxPayloadSize(mtu, count) } - if n¬ifyReset != 0 { + if n¬ifyReset != 0 || n¬ifyAbort != 0 { return tcpip.ErrConnectionAborted } @@ -1655,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { } case notification: n := e.fetchNotifications() - if n¬ifyClose != 0 { + if n¬ifyClose != 0 || n¬ifyAbort != 0 { return nil } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index e18012ac0..d792b07d6 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -68,17 +68,28 @@ func (q *epQueue) empty() bool { type processor struct { epQ epQueue newEndpointWaker sleep.Waker + closeWaker sleep.Waker id int + wg sync.WaitGroup } func newProcessor(id int) *processor { p := &processor{ id: id, } + p.wg.Add(1) go p.handleSegments() return p } +func (p *processor) close() { + p.closeWaker.Assert() +} + +func (p *processor) wait() { + p.wg.Wait() +} + func (p *processor) queueEndpoint(ep *endpoint) { // Queue an endpoint for processing by the processor goroutine. p.epQ.enqueue(ep) @@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) { func (p *processor) handleSegments() { const newEndpointWaker = 1 + const closeWaker = 2 s := sleep.Sleeper{} s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + s.AddWaker(&p.closeWaker, closeWaker) defer s.Done() for { - s.Fetch(true) + id, ok := s.Fetch(true) + if ok && id == closeWaker { + p.wg.Done() + return + } for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { if ep.segmentQueue.empty() { continue @@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher { } } +func (d *dispatcher) close() { + for _, p := range d.processors { + p.close() + } +} + +func (d *dispatcher) wait() { + for _, p := range d.processors { + p.wait() + } +} + func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f2be0e651..f1ad19dac 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -121,6 +121,8 @@ const ( notifyDrain notifyReset notifyResetByPeer + // notifyAbort is a request for an expedited teardown. + notifyAbort notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -785,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { } } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + // The abort notification is not processed synchronously, so no + // synchronization is needed. + // + // If the endpoint becomes connected after this check, we still close + // the endpoint. This worst case results in a slower abort. + // + // If the endpoint disconnected after the check, nothing needs to be + // done, so sending a notification which will potentially be ignored is + // fine. + if e.EndpointState().connected() { + e.notifyProtocolGoroutine(notifyAbort) + return + } + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources associated // with it. It must be called only once and with no other concurrent calls to // the endpoint. @@ -829,9 +849,18 @@ func (e *endpoint) closeNoShutdown() { // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) - if !e.workerRunning { + switch e.EndpointState() { + // Sockets in StateSynRecv state(passive connections) are closed when + // the handshake fails or if the listening socket is closed while + // handshake was in progress. In such cases the handshake goroutine + // is already gone by the time Close is called and we need to cleanup + // here. + case StateInitial, StateBound, StateSynRecv: e.cleanupLocked() - } else { + e.setEndpointState(StateClose) + case StateError, StateClose: + // do nothing. + default: e.workerCleanup = true e.notifyProtocolGoroutine(notifyClose) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 958c06fa7..73098d904 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -194,7 +194,7 @@ func replyWithReset(s *segment) { sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } -// SetOption implements TransportProtocol.SetOption. +// SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { case SACKEnabled: @@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } } -// Option implements TransportProtocol.Option. +// Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { case *SACKEnabled: @@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { } } +// Close implements stack.TransportProtocol.Close. +func (p *protocol) Close() { + p.dispatcher.close() +} + +// Wait implements stack.TransportProtocol.Wait. +func (p *protocol) Wait() { + p.dispatcher.wait() +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index eff7f3600..1c6a600b8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -186,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 259c3072a..8df089d22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{} -- cgit v1.2.3 From c6bdc6b05b4abfcb3c677013496c9a6b1d1365dd Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Thu, 27 Feb 2020 14:14:34 -0800 Subject: Fix a race in TCP endpoint teardown and teardown the stack in tcp_test. Call stack.Close on stacks when we are done with them in tcp_test. This avoids leaking resources and reduces the test's flakiness when race/gotsan is enabled. It also provides test coverage for the race also fixed in this change, which can be reliably triggered with the stack.Close change (and without the other changes) when race/gotsan is enabled. The race was possible when calling Abort (via stack.Close) on an endpoint processing a SYN segment as part of a passive connect. Updates #1564 PiperOrigin-RevId: 297685432 --- pkg/tcpip/transport/tcp/accept.go | 1 + pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 16 +++++++++++++++- pkg/tcpip/transport/tcp/testing/context/context.go | 1 + 4 files changed, 18 insertions(+), 2 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 13e383ffc..85049e54e 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -236,6 +236,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} n.rcvBufSize = int(l.rcvWnd) n.amss = mssForRoute(&n.route) + n.setEndpointState(StateConnecting) n.maybeEnableTimestamp(rcvdSynOpts) n.maybeEnableSACKPermitted(rcvdSynOpts) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 7730e6445..cd247f3e1 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -577,7 +577,7 @@ func (h *handshake) execute() *tcpip.Error { case wakerForNotification: n := h.ep.fetchNotifications() - if n¬ifyClose != 0 { + if (n¬ifyClose)|(n¬ifyAbort) != 0 { return tcpip.ErrAborted } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f1ad19dac..9e72730bd 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -798,7 +798,21 @@ func (e *endpoint) Abort() { // If the endpoint disconnected after the check, nothing needs to be // done, so sending a notification which will potentially be ignored is // fine. - if e.EndpointState().connected() { + // + // If the endpoint connecting finishes after the check, the endpoint + // is either in a connected state (where we would notifyAbort anyway), + // SYN-RECV (where we would also notifyAbort anyway), or in an error + // state where nothing is required and the notification can be safely + // ignored. + // + // Endpoints where a Close during connecting or SYN-RECV state would be + // problematic are set to state connecting before being registered (and + // thus possible to be Aborted). They are never available in initial + // state. + // + // Endpoints transitioning from initial to connecting state may be + // safely either closed or sent notifyAbort. + if s := e.EndpointState(); s == StateConnecting || s == StateSynRecv || s.connected() { e.notifyProtocolGoroutine(notifyAbort) return } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 1e9a0dea3..8cea20fb5 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -204,6 +204,7 @@ func (c *Context) Cleanup() { if c.EP != nil { c.EP.Close() } + c.Stack().Close() } // Stack returns a reference to the stack in the Context. -- cgit v1.2.3 From e9e399c25d4fcad2adfe92d73b192b9784774964 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 19 Mar 2020 07:18:47 -0700 Subject: Remove workMu from tcpip.Endpoint. workMu is removed and e.mu is now a mutex that supports TryLock. The packet processing path tries to lock the mutex and if its locked it will just queue the packet and move on. The endpoint.UnlockUser() will process any backlog of packets before unlocking the socket. This simplifies the locking inside tcp endpoints a lot. Further the endpoint.LockUser() implements spinning as long as the lock is not held by another syscall goroutine. This ensures low latency as not spinning leads to the task thread being put to sleep if the lock is held by the packet dispatch path. This is suboptimal as the lower layer rarely holds the lock for long so implementing spinning here helps. If the lock is held by another task goroutine then we just proceed to call LockUser() and the task could be put to sleep. The protocol goroutines themselves just call e.mu.Lock() and block if the lock is currently not available. Updates #231, #357 PiperOrigin-RevId: 301808349 --- pkg/sentry/kernel/epoll/epoll.go | 2 + pkg/sentry/socket/netstack/netstack.go | 24 +- pkg/tcpip/transport/tcp/accept.go | 90 +++--- pkg/tcpip/transport/tcp/connect.go | 78 ++--- pkg/tcpip/transport/tcp/dispatcher.go | 8 +- pkg/tcpip/transport/tcp/endpoint.go | 495 ++++++++++++++++-------------- pkg/tcpip/transport/tcp/endpoint_state.go | 5 +- pkg/tcpip/transport/tcp/protocol.go | 38 +-- pkg/tcpip/transport/tcp/rcv.go | 6 - pkg/tcpip/transport/tcp/segment_queue.go | 8 +- pkg/tcpip/transport/tcp/snd.go | 3 - pkg/tcpip/transport/tcp/tcp_test.go | 41 ++- 12 files changed, 424 insertions(+), 374 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 8bffb78fc..592650923 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -296,8 +296,10 @@ func (*readyCallback) Callback(w *waiter.Entry) { e.waitingList.Remove(entry) e.readyList.PushBack(entry) entry.curList = &e.readyList + e.listsMu.Unlock() e.Notify(waiter.EventIn) + return } e.listsMu.Unlock() diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 13a9a60b4..a2e1da02f 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -29,6 +29,7 @@ import ( "io" "math" "reflect" + "sync/atomic" "syscall" "time" @@ -264,6 +265,12 @@ type SocketOperations struct { skType linux.SockType protocol int + // readViewHasData is 1 iff readView has data to be read, 0 otherwise. + // Must be accessed using atomic operations. It must only be written + // with readMu held but can be read without holding readMu. The latter + // is required to avoid deadlocks in epoll Readiness checks. + readViewHasData uint32 + // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` // readView contains the remaining payload from the last packet. @@ -410,21 +417,24 @@ func (s *SocketOperations) isPacketBased() bool { // fetchReadView updates the readView field of the socket if it's currently // empty. It assumes that the socket is locked. +// +// Precondition: s.readMu must be held. func (s *SocketOperations) fetchReadView() *syserr.Error { if len(s.readView) > 0 { return nil } - s.readView = nil s.sender = tcpip.FullAddress{} v, cms, err := s.Endpoint.Read(&s.sender) if err != nil { + atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) } s.readView = v s.readCM = cms + atomic.StoreUint32(&s.readViewHasData, 1) return nil } @@ -623,11 +633,9 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Check our cached value iff the caller asked for readability and the // endpoint itself is currently not readable. if (mask & ^r & waiter.EventIn) != 0 { - s.readMu.Lock() - if len(s.readView) > 0 { + if atomic.LoadUint32(&s.readViewHasData) == 1 { r |= waiter.EventIn } - s.readMu.Unlock() } return r @@ -2334,6 +2342,10 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq } copied += n s.readView.TrimFront(n) + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + dst = dst.DropFirst(n) if e != nil { err = syserr.FromError(e) @@ -2456,6 +2468,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readView.TrimFront(int(n)) } + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + var flags int if msgLen > int(n) { flags |= linux.MSG_TRUNC diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 85049e54e..4d7602d54 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -221,7 +221,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu } // createConnectingEndpoint creates a new endpoint in a connecting state, with -// the connection parameters given by the arguments. +// the connection parameters given by the arguments. The endpoint is returned +// with n.mu held. func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto @@ -243,21 +244,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() - // Now inherit any socket options that should be inherited from the - // listening endpoint. - // In case of Forwarder listenEP will be nil and hence this check. - if l.listenEP != nil { - l.listenEP.propagateInheritableOptions(n) - } - - // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { - n.Close() - return nil, err - } - - n.isRegistered = true - // Create sender and receiver. // // The receiver at least temporarily has a zero receive window scale, @@ -269,11 +255,27 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() + // Lock the endpoint before registering to ensure that no out of + // band changes are possible due to incoming packets etc till + // the endpoint is done initializing. + n.mu.Lock() + + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { + n.mu.Unlock() + n.Close() + return nil, err + } + + n.isRegistered = true + return n, nil } // createEndpointAndPerformHandshake creates a new endpoint in connected state // and then performs the TCP 3-way handshake. +// +// The new endpoint is returned with e.mu held. func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber @@ -289,9 +291,25 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { l.listenEP.mu.Unlock() + // Ensure we release any registrations done by the newly + // created endpoint. + ep.mu.Unlock() + ep.Close() + + // Wake up any waiters. This is strictly not required normally + // as a socket that was never accepted can't really have any + // registered waiters except when stack.Wait() is called which + // waits for all registered endpoints to stop and expects an + // EventHUp. + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + l.listenEP.propagateInheritableOptionsLocked(ep) + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } @@ -299,6 +317,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { + ep.mu.Unlock() ep.Close() // Wake up any waiters. This is strictly not required normally // as a socket that was never accepted can't really have any @@ -312,9 +331,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head } return nil, err } - ep.mu.Lock() ep.isConnectNotified = true - ep.mu.Unlock() // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window @@ -366,12 +383,12 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } } -// propagateInheritableOptions propagates any options set on the listening +// propagateInheritableOptionsLocked propagates any options set on the listening // endpoint to the newly created endpoint. -func (e *endpoint) propagateInheritableOptions(n *endpoint) { - e.mu.Lock() +// +// Precondition: e.mu and n.mu must be held. +func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout - e.mu.Unlock() } // handleSynSegment is called in its own goroutine once the listening endpoint @@ -382,7 +399,11 @@ func (e *endpoint) propagateInheritableOptions(n *endpoint) { // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { defer decSynRcvdCount() - defer e.decSynRcvdCount() + defer func() { + e.mu.Lock() + e.decSynRcvdCount() + e.mu.Unlock() + }() defer s.decRef() n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}) @@ -399,29 +420,21 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } func (e *endpoint) incSynRcvdCount() bool { - e.mu.Lock() - if e.synRcvdCount >= cap(e.acceptedChan) { - e.mu.Unlock() + if e.synRcvdCount >= (cap(e.acceptedChan)) { return false } e.synRcvdCount++ - e.mu.Unlock() return true } func (e *endpoint) decSynRcvdCount() { - e.mu.Lock() e.synRcvdCount-- - e.mu.Unlock() } func (e *endpoint) acceptQueueIsFull() bool { - e.mu.Lock() if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c { - e.mu.Unlock() return true } - e.mu.Unlock() return false } @@ -559,6 +572,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + e.propagateInheritableOptionsLocked(n) + // clear the tsOffset for the newly created // endpoint as the Timestamp was already // randomly offset when the original SYN-ACK was @@ -593,14 +610,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() v6only := e.v6only - e.mu.Unlock() ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running // handleSynSegment() from attempting to queue new connections // to the endpoint. - e.mu.Lock() e.setEndpointState(StateClose) // close any endpoints in SYN-RCVD state. @@ -622,7 +637,10 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) for { - switch index, _ := s.Fetch(true); index { + e.mu.Unlock() + index, _ := s.Fetch(true) + e.mu.Lock() + switch index { case wakerForNotification: n := e.fetchNotifications() if n¬ifyClose != 0 { @@ -635,7 +653,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { s.decRef() } close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } case wakerForNewSegment: diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index be86af502..edb37a549 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -61,6 +61,9 @@ const ( ) // handshake holds the state used during a TCP 3-way handshake. +// +// NOTE: handshake.ep.mu is held during handshake processing. It is released if +// we are going to block and reacquired when we start processing an event. type handshake struct { ep *endpoint state handshakeState @@ -209,9 +212,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.deferAccept = deferAccept - h.ep.mu.Lock() h.ep.setEndpointState(StateSynRecv) - h.ep.mu.Unlock() } // checkAck checks if the ACK number, if present, of a segment received during @@ -241,9 +242,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK // was acceptable then signal the user "error: connection reset", drop // the segment, enter CLOSED state, delete TCB, and return." - h.ep.mu.Lock() h.ep.workerCleanup = true - h.ep.mu.Unlock() // Although the RFC above calls out ECONNRESET, Linux actually returns // ECONNREFUSED here so we do as well. return tcpip.ErrConnectionRefused @@ -281,9 +280,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { if s.flagIsSet(header.TCPFlagAck) { h.state = handshakeCompleted - h.ep.mu.Lock() h.ep.transitionToStateEstablishedLocked(h) - h.ep.mu.Unlock() h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil @@ -293,11 +290,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // but resend our own SYN and wait for it to be acknowledged in the // SYN-RCVD state. h.state = handshakeSynRcvd - h.ep.mu.Lock() ttl := h.ep.ttl amss := h.ep.amss h.ep.setEndpointState(StateSynRecv) - h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), TS: rcvSynOpts.TS, @@ -357,10 +352,6 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - h.ep.mu.RLock() - amss := h.ep.amss - h.ep.mu.RUnlock() - h.resetState() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -368,7 +359,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, - MSS: amss, + MSS: h.ep.amss, } h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil @@ -399,15 +390,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { } h.state = handshakeCompleted - h.ep.mu.Lock() h.ep.transitionToStateEstablishedLocked(h) + // If the segment has data then requeue it for the receiver // to process it again once main loop is started. if s.data.Size() > 0 { s.incRef() h.ep.enqueueSegment(s) } - h.ep.mu.Unlock() return nil } @@ -493,7 +483,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { } if n¬ifyDrain != 0 { close(h.ep.drainDone) + h.ep.mu.Unlock() <-h.ep.undrain + h.ep.mu.Lock() } } @@ -535,7 +527,6 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. - h.ep.mu.Lock() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ @@ -546,7 +537,6 @@ func (h *handshake) execute() *tcpip.Error { SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } - h.ep.mu.Unlock() // Execute is also called in a listen context so we want to make sure we // only send the TS/SACK option when we received the TS/SACK in the @@ -563,7 +553,11 @@ func (h *handshake) execute() *tcpip.Error { h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) for h.state != handshakeCompleted { - switch index, _ := s.Fetch(true); index { + h.ep.mu.Unlock() + index, _ := s.Fetch(true) + h.ep.mu.Lock() + switch index { + case wakerForResend: timeOut *= 2 if timeOut > MaxRTO { @@ -600,7 +594,9 @@ func (h *handshake) execute() *tcpip.Error { } } close(h.ep.drainDone) + h.ep.mu.Unlock() <-h.ep.undrain + h.ep.mu.Lock() } case wakerForNewSegment: @@ -1016,7 +1012,6 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // except SYN-SENT, all reset (RST) segments are // validated by checking their SEQ-fields." So // we only process it if it's acceptable. - e.mu.Lock() switch e.EndpointState() { // In case of a RST in CLOSE-WAIT linux moves // the socket to closed state with an error set @@ -1040,11 +1035,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { case StateCloseWait: e.transitionToStateCloseLocked() e.HardError = tcpip.ErrAborted - e.mu.Unlock() e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: - e.mu.Unlock() // RFC 793, page 37 states that "in all states // except SYN-SENT, all reset (RST) segments are // validated by checking their SEQ-fields." So @@ -1157,9 +1150,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. - e.mu.RLock() state := e.state - e.mu.RUnlock() if state == StateClose { // When we get into StateClose while processing from the queue, // return immediately and let the protocolMainloop handle it. @@ -1182,9 +1173,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { - e.mu.RLock() userTimeout := e.userTimeout - e.mu.RUnlock() e.keepalive.Lock() if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { @@ -1248,6 +1237,7 @@ func (e *endpoint) disableKeepaliveTimer() { // goroutine and is responsible for sending segments and handling received // segments. func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { + e.mu.Lock() var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1269,7 +1259,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } e.mu.Unlock() - e.workMu.Unlock() // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1280,16 +1269,13 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // completion. initialRcvWnd := e.initialReceiveWindow() h := newHandshake(e, seqnum.Size(initialRcvWnd)) - e.mu.Lock() h.ep.setEndpointState(StateSynSent) - e.mu.Unlock() if err := h.execute(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() - e.mu.Lock() e.setEndpointState(StateError) e.HardError = err @@ -1302,9 +1288,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.keepalive.timer.init(&e.keepalive.waker) defer e.keepalive.timer.cleanup() - e.mu.Lock() drained := e.drainDone != nil - e.mu.Unlock() if drained { close(e.drainDone) <-e.undrain @@ -1330,10 +1314,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // This means the socket is being closed due // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. - e.mu.Lock() e.transitionToStateCloseLocked() e.workerCleanup = true - e.mu.Unlock() return nil }, }, @@ -1388,7 +1370,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if n¬ifyClose != 0 && closeTimer == nil { - e.mu.Lock() if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. @@ -1397,7 +1378,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }) e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } - e.mu.Unlock() } if n¬ifyKeepaliveChanged != 0 { @@ -1417,7 +1397,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } } @@ -1460,7 +1442,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } e.rcvListMu.Unlock() - e.mu.Lock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } @@ -1468,7 +1449,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Main loop. Handle segments until both send and receive ends of the // connection have completed. cleanupOnError := func(err *tcpip.Error) { - e.mu.Lock() e.workerCleanup = true if err != nil { e.resetConnectionLocked(err) @@ -1480,16 +1460,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ loop: for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError { e.mu.Unlock() - e.workMu.Unlock() v, _ := s.Fetch(true) - e.workMu.Lock() + e.mu.Lock() - // We need to double check here because the notification maybe + // We need to double check here because the notification may be // stale by the time we got around to processing it. - // - // NOTE: since we now hold the workMu the processors cannot - // change the state of the endpoint so it's safe to proceed - // after this check. switch e.EndpointState() { case StateError: // If the endpoint has already transitioned to an ERROR @@ -1502,21 +1477,17 @@ loop: case StateTimeWait: fallthrough case StateClose: - e.mu.Lock() break loop default: if err := funcs[v].f(); err != nil { cleanupOnError(err) return nil } - e.mu.Lock() } } - state := e.EndpointState() - e.mu.Unlock() var reuseTW func() - if state == StateTimeWait { + if e.EndpointState() == StateTimeWait { // Disable close timer as we now entering real TIME_WAIT. if closeTimer != nil { closeTimer.Stop() @@ -1526,14 +1497,11 @@ loop: s.Done() // Wake up any waiters before we enter TIME_WAIT. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) - e.mu.Lock() e.workerCleanup = true - e.mu.Unlock() reuseTW = e.doTimeWait() } // Mark endpoint as closed. - e.mu.Lock() if e.EndpointState() != StateError { e.transitionToStateCloseLocked() } @@ -1649,9 +1617,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) { defer timeWaitTimer.Stop() for { - e.workMu.Unlock() + e.mu.Unlock() v, _ := s.Fetch(true) - e.workMu.Lock() + e.mu.Lock() switch v { case newSegment: extendTimeWait, reuseTW := e.handleTimeWaitSegments() @@ -1674,7 +1642,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) { e.handleTimeWaitSegments() } close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() return nil } case timeWaitDone: diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index d792b07d6..90ac956a9 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -128,7 +128,7 @@ func (p *processor) handleSegments() { continue } - if !ep.workMu.TryLock() { + if !ep.mu.TryLock() { ep.newSegmentWaker.Assert() continue } @@ -138,12 +138,10 @@ func (p *processor) handleSegments() { if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { // Send any active resets if required. if err != nil { - ep.mu.Lock() ep.resetConnectionLocked(err) - ep.mu.Unlock() } ep.notifyProtocolGoroutine(notifyTickleWorker) - ep.workMu.Unlock() + ep.mu.Unlock() continue } @@ -151,7 +149,7 @@ func (p *processor) handleSegments() { p.epQ.enqueue(ep) } - ep.workMu.Unlock() + ep.mu.Unlock() } } } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 5187a5e25..eb8a9d73e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "fmt" "math" + "runtime" "strings" "sync/atomic" "time" @@ -33,7 +34,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tmutex" "gvisor.dev/gvisor/pkg/waiter" ) @@ -283,6 +283,37 @@ func (*EndpointInfo) IsEndpointInfo() {} // synchronized. The protocol implementation, however, runs in a single // goroutine. // +// Each endpoint has a few mutexes: +// +// e.mu -> Primary mutex for an endpoint must be held for all operations except +// in e.Readiness where acquiring it will result in a deadlock in epoll +// implementation. +// +// The following three mutexes can be acquired independent of e.mu but if +// acquired with e.mu then e.mu must be acquired first. +// +// e.rcvListMu -> Protects the rcvList and associated fields. +// e.sndBufMu -> Protects the sndQueue and associated fields. +// e.lastErrorMu -> Protects the lastError field. +// +// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different +// based on the context in which the lock is acquired. In the syscall context +// e.LockUser/e.UnlockUser should be used and when doing background processing +// e.mu.Lock/e.mu.Unlock should be used. The distinction is described below +// in brief. +// +// The reason for this locking behaviour is to avoid wakeups to handle packets. +// In cases where the endpoint is already locked the background processor can +// queue the packet up and go its merry way and the lock owner will eventually +// process the backlog when releasing the lock. Similarly when acquiring the +// lock from say a syscall goroutine we can implement a bit of spinning if we +// know that the lock is not held by another syscall goroutine. Background +// processors should never hold the lock for long and we can avoid an expensive +// sleep/wakeup by spinning for a shortwhile. +// +// For more details please see the detailed documentation on +// e.LockUser/e.UnlockUser methods. +// // +stateify savable type endpoint struct { EndpointInfo @@ -299,12 +330,6 @@ type endpoint struct { // Precondition: epQueue.mu must be held to read/write this field.. pendingProcessing bool `state:"nosave"` - // workMu is used to arbitrate which goroutine may perform protocol - // work. Only the main protocol goroutine is expected to call Lock() on - // it, but other goroutines (e.g., send) may call TryLock() to eagerly - // perform work without having to wait for the main one to wake up. - workMu tmutex.Mutex `state:"nosave"` - // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` @@ -330,15 +355,11 @@ type endpoint struct { rcvBufSize int rcvBufUsed int rcvAutoParams rcvBufAutoTuneParams - // zeroWindow indicates that the window was closed due to receive buffer - // space being filled up. This is set by the worker goroutine before - // moving a segment to the rcvList. This setting is cleared by the - // endpoint when a Read() call reads enough data for the new window to - // be non-zero. - zeroWindow bool - // The following fields are protected by the mutex. - mu sync.RWMutex `state:"nosave"` + // mu protects all endpoint fields unless documented otherwise. mu must + // be acquired before interacting with the endpoint fields. + mu sync.Mutex `state:"nosave"` + ownedByUser uint32 // state must be read/set using the EndpointState()/setEndpointState() methods. state EndpointState `state:".(EndpointState)"` @@ -583,14 +604,93 @@ func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { return maxMSS } +// LockUser tries to lock e.mu and if it fails it will check if the lock is held +// by another syscall goroutine. If yes, then it will goto sleep waiting for the +// lock to be released, if not then it will spin till it acquires the lock or +// another syscall goroutine acquires it in which case it will goto sleep as +// described above. +// +// The assumption behind spinning here being that background packet processing +// should not be holding the lock for long and spinning reduces latency as we +// avoid an expensive sleep/wakeup of of the syscall goroutine). +func (e *endpoint) LockUser() { + for { + // Try first if the sock is locked then check if it's owned + // by another user goroutine if not then we spin, otherwise + // we just goto sleep on the Lock() and wait. + if !e.mu.TryLock() { + // If socket is owned by the user then just goto sleep + // as the lock could be held for a reasonably long time. + if atomic.LoadUint32(&e.ownedByUser) == 1 { + e.mu.Lock() + atomic.StoreUint32(&e.ownedByUser, 1) + return + } + // Spin but yield the processor since the lower half + // should yield the lock soon. + runtime.Gosched() + continue + } + atomic.StoreUint32(&e.ownedByUser, 1) + return + } +} + +// UnlockUser will check if there are any segments already queued for processing +// and process any such segments before unlocking e.mu. This is required because +// we when packets arrive and endpoint lock is already held then such packets +// are queued up to be processed. If the lock is held by the endpoint goroutine +// then it will process these packets but if the lock is instead held by the +// syscall goroutine then we can have the syscall goroutine process the backlog +// before unlocking. +// +// This avoids an unnecessary wakeup of the endpoint protocol goroutine for the +// endpoint. It's also required eventually when we get rid of the endpoint +// protocol goroutine altogether. +// +// Precondition: e.LockUser() must have been called before calling e.UnlockUser() +func (e *endpoint) UnlockUser() { + // Lock segment queue before checking so that we avoid a race where + // segments can be queued between the time we check if queue is empty + // and actually unlock the endpoint mutex. + for { + e.segmentQueue.mu.Lock() + if e.segmentQueue.emptyLocked() { + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + e.segmentQueue.mu.Unlock() + return + } + e.segmentQueue.mu.Unlock() + + switch e.EndpointState() { + case StateEstablished: + if err := e.handleSegments(true /* fastPath */); err != nil { + e.notifyProtocolGoroutine(notifyTickleWorker) + } + default: + // Since we are waking the endpoint goroutine here just unlock + // and let it process the queued segments. + e.newSegmentWaker.Assert() + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + return + } + } +} + // StopWork halts packet processing. Only to be used in tests. func (e *endpoint) StopWork() { - e.workMu.Lock() + e.mu.Lock() } // ResumeWork resumes packet processing. Only to be used in tests. func (e *endpoint) ResumeWork() { - e.workMu.Unlock() + e.mu.Unlock() } // setEndpointState updates the state of the endpoint to state atomically. This @@ -709,8 +809,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() - e.workMu.Lock() e.tsOffset = timeStampOffset() return e @@ -721,9 +819,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result := waiter.EventMask(0) - e.mu.RLock() - defer e.mu.RUnlock() - switch e.EndpointState() { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. @@ -823,20 +918,22 @@ func (e *endpoint) Abort() { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { - e.mu.Lock() - closed := e.closed - e.closed = true - e.mu.Unlock() - if closed { + e.LockUser() + defer e.UnlockUser() + if e.closed { return } // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. - e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) - - e.mu.Lock() + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + e.closeNoShutdownLocked() +} +// closeNoShutdown closes the endpoint without doing a full shutdown. This is +// used when a connection needs to be aborted with a RST and we want to skip +// a full 4 way TCP shutdown. +func (e *endpoint) closeNoShutdownLocked() { // For listening sockets, we always release ports inline so that they // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail @@ -853,6 +950,8 @@ func (e *endpoint) Close() { e.boundPortFlags = ports.Flags{} } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. switch e.EndpointState() { @@ -873,8 +972,6 @@ func (e *endpoint) Close() { // goroutine terminates. e.notifyProtocolGoroutine(notifyClose) } - - e.mu.Unlock() } // closePendingAcceptableConnections closes all connections that have completed @@ -909,7 +1006,6 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { // after Close() is called and the worker goroutine (if any) is done with its // work. func (e *endpoint) cleanupLocked() { - // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { @@ -954,18 +1050,18 @@ func (e *endpoint) initialReceiveWindow() int { // ModerateRecvBuf adjusts the receive buffer and the advertised window // based on the number of bytes copied to user space. func (e *endpoint) ModerateRecvBuf(copied int) { - e.mu.RLock() + e.LockUser() + defer e.UnlockUser() + e.rcvListMu.Lock() if e.rcvAutoParams.disabled { e.rcvListMu.Unlock() - e.mu.RUnlock() return } now := time.Now() if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { e.rcvAutoParams.copied += copied e.rcvListMu.Unlock() - e.mu.RUnlock() return } prevRTTCopied := e.rcvAutoParams.copied + copied @@ -1021,7 +1117,6 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvAutoParams.measureTime = now e.rcvAutoParams.copied = 0 e.rcvListMu.Unlock() - e.mu.RUnlock() } // IPTables implements tcpip.Endpoint.IPTables. @@ -1031,7 +1126,7 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - e.mu.RLock() + e.LockUser() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the @@ -1041,7 +1136,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError - e.mu.RUnlock() + e.UnlockUser() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } @@ -1051,7 +1146,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, v, err := e.readLocked() e.rcvListMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() if err == tcpip.ErrClosedForReceive { e.stats.ReadErrors.ReadClosed.Increment() @@ -1124,13 +1219,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. - e.mu.RLock() + e.LockUser() e.sndBufMu.Lock() avail, err := e.isEndpointWritableLocked() if err != nil { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } @@ -1142,113 +1237,68 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // are copying data in. if !opts.Atomic { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() } // Fetch data. v, perr := p.Payload(avail) if perr != nil || len(v) == 0 { - if opts.Atomic { // See above. + // Note that perr may be nil if len(v) == 0. + if opts.Atomic { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() } - // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - if opts.Atomic { + queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { // Add data to the send queue. s := newSegmentFromView(&e.route, e.ID, v) e.sndBufUsed += len(v) e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() - } - if e.workMu.TryLock() { - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() - - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } - - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() - - } // Do the work inline. e.handleWrite() - e.workMu.Unlock() - } else { - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() + e.UnlockUser() + return int64(len(v)), nil, nil + } - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } + if opts.Atomic { + // Locks released in queueAndSend() + return queueAndSend() + } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + e.LockUser() + e.sndBufMu.Lock() + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.UnlockUser() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } - } - // Let the protocol goroutine do the work. - e.sndWaker.Assert() + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] } - return int64(len(v)), nil, nil + // Locks released in queueAndSend() + return queueAndSend() } // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. @@ -1339,6 +1389,9 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // SetSockOptBool sets a socket option. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + e.LockUser() + defer e.UnlockUser() + switch opt { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. @@ -1346,9 +1399,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - e.mu.Lock() - defer e.mu.Unlock() - // We only allow this to be set when we're in the initial state. if e.EndpointState() != StateInitial { return tcpip.ErrInvalidEndpointState @@ -1379,7 +1429,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { mask := uint32(notifyReceiveWindowChanged) - e.mu.RLock() + e.LockUser() e.rcvListMu.Lock() // Make sure the receive buffer size allows us to send a @@ -1409,8 +1459,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } + e.rcvListMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() e.notifyProtocolGoroutine(mask) return nil @@ -1466,15 +1517,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil case tcpip.ReuseAddressOption: - e.mu.Lock() + e.LockUser() e.reuseAddr = v != 0 - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.ReusePortOption: - e.mu.Lock() + e.LockUser() e.reusePort = v != 0 - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.BindToDeviceOption: @@ -1482,9 +1533,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { if id != 0 && !e.stack.HasNIC(id) { return tcpip.ErrUnknownDevice } - e.mu.Lock() + e.LockUser() e.bindToDevice = id - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.QuickAckOption: @@ -1500,16 +1551,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { return tcpip.ErrInvalidOptionValue } - e.mu.Lock() + e.LockUser() e.userMSS = uint16(userMSS) - e.mu.Unlock() + e.UnlockUser() e.notifyProtocolGoroutine(notifyMSSChanged) return nil case tcpip.TTLOption: - e.mu.Lock() + e.LockUser() e.ttl = uint8(v) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.KeepaliveEnabledOption: @@ -1541,15 +1592,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil case tcpip.TCPUserTimeoutOption: - e.mu.Lock() + e.LockUser() e.userTimeout = time.Duration(v) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.BroadcastOption: - e.mu.Lock() + e.LockUser() e.broadcast = v != 0 - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.CongestionControlOption: @@ -1563,22 +1614,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { availCC := strings.Split(string(avail), " ") for _, cc := range availCC { if v == tcpip.CongestionControlOption(cc) { - // Acquire the work mutex as we may need to - // reinitialize the congestion control state. - e.mu.Lock() + e.LockUser() state := e.EndpointState() e.cc = v - e.mu.Unlock() switch state { case StateEstablished: - e.workMu.Lock() - e.mu.Lock() if e.EndpointState() == state { e.snd.cc = e.snd.initCongestionControl(e.cc) } - e.mu.Unlock() - e.workMu.Unlock() } + e.UnlockUser() return nil } } @@ -1588,23 +1633,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrNoSuchFile case tcpip.IPv4TOSOption: - e.mu.Lock() + e.LockUser() // TODO(gvisor.dev/issue/995): ECN is not currently supported, // ignore the bits for now. e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.IPv6TrafficClassOption: - e.mu.Lock() + e.LockUser() // TODO(gvisor.dev/issue/995): ECN is not currently supported, // ignore the bits for now. e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.TCPLingerTimeoutOption: - e.mu.Lock() + e.LockUser() if v < 0 { // Same as effectively disabling TCPLinger timeout. v = 0 @@ -1622,16 +1667,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { v = stkTCPLingerTimeout } e.tcpLingerTimeout = time.Duration(v) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.TCPDeferAcceptOption: - e.mu.Lock() + e.LockUser() if time.Duration(v) > MaxRTO { v = tcpip.TCPDeferAcceptOption(MaxRTO) } e.deferAccept = time.Duration(v) - e.mu.Unlock() + e.UnlockUser() return nil default: @@ -1641,8 +1686,8 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // readyReceiveSize returns the number of bytes ready to be received. func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // The endpoint cannot be in listen state. if e.EndpointState() == StateListen { @@ -1664,9 +1709,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return false, tcpip.ErrUnknownProtocolOption } - e.mu.Lock() + e.LockUser() v := e.v6only - e.mu.Unlock() + e.UnlockUser() return v, nil } @@ -1730,9 +1775,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.ReuseAddressOption: - e.mu.RLock() + e.LockUser() v := e.reuseAddr - e.mu.RUnlock() + e.UnlockUser() *o = 0 if v { @@ -1741,9 +1786,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.ReusePortOption: - e.mu.RLock() + e.LockUser() v := e.reusePort - e.mu.RUnlock() + e.UnlockUser() *o = 0 if v { @@ -1752,9 +1797,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.BindToDeviceOption: - e.mu.RLock() + e.LockUser() *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.mu.RUnlock() + e.UnlockUser() return nil case *tcpip.QuickAckOption: @@ -1765,16 +1810,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.TTLOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TTLOption(e.ttl) - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} - e.mu.RLock() + e.LockUser() snd := e.snd - e.mu.RUnlock() + e.UnlockUser() if snd != nil { snd.rtt.Lock() o.RTT = snd.rtt.srtt @@ -1813,9 +1858,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.TCPUserTimeoutOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TCPUserTimeoutOption(e.userTimeout) - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.OutOfBandInlineOption: @@ -1824,9 +1869,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.BroadcastOption: - e.mu.Lock() + e.LockUser() v := e.broadcast - e.mu.Unlock() + e.UnlockUser() *o = 0 if v { @@ -1835,33 +1880,33 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.CongestionControlOption: - e.mu.Lock() + e.LockUser() *o = e.cc - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.IPv4TOSOption: - e.mu.RLock() + e.LockUser() *o = tcpip.IPv4TOSOption(e.sendTOS) - e.mu.RUnlock() + e.UnlockUser() return nil case *tcpip.IPv6TrafficClassOption: - e.mu.RLock() + e.LockUser() *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.mu.RUnlock() + e.UnlockUser() return nil case *tcpip.TCPLingerTimeoutOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.TCPDeferAcceptOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TCPDeferAcceptOption(e.deferAccept) - e.mu.Unlock() + e.UnlockUser() return nil default: @@ -1901,8 +1946,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // yet accepted by the app, they are restored without running the main goroutine // here. func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() connectingAddr := addr.Addr @@ -2071,9 +2116,13 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // Shutdown closes the read and/or write end of the endpoint connection to its // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - e.mu.Lock() + e.LockUser() + defer e.UnlockUser() + return e.shutdownLocked(flags) +} + +func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { e.shutdownFlags |= flags - finQueued := false switch { case e.EndpointState().connected(): // Close for read. @@ -2087,24 +2136,9 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { - e.mu.Unlock() - // Try to send an active reset immediately if the - // work mutex is available. - if e.workMu.TryLock() { - e.mu.Lock() - // We need to double check here to make - // sure worker has not transitioned the - // endpoint out of a connected state - // before trying to send a reset. - if e.EndpointState().connected() { - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.notifyProtocolGoroutine(notifyTickleWorker) - } - e.mu.Unlock() - e.workMu.Unlock() - } else { - e.notifyProtocolGoroutine(notifyReset) - } + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + // Wake up worker to terminate loop. + e.notifyProtocolGoroutine(notifyTickleWorker) return nil } } @@ -2116,42 +2150,32 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // Already closed. e.sndBufMu.Unlock() if e.EndpointState() == StateTimeWait { - e.mu.Unlock() return tcpip.ErrNotConnected } - break + return nil } // Queue fin segment. s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() + e.handleClose() } + return nil case e.EndpointState() == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } + return nil + default: - e.mu.Unlock() return tcpip.ErrNotConnected } - e.mu.Unlock() - if finQueued { - if e.workMu.TryLock() { - e.handleClose() - e.workMu.Unlock() - } else { - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() - } - } - return nil } // Listen puts the endpoint in "listen" mode, which allows it to accept @@ -2166,8 +2190,8 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error { } func (e *endpoint) listen(backlog int) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() // Allow the backlog to be adjusted if the endpoint is not shutting down. // When the endpoint shuts down, it sets workerCleanup to true, and from @@ -2229,7 +2253,6 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. func (e *endpoint) startAcceptedLoop() { - e.mu.Lock() e.workerRunning = true e.mu.Unlock() wakerInitDone := make(chan struct{}) @@ -2240,8 +2263,8 @@ func (e *endpoint) startAcceptedLoop() { // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // Endpoint must be in listen state before it can accept connections. if e.EndpointState() != StateListen { @@ -2260,8 +2283,8 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { // Bind binds the endpoint to a specific local port and optionally address. func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() return e.bindLocked(addr) } @@ -2339,8 +2362,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // GetLocalAddress returns the address to which the endpoint is bound. func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() return tcpip.FullAddress{ Addr: e.ID.LocalAddress, @@ -2351,8 +2374,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { // GetRemoteAddress returns the address to which the endpoint is connected. func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() if !e.EndpointState().connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected @@ -2419,7 +2442,6 @@ func (e *endpoint) updateSndBufferUsage(v int) { // to be read, or when the connection is closed for receiving (in which case // s will be nil). func (e *endpoint) readyToRead(s *segment) { - e.mu.RLock() e.rcvListMu.Lock() if s != nil { s.incRef() @@ -2434,7 +2456,6 @@ func (e *endpoint) readyToRead(s *segment) { e.rcvClosed = true } e.rcvListMu.Unlock() - e.mu.RUnlock() e.waiterQueue.Notify(waiter.EventIn) } @@ -2578,9 +2599,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.SegTime = time.Now() // Copy EndpointID. - e.mu.Lock() s.ID = stack.TCPEndpointID(e.ID) - e.mu.Unlock() // Copy endpoint rcv state. e.rcvListMu.Lock() @@ -2710,10 +2729,10 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { - e.mu.RLock() + e.LockUser() // Make a copy of the endpoint info. ret := e.EndpointInfo - e.mu.RUnlock() + e.UnlockUser() return &ret } @@ -2728,9 +2747,9 @@ func (e *endpoint) Wait() { e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) defer e.waiterQueue.EventUnregister(&waitEntry) for { - e.mu.Lock() + e.LockUser() running := e.workerRunning - e.mu.Unlock() + e.UnlockUser() if !running { break } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 4a46f0ec5..9175de441 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -162,8 +162,8 @@ func (e *endpoint) loadState(state EndpointState) { connectingLoading.Add(1) } // Directly update the state here rather than using e.setEndpointState - // as the endpoint is still being loaded and the stack reference to increment - // metrics is not yet initialized. + // as the endpoint is still being loaded and the stack reference is not + // yet initialized. atomic.StoreUint32((*uint32)(&e.state), uint32(state)) } @@ -180,7 +180,6 @@ func (e *endpoint) afterLoad() { func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() state := e.origEndpointState switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 73098d904..b0f918bb4 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -95,7 +95,7 @@ const ( ) type protocol struct { - mu sync.Mutex + mu sync.RWMutex sackEnabled bool delayEnabled bool sendBufferSize SendBufferSizeOption @@ -273,57 +273,57 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { case *SACKEnabled: - p.mu.Lock() + p.mu.RLock() *v = SACKEnabled(p.sackEnabled) - p.mu.Unlock() + p.mu.RUnlock() return nil case *DelayEnabled: - p.mu.Lock() + p.mu.RLock() *v = DelayEnabled(p.delayEnabled) - p.mu.Unlock() + p.mu.RUnlock() return nil case *SendBufferSizeOption: - p.mu.Lock() + p.mu.RLock() *v = p.sendBufferSize - p.mu.Unlock() + p.mu.RUnlock() return nil case *ReceiveBufferSizeOption: - p.mu.Lock() + p.mu.RLock() *v = p.recvBufferSize - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.CongestionControlOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.CongestionControlOption(p.congestionControl) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.AvailableCongestionControlOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.ModerateReceiveBufferOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.TCPLingerTimeoutOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.TCPTimeWaitTimeoutOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) - p.mu.Unlock() + p.mu.RUnlock() return nil default: diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index d80aff1b6..caf8977b3 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -168,7 +168,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // We just received a FIN, our next state depends on whether we sent a // FIN already or not. - r.ep.mu.Lock() switch r.ep.EndpointState() { case StateEstablished: r.ep.setEndpointState(StateCloseWait) @@ -183,7 +182,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateFinWait2: r.ep.setEndpointState(StateTimeWait) } - r.ep.mu.Unlock() // Flush out any pending segments, except the very first one if // it happens to be the one we're handling now because the @@ -208,7 +206,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { - r.ep.mu.Lock() switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -222,7 +219,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateLastAck: r.ep.transitionToStateCloseLocked() } - r.ep.mu.Unlock() } return true @@ -336,10 +332,8 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // handleRcvdSegment handles TCP segments directed at the connection managed by // r as they arrive. It is called by the protocol main loop. func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { - r.ep.mu.RLock() state := r.ep.EndpointState() closed := r.ep.closed - r.ep.mu.RUnlock() if state != StateEstablished { drop, err := r.handleRcvdSegmentClosing(s, state, closed) diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index bd20a7ee9..48a257137 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -28,10 +28,16 @@ type segmentQueue struct { used int } +// emptyLocked determines if the queue is empty. +// Preconditions: q.mu must be held. +func (q *segmentQueue) emptyLocked() bool { + return q.used == 0 +} + // empty determines if the queue is empty. func (q *segmentQueue) empty() bool { q.mu.Lock() - r := q.used == 0 + r := q.emptyLocked() q.mu.Unlock() return r diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 657c3146e..17fed4ec5 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -455,9 +455,7 @@ func (s *sender) retransmitTimerExpired() bool { // Give up if we've waited more than a minute since the last resend or // if a user time out is set and we have exceeded the user specified // timeout since the first retransmission. - s.ep.mu.RLock() uto := s.ep.userTimeout - s.ep.mu.RUnlock() if s.firstRetransmittedSegXmitTime.IsZero() { // We store the original xmitTime of the segment that we are @@ -713,7 +711,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se default: s.ep.setEndpointState(StateFinWait1) } - } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 5b2b16afa..39d36d2ba 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2236,9 +2236,17 @@ func TestSegmentMerging(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - // Prevent the endpoint from processing packets. - test.stop(c.EP) + // Send 10 1 byte segments to fill up InitialWindow but don't + // ACK. That should prevent anymore packets from going out. + for i := 0; i < 10; i++ { + view := buffer.NewViewFromBytes([]byte{0}) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %v", i+1, err) + } + } + // Now send the segments that should get merged as the congestion + // window is full and we won't be able to send any more packets. var allData []byte for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) @@ -2248,8 +2256,29 @@ func TestSegmentMerging(t *testing.T) { } } - // Let the endpoint process the segments that we just sent. - test.resume(c.EP) + // Check that we get 10 packets of 1 byte each. + for i := 0; i < 10; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(header.TCPMinimumSize+1), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+uint32(i)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. + RcvWnd: 30000, + }) // Check that data is received. b := c.GetPacket() @@ -2257,7 +2286,7 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(len(allData)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+11), checker.AckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), @@ -2273,7 +2302,7 @@ func TestSegmentMerging(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))), + AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), RcvWnd: 30000, }) }) -- cgit v1.2.3 From fd27a917ef068a4e17cddbe3f671b59f52e6e030 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 19 Mar 2020 09:41:56 -0700 Subject: Address comments on workMu removal change. Updates #231, #357 PiperOrigin-RevId: 301833669 --- pkg/tcpip/transport/tcp/accept.go | 2 +- pkg/tcpip/transport/tcp/tcp_test.go | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 4d7602d54..3f80995f3 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -420,7 +420,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } func (e *endpoint) incSynRcvdCount() bool { - if e.synRcvdCount >= (cap(e.acceptedChan)) { + if e.synRcvdCount >= cap(e.acceptedChan) { return false } e.synRcvdCount++ diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 39d36d2ba..ce3df7478 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2236,12 +2236,13 @@ func TestSegmentMerging(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - // Send 10 1 byte segments to fill up InitialWindow but don't - // ACK. That should prevent anymore packets from going out. - for i := 0; i < 10; i++ { + // Send tcp.InitialCwnd number of segments to fill up + // InitialWindow but don't ACK. That should prevent + // anymore packets from going out. + for i := 0; i < tcp.InitialCwnd; i++ { view := buffer.NewViewFromBytes([]byte{0}) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) + t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2256,8 +2257,8 @@ func TestSegmentMerging(t *testing.T) { } } - // Check that we get 10 packets of 1 byte each. - for i := 0; i < 10; i++ { + // Check that we get tcp.InitialCwnd packets. + for i := 0; i < tcp.InitialCwnd; i++ { b := c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(header.TCPMinimumSize+1), -- cgit v1.2.3 From c8eeedcc1d6b1ee25532ae630a7efd7aa4656bdc Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 24 Mar 2020 15:33:16 -0700 Subject: Add support for setting TCP segment hash. This allows the link layer endpoints to consistenly hash a TCP segment to a single underlying queue in case a link layer endpoint does support multiple underlying queues. Updates #231 PiperOrigin-RevId: 302760664 --- pkg/tcpip/link/fdbased/endpoint.go | 10 ++- pkg/tcpip/link/fdbased/endpoint_test.go | 76 +++++++++++++----- pkg/tcpip/stack/BUILD | 1 + pkg/tcpip/stack/packet_buffer.go | 5 ++ pkg/tcpip/stack/packet_buffer_state.go | 1 + pkg/tcpip/stack/rand.go | 40 +++++++++ pkg/tcpip/stack/stack.go | 44 +++++++++- pkg/tcpip/transport/tcp/accept.go | 20 ++++- pkg/tcpip/transport/tcp/connect.go | 138 +++++++++++++++++++++++--------- pkg/tcpip/transport/tcp/endpoint.go | 5 ++ pkg/tcpip/transport/tcp/protocol.go | 10 ++- 11 files changed, 280 insertions(+), 70 deletions(-) create mode 100644 pkg/tcpip/stack/rand.go (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 235e647ff..3b3b6909b 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -405,6 +405,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne eth.Encode(ethHdr) } + fd := e.fds[pkt.Hash%uint32(len(e.fds))] if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} if gso != nil { @@ -428,14 +429,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) - return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) + return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) } if pkt.Data.Size() == 0 { - return rawfile.NonBlockingWrite(e.fds[0], pkt.Header.View()) + return rawfile.NonBlockingWrite(fd, pkt.Header.View()) } - return rawfile.NonBlockingWrite3(e.fds[0], pkt.Header.View(), pkt.Data.ToView(), nil) + return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil) } // WritePackets writes outbound packets to the file descriptor. If it is not @@ -551,7 +552,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac packets := 0 for packets < n { - sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs) + fd := e.fds[pkts[packets].Hash%uint32(len(e.fds))] + sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs) if err != nil { return packets, err } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index c7dbbbc6b..3bfb15a8e 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -49,36 +49,42 @@ type packetInfo struct { } type context struct { - t *testing.T - fds [2]int - ep stack.LinkEndpoint - ch chan packetInfo - done chan struct{} + t *testing.T + readFDs []int + writeFDs []int + ep stack.LinkEndpoint + ch chan packetInfo + done chan struct{} } func newContext(t *testing.T, opt *Options) *context { - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + firstFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + if err != nil { + t.Fatalf("Socketpair failed: %v", err) + } + secondFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) if err != nil { t.Fatalf("Socketpair failed: %v", err) } - done := make(chan struct{}, 1) + done := make(chan struct{}, 2) opt.ClosedFunc = func(*tcpip.Error) { done <- struct{}{} } - opt.FDs = []int{fds[1]} + opt.FDs = []int{firstFDPair[1], secondFDPair[1]} ep, err := New(opt) if err != nil { t.Fatalf("Failed to create FD endpoint: %v", err) } c := &context{ - t: t, - fds: fds, - ep: ep, - ch: make(chan packetInfo, 100), - done: done, + t: t, + readFDs: []int{firstFDPair[0], secondFDPair[0]}, + writeFDs: opt.FDs, + ep: ep, + ch: make(chan packetInfo, 100), + done: done, } ep.Attach(c) @@ -87,9 +93,14 @@ func newContext(t *testing.T, opt *Options) *context { } func (c *context) cleanup() { - syscall.Close(c.fds[0]) + for _, fd := range c.readFDs { + syscall.Close(fd) + } + <-c.done <-c.done - syscall.Close(c.fds[1]) + for _, fd := range c.writeFDs { + syscall.Close(fd) + } } func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { @@ -136,7 +147,7 @@ func TestAddress(t *testing.T) { } } -func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { +func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash uint32) { c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) defer c.cleanup() @@ -171,13 +182,15 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), + Hash: hash, }); err != nil { t.Fatalf("WritePacket failed: %v", err) } - // Read from fd, then compare with what we wrote. + // Read from the corresponding FD, then compare with what we wrote. b = make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) + fd := c.readFDs[hash%uint32(len(c.readFDs))] + n, err := syscall.Read(fd, b) if err != nil { t.Fatalf("Read failed: %v", err) } @@ -238,7 +251,7 @@ func TestWritePacket(t *testing.T) { t.Run( fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso), func(t *testing.T) { - testWritePacket(t, plen, eth, gso) + testWritePacket(t, plen, eth, gso, 0) }, ) } @@ -246,6 +259,27 @@ func TestWritePacket(t *testing.T) { } } +func TestHashedWritePacket(t *testing.T) { + lengths := []int{0, 100, 1000} + eths := []bool{true, false} + gsos := []uint32{0, 32768} + hashes := []uint32{0, 1} + for _, eth := range eths { + for _, plen := range lengths { + for _, gso := range gsos { + for _, hash := range hashes { + t.Run( + fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v,Hash=%d", eth, plen, gso, hash), + func(t *testing.T) { + testWritePacket(t, plen, eth, gso, hash) + }, + ) + } + } + } + } +} + func TestPreserveSrcAddress(t *testing.T) { baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99") @@ -270,7 +304,7 @@ func TestPreserveSrcAddress(t *testing.T) { // Read from the FD, then compare with what we wrote. b := make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) + n, err := syscall.Read(c.readFDs[0], b) if err != nil { t.Fatalf("Read failed: %v", err) } @@ -314,7 +348,7 @@ func TestDeliverPacket(t *testing.T) { } // Write packet via the file descriptor. - if _, err := syscall.Write(c.fds[0], all); err != nil { + if _, err := syscall.Write(c.readFDs[0], all); err != nil { t.Fatalf("Write failed: %v", err) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 7a43a1d4e..8d80e9cee 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -30,6 +30,7 @@ go_library( "nic.go", "packet_buffer.go", "packet_buffer_state.go", + "rand.go", "registration.go", "route.go", "stack.go", diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 1850fa8c3..9505a4e92 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -10,6 +10,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package stack import "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -54,6 +55,10 @@ type PacketBuffer struct { LinkHeader buffer.View NetworkHeader buffer.View TransportHeader buffer.View + + // Hash is the transport layer hash of this packet. A value of zero + // indicates no valid hash has been set. + Hash uint32 } // Clone makes a copy of pk. It clones the Data field, which creates a new diff --git a/pkg/tcpip/stack/packet_buffer_state.go b/pkg/tcpip/stack/packet_buffer_state.go index 76602549e..0c6b7924c 100644 --- a/pkg/tcpip/stack/packet_buffer_state.go +++ b/pkg/tcpip/stack/packet_buffer_state.go @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package stack import "gvisor.dev/gvisor/pkg/tcpip/buffer" diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go new file mode 100644 index 000000000..421fb5c15 --- /dev/null +++ b/pkg/tcpip/stack/rand.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + mathrand "math/rand" + + "gvisor.dev/gvisor/pkg/sync" +) + +// lockedRandomSource provides a threadsafe rand.Source. +type lockedRandomSource struct { + mu sync.Mutex + src mathrand.Source +} + +func (r *lockedRandomSource) Int63() (n int64) { + r.mu.Lock() + n = r.src.Int63() + r.mu.Unlock() + return n +} + +func (r *lockedRandomSource) Seed(seed int64) { + r.mu.Lock() + r.src.Seed(seed) + r.mu.Unlock() +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a9584d636..41398a1b6 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,7 +20,9 @@ package stack import ( + "bytes" "encoding/binary" + mathrand "math/rand" "sync/atomic" "time" @@ -465,6 +467,10 @@ type Stack struct { // forwarder holds the packets that wait for their link-address resolutions // to complete, and forwards them when each resolution is done. forwarder *forwardQueue + + // randomGenerator is an injectable pseudo random generator that can be + // used when a random number is required. + randomGenerator *mathrand.Rand } // UniqueID is an abstract generator of unique identifiers. @@ -525,9 +531,16 @@ type Options struct { // this is non-nil. RawFactory RawFactory - // OpaqueIIDOpts hold the options for generating opaque interface identifiers - // (IIDs) as outlined by RFC 7217. + // OpaqueIIDOpts hold the options for generating opaque interface + // identifiers (IIDs) as outlined by RFC 7217. OpaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // RandSource is an optional source to use to generate random + // numbers. If omitted it defaults to a Source seeded by the data + // returned by rand.Read(). + // + // RandSource must be thread-safe. + RandSource mathrand.Source } // TransportEndpointInfo holds useful information about a transport endpoint @@ -623,6 +636,13 @@ func New(opts Options) *Stack { opts.UniqueID = new(uniqueIDGenerator) } + randSrc := opts.RandSource + if randSrc == nil { + // Source provided by mathrand.NewSource is not thread-safe so + // we wrap it in a simple thread-safe version. + randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} + } + // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() @@ -645,6 +665,7 @@ func New(opts Options) *Stack { ndpDisp: opts.NDPDisp, opaqueIIDOpts: opts.OpaqueIIDOpts, forwarder: newForwardQueue(), + randomGenerator: mathrand.New(randSrc), } // Add specified network protocols. @@ -1818,6 +1839,12 @@ func (s *Stack) Seed() uint32 { return s.seed } +// Rand returns a reference to a pseudo random generator that can be used +// to generate random numbers as required. +func (s *Stack) Rand() *mathrand.Rand { + return s.randomGenerator +} + func generateRandUint32() uint32 { b := make([]byte, 4) if _, err := rand.Read(b); err != nil { @@ -1825,3 +1852,16 @@ func generateRandUint32() uint32 { } return binary.LittleEndian.Uint32(b) } + +func generateRandInt64() int64 { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + panic(err) + } + buf := bytes.NewReader(b) + var v int64 + if err := binary.Read(buf, binary.LittleEndian, &v); err != nil { + panic(err) + } + return v +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3f80995f3..b4c4c8ab1 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -445,7 +445,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST // must be sent in response to a SYN-ACK while in the listen // state to prevent completing a handshake from an old SYN. - e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil) + e.sendTCP(&s.route, tcpFields{ + id: s.id, + ttl: e.ttl, + tos: e.sendTOS, + flags: header.TCPFlagRst, + seq: s.ackNumber, + ack: 0, + rcvWnd: 0, + }, buffer.VectorisedView{}, nil) return } @@ -493,7 +501,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TSEcr: opts.TSVal, MSS: mssForRoute(&s.route), } - e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.sendSynTCP(&s.route, tcpFields{ + id: s.id, + ttl: e.ttl, + tos: e.sendTOS, + flags: header.TCPFlagSyn | header.TCPFlagAck, + seq: cookie, + ack: s.sequenceNumber + 1, + rcvWnd: ctx.rcvWnd, + }, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 79552fc61..1d245c2c6 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -308,7 +308,15 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { if ttl == 0 { ttl = s.route.DefaultTTL() } - h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, tcpFields{ + id: h.ep.ID, + ttl: ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) return nil } @@ -361,7 +369,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) return nil } @@ -550,7 +566,16 @@ func (h *handshake) execute() *tcpip.Error { synOpts.WS = -1 } } - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + + h.ep.sendSynTCP(&h.ep.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) for h.state != handshakeCompleted { h.ep.mu.Unlock() @@ -573,7 +598,15 @@ func (h *handshake) execute() *tcpip.Error { // the connection with another ACK or data (as ACKs are never // retransmitted on their own). if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) } case wakerForNotification: @@ -686,18 +719,33 @@ func makeSynOptions(opts header.TCPSynOptions) []byte { return options[:offset] } -func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { - options := makeSynOptions(opts) +// tcpFields is a struct to carry different parameters required by the +// send*TCP variant functions below. +type tcpFields struct { + id stack.TransportEndpointID + ttl uint8 + tos uint8 + flags byte + seq seqnum.Value + ack seqnum.Value + rcvWnd seqnum.Size + opts []byte + txHash uint32 +} + +func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error { + tf.opts = makeSynOptions(opts) // We ignore SYN send errors and let the callers re-attempt send. - if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil { + if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil { e.stats.SendErrors.SynSendToNetworkFailed.Increment() } - putOptions(options) + putOptions(tf.opts) return nil } -func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil { +func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { + tf.txHash = e.txHash + if err := sendTCP(r, tf, data, gso); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() return err } @@ -705,8 +753,8 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu return nil } -func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { - optLen := len(opts) +func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { + optLen := len(tf.opts) hdr := &pkt.Header packetSize := pkt.DataSize off := pkt.DataOffset @@ -714,15 +762,15 @@ func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *stack.Packet tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) pkt.TransportHeader = buffer.View(tcp) tcp.Encode(&header.TCPFields{ - SrcPort: id.LocalPort, - DstPort: id.RemotePort, - SeqNum: uint32(seq), - AckNum: uint32(ack), + SrcPort: tf.id.LocalPort, + DstPort: tf.id.RemotePort, + SeqNum: uint32(tf.seq), + AckNum: uint32(tf.ack), DataOffset: uint8(header.TCPMinimumSize + optLen), - Flags: flags, - WindowSize: uint16(rcvWnd), + Flags: tf.flags, + WindowSize: uint16(tf.rcvWnd), }) - copy(tcp[header.TCPMinimumSize:], opts) + copy(tcp[header.TCPMinimumSize:], tf.opts) length := uint16(hdr.UsedLength() + packetSize) xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) @@ -737,13 +785,12 @@ func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *stack.Packet xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } - } -func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - optLen := len(opts) - if rcvWnd > 0xffff { - rcvWnd = 0xffff +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { + optLen := len(tf.opts) + if tf.rcvWnd > 0xffff { + tf.rcvWnd = 0xffff } mss := int(gso.MSS) @@ -768,14 +815,15 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect pkts[i].DataOffset = off pkts[i].DataSize = packetSize pkts[i].Data = data - buildTCPHdr(r, id, &pkts[i], flags, seq, ack, rcvWnd, opts, gso) + pkts[i].Hash = tf.txHash + buildTCPHdr(r, tf, &pkts[i], gso) off += packetSize - seq = seq.Add(seqnum.Size(packetSize)) + tf.seq = tf.seq.Add(seqnum.Size(packetSize)) } - if ttl == 0 { - ttl = r.DefaultTTL() + if tf.ttl == 0 { + tf.ttl = r.DefaultTTL() } - sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}) + sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}) if err != nil { r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) } @@ -785,14 +833,14 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - optLen := len(opts) - if rcvWnd > 0xffff { - rcvWnd = 0xffff +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { + optLen := len(tf.opts) + if tf.rcvWnd > 0xffff { + tf.rcvWnd = 0xffff } if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { - return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso) + return sendTCPBatch(r, tf, data, gso) } pkt := stack.PacketBuffer{ @@ -800,18 +848,19 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise DataOffset: 0, DataSize: data.Size(), Data: data, + Hash: tf.txHash, } - buildTCPHdr(r, id, &pkt, flags, seq, ack, rcvWnd, opts, gso) + buildTCPHdr(r, tf, &pkt, gso) - if ttl == 0 { - ttl = r.DefaultTTL() + if tf.ttl == 0 { + tf.ttl = r.DefaultTTL() } - if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, pkt); err != nil { + if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } r.Stats().TCP.SegmentsSent.Increment() - if (flags & header.TCPFlagRst) != 0 { + if (tf.flags & header.TCPFlagRst) != 0 { r.Stats().TCP.ResetsSent.Increment() } return nil @@ -863,7 +912,16 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso) + err := e.sendTCP(&e.route, tcpFields{ + id: e.ID, + ttl: e.ttl, + tos: e.sendTOS, + flags: flags, + seq: seq, + ack: ack, + rcvWnd: rcvWnd, + opts: options, + }, data, e.gso) putOptions(options) return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 594efaa11..b6e571361 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -581,6 +581,10 @@ type endpoint struct { // endpoint and at this point the endpoint is only around // to complete the TCP shutdown. closed bool + + // txHash is the transport layer hash to be set on outbound packets + // emitted by this endpoint. + txHash uint32 } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -771,6 +775,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue count: 9, }, uniqueID: s.UniqueID(), + txHash: s.Rand().Uint32(), } var ss SendBufferSizeOption diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 57985b85d..1377107ca 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -191,7 +191,15 @@ func replyWithReset(s *segment) { flags |= header.TCPFlagAck ack = s.sequenceNumber.Add(s.logicalLen()) } - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) + sendTCP(&s.route, tcpFields{ + id: s.id, + ttl: s.route.DefaultTTL(), + tos: stack.DefaultTOS, + flags: flags, + seq: seq, + ack: ack, + rcvWnd: 0, + }, buffer.VectorisedView{}, nil /* gso */) } // SetOption implements stack.TransportProtocol.SetOption. -- cgit v1.2.3 From d04adebaab86ac30aca463b06528fc22430598ac Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 25 Mar 2020 10:54:19 -0700 Subject: Fix data-race in endpoint.Readiness PiperOrigin-RevId: 302924789 --- pkg/sync/aliases.go | 5 +++ pkg/tcpip/transport/tcp/accept.go | 46 ++++++++++++-------- pkg/tcpip/transport/tcp/endpoint.go | 70 ++++++++++++++++++++----------- pkg/tcpip/transport/tcp/endpoint_state.go | 3 ++ 4 files changed, 82 insertions(+), 42 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go index d2d7132fa..0d4316254 100644 --- a/pkg/sync/aliases.go +++ b/pkg/sync/aliases.go @@ -29,3 +29,8 @@ type ( // Map is an alias of sync.Map. Map = sync.Map ) + +// NewCond is a wrapper around sync.NewCond. +func NewCond(l Locker) *Cond { + return sync.NewCond(l) +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index b4c4c8ab1..375ca21f6 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -365,21 +365,29 @@ func (l *listenContext) closeAllPendingEndpoints() { } // deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// endpoint has transitioned out of the listen state, the new endpoint is closed -// instead. +// endpoint has transitioned out of the listen state (acceptedChan is nil), +// the new endpoint is closed instead. func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.Lock() - state := e.EndpointState() e.pendingAccepted.Add(1) - defer e.pendingAccepted.Done() - acceptedChan := e.acceptedChan e.mu.Unlock() + defer e.pendingAccepted.Done() - if state == StateListen { - acceptedChan <- n - e.waiterQueue.Notify(waiter.EventIn) - } else { - n.Close() + e.acceptMu.Lock() + for { + if e.acceptedChan == nil { + e.acceptMu.Unlock() + n.Close() + return + } + select { + case e.acceptedChan <- n: + e.acceptMu.Unlock() + e.waiterQueue.Notify(waiter.EventIn) + return + default: + e.acceptCond.Wait() + } } } @@ -420,11 +428,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } func (e *endpoint) incSynRcvdCount() bool { - if e.synRcvdCount >= cap(e.acceptedChan) { - return false + e.acceptMu.Lock() + canInc := e.synRcvdCount < cap(e.acceptedChan) + e.acceptMu.Unlock() + if canInc { + e.synRcvdCount++ } - e.synRcvdCount++ - return true + return canInc } func (e *endpoint) decSynRcvdCount() { @@ -432,10 +442,10 @@ func (e *endpoint) decSynRcvdCount() { } func (e *endpoint) acceptQueueIsFull() bool { - if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c { - return true - } - return false + e.acceptMu.Lock() + full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) + e.acceptMu.Unlock() + return full } // handleListenSegment is called when a listening endpoint receives a segment diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b6e571361..1ebee0cfe 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -291,6 +291,7 @@ func (*EndpointInfo) IsEndpointInfo() {} // The following three mutexes can be acquired independent of e.mu but if // acquired with e.mu then e.mu must be acquired first. // +// e.acceptMu -> protects acceptedChan. // e.rcvListMu -> Protects the rcvList and associated fields. // e.sndBufMu -> Protects the sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. @@ -533,6 +534,23 @@ type endpoint struct { // to the acceptedChan below terminate before we close acceptedChan. pendingAccepted sync.WaitGroup `state:"nosave"` + // acceptMu protects acceptedChan. + acceptMu sync.Mutex `state:"nosave"` + + // acceptCond is a condition variable that can be used to block on when + // acceptedChan is full and an endpoint is ready to be delivered. + // + // This condition variable is required because just blocking on sending + // to acceptedChan does not work in cases where endpoint.Listen is + // called twice with different backlog values. In such cases the channel + // is closed and a new one created. Any pending goroutines blocking on + // the write to the channel will panic. + // + // We use this condition variable to block/unblock goroutines which + // tried to deliver an endpoint but couldn't because accept backlog was + // full ( See: endpoint.deliverAccepted ). + acceptCond *sync.Cond `state:"nosave"` + // acceptedChan is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. @@ -814,6 +832,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.segmentQueue.setLimit(MaxUnprocessedSegments) e.tsOffset = timeStampOffset() + e.acceptCond = sync.NewCond(&e.acceptMu) return e } @@ -834,9 +853,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { case StateListen: // Check if there's anything in the accepted channel. if (mask & waiter.EventIn) != 0 { + e.acceptMu.Lock() if len(e.acceptedChan) > 0 { result |= waiter.EventIn } + e.acceptMu.Unlock() } } if e.EndpointState().connected() { @@ -981,29 +1002,19 @@ func (e *endpoint) closeNoShutdownLocked() { // closePendingAcceptableConnections closes all connections that have completed // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { - done := make(chan struct{}) - // Spin a goroutine up as ranging on e.acceptedChan will just block when - // there are no more connections in the channel. Using a non-blocking - // select does not work as it can potentially select the default case - // even when there are pending writes but that are not yet written to - // the channel. - go func() { - defer close(done) - for n := range e.acceptedChan { - n.notifyProtocolGoroutine(notifyReset) - // close all connections that have completed but - // not accepted by the application. - n.Close() - } - }() - // pendingAccepted(see endpoint.deliverAccepted) tracks the number of - // endpoints which have completed handshake but are not yet written to - // the e.acceptedChan. We wait here till the goroutine above can drain - // all such connections from e.acceptedChan. - e.pendingAccepted.Wait() + e.acceptMu.Lock() + if e.acceptedChan == nil { + e.acceptMu.Unlock() + return + } + close(e.acceptedChan) - <-done e.acceptedChan = nil + e.acceptCond.Broadcast() + e.acceptMu.Unlock() + + // Wait for all pending endpoints to close. + e.pendingAccepted.Wait() } // cleanupLocked frees all resources associated with the endpoint. It is called @@ -1012,9 +1023,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. - if e.acceptedChan != nil { - e.closePendingAcceptableConnectionsLocked() - } + e.closePendingAcceptableConnectionsLocked() e.workerCleanup = false @@ -2204,6 +2213,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { if e.EndpointState() == StateListen && !e.workerCleanup { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. + e.acceptMu.Lock() + defer e.acceptMu.Unlock() if len(e.acceptedChan) > backlog { return tcpip.ErrInvalidEndpointState } @@ -2216,6 +2227,11 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { for ep := range origChan { e.acceptedChan <- ep } + + // Notify any blocked goroutines that they can attempt to + // deliver endpoints again. + e.acceptCond.Broadcast() + return nil } @@ -2245,9 +2261,12 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // The channel may be non-nil when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) // endpoints. + e.acceptMu.Lock() if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } + e.acceptMu.Unlock() + e.workerRunning = true go e.protocolListenLoop( // S/R-SAFE: drained on save. seqnum.Size(e.receiveBufferAvailable())) @@ -2276,9 +2295,12 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } // Get the new accepted endpoint. + e.acceptMu.Lock() + defer e.acceptMu.Unlock() var n *endpoint select { case n = <-e.acceptedChan: + e.acceptCond.Signal() default: return nil, nil, tcpip.ErrWouldBlock } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 9175de441..c3c692555 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -173,6 +173,9 @@ func (e *endpoint) afterLoad() { // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. e.state = StateInitial + // Condition variables and mutexs are not S/R'ed so reinitialize + // acceptCond with e.acceptMu. + e.acceptCond = sync.NewCond(&e.acceptMu) stack.StackFromEnv.RegisterRestoredEndpoint(e) } -- cgit v1.2.3 From 92b9069b67b927cef25a1490ebd142ad6d65690d Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Fri, 20 Mar 2020 12:00:21 -0700 Subject: Support owner matching for iptables. This feature will match UID and GID of the packet creator, for locally generated packets. This match is only valid in the OUTPUT and POSTROUTING chains. Forwarded packets do not have any socket associated with them. Packets from kernel threads do have a socket, but usually no owner. --- pkg/abi/linux/netfilter.go | 41 ++++++++ pkg/abi/linux/netfilter_test.go | 1 + pkg/sentry/kernel/task.go | 12 +++ pkg/sentry/socket/netfilter/BUILD | 1 + pkg/sentry/socket/netfilter/netfilter.go | 7 +- pkg/sentry/socket/netfilter/owner_matcher.go | 128 ++++++++++++++++++++++++ pkg/sentry/socket/netstack/provider.go | 6 ++ pkg/tcpip/network/ipv4/ipv4.go | 15 +++ pkg/tcpip/stack/packet_buffer.go | 9 +- pkg/tcpip/stack/transport_test.go | 2 + pkg/tcpip/tcpip.go | 12 +++ pkg/tcpip/transport/icmp/endpoint.go | 12 ++- pkg/tcpip/transport/packet/endpoint.go | 2 + pkg/tcpip/transport/raw/endpoint.go | 9 ++ pkg/tcpip/transport/tcp/accept.go | 5 +- pkg/tcpip/transport/tcp/connect.go | 10 +- pkg/tcpip/transport/tcp/endpoint.go | 7 ++ pkg/tcpip/transport/tcp/forwarder.go | 2 +- pkg/tcpip/transport/tcp/protocol.go | 2 +- pkg/tcpip/transport/udp/endpoint.go | 12 ++- test/iptables/filter_output.go | 143 +++++++++++++++++++++++++++ test/iptables/iptables_test.go | 30 ++++++ 22 files changed, 451 insertions(+), 17 deletions(-) create mode 100644 pkg/sentry/socket/netfilter/owner_matcher.go (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 80dc09aa9..a8d4f9d69 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -509,3 +509,44 @@ const ( // Enable all flags. XT_UDP_INV_MASK = 0x03 ) + +// IPTOwnerInfo holds data for matching packets with owner. It corresponds +// to struct ipt_owner_info in libxt_owner.c of iptables binary. +type IPTOwnerInfo struct { + // UID is user id which created the packet. + UID uint32 + + // GID is group id which created the packet. + GID uint32 + + // PID is process id of the process which created the packet. + PID uint32 + + // SID is session id which created the packet. + SID uint32 + + // Comm is the command name which created the packet. + Comm [16]byte + + // Match is used to match UID/GID of the socket. See the + // XT_OWNER_* flags below. + Match uint8 + + // Invert flips the meaning of Match field. + Invert uint8 +} + +// SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo. +const SizeOfIPTOwnerInfo = 34 + +// Flags in IPTOwnerInfo.Match. Corresponding constants are in +// include/uapi/linux/netfilter/xt_owner.h. +const ( + // Match the UID of the packet. + XT_OWNER_UID = 1 << 0 + // Match the GID of the packet. + XT_OWNER_GID = 1 << 1 + // Match if the socket exists for the packet. Forwarded + // packets do not have an associated socket. + XT_OWNER_SOCKET = 1 << 2 +) diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go index 21e237f92..565dd550e 100644 --- a/pkg/abi/linux/netfilter_test.go +++ b/pkg/abi/linux/netfilter_test.go @@ -29,6 +29,7 @@ func TestSizes(t *testing.T) { {IPTGetEntries{}, SizeOfIPTGetEntries}, {IPTGetinfo{}, SizeOfIPTGetinfo}, {IPTIP{}, SizeOfIPTIP}, + {IPTOwnerInfo{}, SizeOfIPTOwnerInfo}, {IPTReplace{}, SizeOfIPTReplace}, {XTCounters{}, SizeOfXTCounters}, {XTEntryMatch{}, SizeOfXTEntryMatch}, diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 8452ddf5b..d6546735e 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -863,3 +863,15 @@ func (t *Task) SetOOMScoreAdj(adj int32) error { atomic.StoreInt32(&t.tg.oomScoreAdj, adj) return nil } + +// UID returns t's uid. +// TODO(gvisor.dev/issue/170): This method is not namespaced yet. +func (t *Task) UID() uint32 { + return uint32(t.Credentials().EffectiveKUID) +} + +// GID returns t's gid. +// TODO(gvisor.dev/issue/170): This method is not namespaced yet. +func (t *Task) GID() uint32 { + return uint32(t.Credentials().EffectiveKGID) +} diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index e801abeb8..721094bbf 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "extensions.go", "netfilter.go", + "owner_matcher.go", "targets.go", "tcp_matcher.go", "udp_matcher.go", diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 55bcc3ace..878f81fd5 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -517,11 +517,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { } // TODO(gvisor.dev/issue/170): Support other chains. - // Since we only support modifying the INPUT chain and redirect for - // PREROUTING chain right now, make sure all other chains point to - // ACCEPT rules. + // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, + // make sure all other chains point to ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook != stack.Input && hook != stack.Prerouting { + if hook == stack.Forward || hook == stack.Postrouting { if _, ok := table.Rules[ruleIdx].Target.(stack.AcceptTarget); !ok { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go new file mode 100644 index 000000000..5949a7c29 --- /dev/null +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -0,0 +1,128 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +const matcherNameOwner = "owner" + +func init() { + registerMatchMaker(ownerMarshaler{}) +} + +// ownerMarshaler implements matchMaker for owner matching. +type ownerMarshaler struct{} + +// name implements matchMaker.name. +func (ownerMarshaler) name() string { + return matcherNameOwner +} + +// marshal implements matchMaker.marshal. +func (ownerMarshaler) marshal(mr stack.Matcher) []byte { + matcher := mr.(*OwnerMatcher) + iptOwnerInfo := linux.IPTOwnerInfo{ + UID: matcher.uid, + GID: matcher.gid, + } + + // Support for UID match. + // TODO(gvisor.dev/issue/170): Need to support gid match. + if matcher.matchUID { + iptOwnerInfo.Match = linux.XT_OWNER_UID + } else if matcher.matchGID { + panic("GID match is not supported.") + } else { + panic("UID match is not set.") + } + + buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo) + return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo)) +} + +// unmarshal implements matchMaker.unmarshal. +func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { + if len(buf) < linux.SizeOfIPTOwnerInfo { + return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf)) + } + + // For alignment reasons, the match's total size may + // exceed what's strictly necessary to hold matchData. + var matchData linux.IPTOwnerInfo + binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData) + nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData) + + if matchData.Invert != 0 { + return nil, fmt.Errorf("invert flag is not supported for owner match") + } + + // Support for UID match. + // TODO(gvisor.dev/issue/170): Need to support gid match. + if matchData.Match&linux.XT_OWNER_UID != linux.XT_OWNER_UID { + return nil, fmt.Errorf("owner match is only supported for uid") + } + + // Check Flags. + var owner OwnerMatcher + owner.uid = matchData.UID + owner.gid = matchData.GID + owner.matchUID = true + + return &owner, nil +} + +type OwnerMatcher struct { + uid uint32 + gid uint32 + matchUID bool + matchGID bool + invert uint8 +} + +// Name implements Matcher.Name. +func (*OwnerMatcher) Name() string { + return matcherNameOwner +} + +// Match implements Matcher.Match. +func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { + // Support only for OUTPUT chain. + // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. + if hook != stack.Output { + return false, true + } + + // If the packet owner is not set, drop the packet. + // Support for uid match. + // TODO(gvisor.dev/issue/170): Need to support gid match. + if pkt.Owner == nil || !om.matchUID { + return false, true + } + + // TODO(gvisor.dev/issue/170): Need to add tests to verify + // drop rule when packet UID does not match owner matcher UID. + if pkt.Owner.UID() != om.uid { + return false, false + } + + return true, false +} diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index 5f181f017..eb090e79b 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -126,6 +126,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) } else { ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq) + + // Assign task to PacketOwner interface to get the UID and GID for + // iptables owner matching. + if e == nil { + ep.SetOwner(t) + } } if e != nil { return nil, syserr.TranslateNetstackError(e) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b3ee6000e..a7d9a8b25 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -244,6 +244,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + // iptables filtering. All packets that reach here are locally + // generated. + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt); !ok { + // iptables is telling us to drop the packet. + return nil + } + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. @@ -280,7 +288,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac return len(pkts), nil } + // iptables filtering. All packets that reach here are locally + // generated. + ipt := e.stack.IPTables() for i := range pkts { + if ok := ipt.Check(stack.Output, pkts[i]); !ok { + // iptables is telling us to drop the packet. + continue + } ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) pkts[i].NetworkHeader = buffer.View(ip) } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9505a4e92..9367de180 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -13,7 +13,10 @@ package stack -import "gvisor.dev/gvisor/pkg/tcpip/buffer" +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) // A PacketBuffer contains all the data of a network packet. // @@ -59,6 +62,10 @@ type PacketBuffer struct { // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. Hash uint32 + + // Owner is implemented by task to get the uid and gid. + // Only set for locally generated packets. + Owner tcpip.PacketOwner } // Clone makes a copy of pk. It clones the Data field, which creates a new diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8ca9ac3cf..3084e6593 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -56,6 +56,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } +func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} + func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3dc5d87d6..2ef3271f1 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -336,6 +336,15 @@ type ControlMessages struct { PacketInfo IPPacketInfo } +// PacketOwner is used to get UID and GID of the packet. +type PacketOwner interface { + // UID returns UID of the packet. + UID() uint32 + + // GID returns GID of the packet. + GID() uint32 +} + // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) // that exposes functionality like read, write, connect, etc. to users of the // networking stack. @@ -470,6 +479,9 @@ type Endpoint interface { // Stats returns a reference to the endpoint stats. Stats() EndpointStats + + // SetOwner sets the task owner to the endpoint owner. + SetOwner(owner PacketOwner) } // EndpointInfo is the interface implemented by each endpoint info struct. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 613b12ead..b007302fb 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -73,6 +73,9 @@ type endpoint struct { route stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -133,6 +136,10 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + // IPTables implements tcpip.Endpoint.IPTables. func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil @@ -321,7 +328,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c switch e.NetProto { case header.IPv4ProtocolNumber: - err = send4(route, e.ID.LocalPort, v, e.ttl) + err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) case header.IPv6ProtocolNumber: err = send6(route, e.ID.LocalPort, v, e.ttl) @@ -415,7 +422,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } @@ -444,6 +451,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err Header: hdr, Data: data.ToVectorisedView(), TransportHeader: buffer.View(icmpv4), + Owner: owner, }) } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index df49d0995..23158173d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -392,3 +392,5 @@ func (ep *endpoint) Info() tcpip.EndpointInfo { func (ep *endpoint) Stats() tcpip.EndpointStats { return &ep.stats } + +func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 536dafd1e..337bc1c71 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -80,6 +80,9 @@ type endpoint struct { // Connect(), and is valid only when conneted is true. route stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } // NewEndpoint returns a raw endpoint for the given protocols. @@ -159,6 +162,10 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + // IPTables implements tcpip.Endpoint.IPTables. func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil @@ -348,10 +355,12 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } break } + hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: buffer.View(payloadBytes).ToVectorisedView(), + Owner: e.owner, }); err != nil { return 0, nil, err } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 375ca21f6..7a9dea4ac 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -276,7 +276,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // and then performs the TCP 3-way handshake. // // The new endpoint is returned with e.mu held. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -284,6 +284,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if err != nil { return nil, err } + ep.owner = owner // listenEP is nil when listenContext is used by tcp.Forwarder. deferAccept := time.Duration(0) @@ -414,7 +415,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header }() defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}) + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 1d245c2c6..3239a5911 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -745,7 +745,7 @@ func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOp func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { tf.txHash = e.txHash - if err := sendTCP(r, tf, data, gso); err != nil { + if err := sendTCP(r, tf, data, gso, e.owner); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() return err } @@ -787,7 +787,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta } } -func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > 0xffff { tf.rcvWnd = 0xffff @@ -816,6 +816,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkts[i].DataSize = packetSize pkts[i].Data = data pkts[i].Hash = tf.txHash + pkts[i].Owner = owner buildTCPHdr(r, tf, &pkts[i], gso) off += packetSize tf.seq = tf.seq.Add(seqnum.Size(packetSize)) @@ -833,14 +834,14 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > 0xffff { tf.rcvWnd = 0xffff } if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { - return sendTCPBatch(r, tf, data, gso) + return sendTCPBatch(r, tf, data, gso, owner) } pkt := stack.PacketBuffer{ @@ -849,6 +850,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac DataSize: data.Size(), Data: data, Hash: tf.txHash, + Owner: owner, } buildTCPHdr(r, tf, &pkt, gso) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 1ebee0cfe..9b123e968 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -603,6 +603,9 @@ type endpoint struct { // txHash is the transport layer hash to be set on outbound packets // emitted by this endpoint. txHash uint32 + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -1132,6 +1135,10 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + // IPTables implements tcpip.Endpoint.IPTables. func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index a094471b8..808410c92 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -157,7 +157,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, TSVal: r.synOptions.TSVal, TSEcr: r.synOptions.TSEcr, SACKPermitted: r.synOptions.SACKPermitted, - }, queue) + }, queue, nil) if err != nil { return nil, err } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 1377107ca..dce9a1652 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -199,7 +199,7 @@ func replyWithReset(s *segment) { seq: seq, ack: ack, rcvWnd: 0, - }, buffer.VectorisedView{}, nil /* gso */) + }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */) } // SetOption implements stack.TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index a3372ac58..120d3baa3 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -143,6 +143,9 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } // +stateify savable @@ -484,7 +487,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -886,7 +889,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -916,6 +919,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Header: hdr, Data: data, TransportHeader: buffer.View(udp), + Owner: owner, }); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err @@ -1356,3 +1360,7 @@ func (*endpoint) Wait() {} func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } + +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index 4582d514c..f6d974b85 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -24,6 +24,11 @@ func init() { RegisterTestCase(FilterOutputDropTCPSrcPort{}) RegisterTestCase(FilterOutputDestination{}) RegisterTestCase(FilterOutputInvertDestination{}) + RegisterTestCase(FilterOutputAcceptTCPOwner{}) + RegisterTestCase(FilterOutputDropTCPOwner{}) + RegisterTestCase(FilterOutputAcceptUDPOwner{}) + RegisterTestCase(FilterOutputDropUDPOwner{}) + RegisterTestCase(FilterOutputOwnerFail{}) } // FilterOutputDropTCPDestPort tests that connections are not accepted on @@ -90,6 +95,144 @@ func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error { return nil } +// FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted. +type FilterOutputAcceptTCPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputAcceptTCPOwner) Name() string { + return "FilterOutputAcceptTCPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + if err := listenTCP(acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("connection on port %d should be accepted, but got dropped", acceptPort) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP) error { + if err := connectTCP(ip, acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("connection destined to port %d should be accepted, but got dropped", acceptPort) + } + + return nil +} + +// FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped. +type FilterOutputDropTCPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputDropTCPOwner) Name() string { + return "FilterOutputDropTCPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + if err := listenTCP(acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropTCPOwner) LocalAction(ip net.IP) error { + if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort) + } + + return nil +} + +// FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted. +type FilterOutputAcceptUDPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputAcceptUDPOwner) Name() string { + return "FilterOutputAcceptUDPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { + return err + } + + // Send UDP packets on acceptPort. + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP) error { + // Listen for UDP packets on acceptPort. + return listenUDP(acceptPort, sendloopDuration) +} + +// FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped. +type FilterOutputDropUDPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputDropUDPOwner) Name() string { + return "FilterOutputDropUDPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { + return err + } + + // Send UDP packets on dropPort. + return sendUDPLoop(ip, dropPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropUDPOwner) LocalAction(ip net.IP) error { + // Listen for UDP packets on dropPort. + if err := listenUDP(dropPort, sendloopDuration); err == nil { + return fmt.Errorf("packets should not be received") + } + + return nil +} + +// FilterOutputOwnerFail tests that without uid/gid option, owner rule +// will fail. +type FilterOutputOwnerFail struct{} + +// Name implements TestCase.Name. +func (FilterOutputOwnerFail) Name() string { + return "FilterOutputOwnerFail" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputOwnerFail) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { + return fmt.Errorf("Invalid argument") + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputOwnerFail) LocalAction(ip net.IP) error { + // no-op. + return nil +} + // FilterOutputDestination tests that we can selectively allow packets to // certain destinations. type FilterOutputDestination struct{} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 7f1f70606..493d69052 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -274,6 +274,36 @@ func TestFilterOutputDropTCPSrcPort(t *testing.T) { } } +func TestFilterOutputAcceptTCPOwner(t *testing.T) { + if err := singleTest(FilterOutputAcceptTCPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputDropTCPOwner(t *testing.T) { + if err := singleTest(FilterOutputDropTCPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputAcceptUDPOwner(t *testing.T) { + if err := singleTest(FilterOutputAcceptUDPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputDropUDPOwner(t *testing.T) { + if err := singleTest(FilterOutputDropUDPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputOwnerFail(t *testing.T) { + if err := singleTest(FilterOutputOwnerFail{}); err != nil { + t.Fatal(err) + } +} + func TestJumpSerialize(t *testing.T) { if err := singleTest(FilterInputSerializeJump{}); err != nil { t.Fatal(err) -- cgit v1.2.3 From 9c918340e4e6126cca1dfedbf28fec8c8f836e1a Mon Sep 17 00:00:00 2001 From: Mithun Iyer Date: Wed, 15 Apr 2020 01:10:38 -0700 Subject: Reset pending connections on listener close Attempt to redeliver TCP segments that are enqueued into a closing TCP endpoint. This was being done for Established endpoints but not for those that are listening or performing connection handshake. Fixes #2417 PiperOrigin-RevId: 306598155 --- pkg/tcpip/transport/tcp/accept.go | 7 +++- pkg/tcpip/transport/tcp/connect.go | 30 ++++++++------ pkg/tcpip/transport/tcp/endpoint.go | 30 +++++++------- pkg/tcpip/transport/tcp/tcp_test.go | 37 +++++++++++++++++ test/packetimpact/tests/BUILD | 2 - test/syscalls/linux/socket_inet_loopback.cc | 62 +++++++++++++++++++++++++++++ 6 files changed, 138 insertions(+), 30 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 7a9dea4ac..e07b436c4 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -330,6 +330,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if l.listenEP != nil { l.removePendingEndpoint(ep) } + + ep.drainClosingSegmentQueue() + return nil, err } ep.isConnectNotified = true @@ -378,7 +381,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) { for { if e.acceptedChan == nil { e.acceptMu.Unlock() - n.Close() + n.notifyProtocolGoroutine(notifyReset) return } select { @@ -656,6 +659,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { } e.mu.Unlock() + e.drainClosingSegmentQueue() + // Notify waiters that the endpoint is shutdown. e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 2ca3fb809..994ac52a3 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1062,6 +1062,20 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { } } +// Drain segment queue from the endpoint and try to re-match the segment to a +// different endpoint. This is used when the current endpoint is transitioned to +// StateClose and has been unregistered from the transport demuxer. +func (e *endpoint) drainClosingSegmentQueue() { + for { + s := e.segmentQueue.dequeue() + if s == nil { + break + } + + e.tryDeliverSegmentFromClosedEndpoint(s) + } +} + func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { if e.rcv.acceptable(s.sequenceNumber, 0) { // RFC 793, page 37 states that "in all states @@ -1315,6 +1329,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } e.mu.Unlock() + + e.drainClosingSegmentQueue() + // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1565,19 +1582,6 @@ loop: // Lock released below. epilogue() - // epilogue removes the endpoint from the transport-demuxer and - // unlocks e.mu. Now that no new segments can get enqueued to this - // endpoint, try to re-match the segment to a different endpoint - // as the current endpoint is closed. - for { - s := e.segmentQueue.dequeue() - if s == nil { - break - } - - e.tryDeliverSegmentFromClosedEndpoint(s) - } - // A new SYN was received during TIME_WAIT and we need to abort // the timewait and redirect the segment to the listener queue if reuseTW != nil { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a8d443f73..7ed78d57f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -980,25 +980,22 @@ func (e *endpoint) closeNoShutdownLocked() { // Mark endpoint as closed. e.closed = true + + switch e.EndpointState() { + case StateClose, StateError: + return + } + // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. - switch e.EndpointState() { - // Sockets in StateSynRecv state(passive connections) are closed when - // the handshake fails or if the listening socket is closed while - // handshake was in progress. In such cases the handshake goroutine - // is already gone by the time Close is called and we need to cleanup - // here. - case StateInitial, StateBound, StateSynRecv: - e.cleanupLocked() - e.setEndpointState(StateClose) - case StateError, StateClose: - // do nothing. - default: + if e.workerRunning { e.workerCleanup = true tcpip.AddDanglingEndpoint(e) // Worker will remove the dangling endpoint when the endpoint // goroutine terminates. e.notifyProtocolGoroutine(notifyClose) + } else { + e.transitionToStateCloseLocked() } } @@ -1010,13 +1007,18 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Unlock() return } - close(e.acceptedChan) + ch := e.acceptedChan e.acceptedChan = nil e.acceptCond.Broadcast() e.acceptMu.Unlock() - // Wait for all pending endpoints to close. + // Reset all connections that are waiting to be accepted. + for n := range ch { + n.notifyProtocolGoroutine(notifyReset) + } + // Wait for reset of all endpoints that are still waiting to be delivered to + // the now closed acceptedChan. e.pendingAccepted.Wait() } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 41caa9ed4..a9f121c17 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1068,6 +1068,43 @@ func TestListenShutdown(t *testing.T) { c.CheckNoPacket("Packet received when listening socket was shutdown") } +// TestListenCloseWhileConnect tests for the listening endpoint to +// drain the accept-queue when closed. This should reset all of the +// pending connections that are waiting to be accepted. +func TestListenCloseWhileConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(1 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventIn) + defer c.WQ.EventUnregister(&waitEntry) + + executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + // Wait for the new endpoint created because of handshake to be delivered + // to the listening endpoint's accept queue. + <-notifyCh + + // Close the listening endpoint. + c.EP.Close() + + // Expect the listening endpoint to reset the connection. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) +} + func TestTOSV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 308590162..1274d9f60 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -43,8 +43,6 @@ packetimpact_go_test( packetimpact_go_test( name = "tcp_noaccept_close_rst", srcs = ["tcp_noaccept_close_rst_test.go"], - # TODO(b/153380909): Fix netstack then remove the line below. - netstack = False, deps = [ "//pkg/tcpip/header", "//test/packetimpact/testbench", diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 71bd7c14d..cd84e633a 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -365,6 +365,68 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { } } +TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kBacklog = 2; + constexpr int kClients = kBacklog + 1; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector clients; + for (int i = 0; i < kClients; i++) { + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + clients.push_back(std::move(client)); + } + } + // Close the listening socket. + listen_fd.reset(); + + for (auto& client : clients) { + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = client.get(), + .events = POLLIN, + }; + // When the listening socket is closed, then we expect the remote to reset + // the connection. + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR); + char c; + // Subsequent read can fail with: + // ECONNRESET: If the client connection was established and was reset by the + // remote. ECONNREFUSED: If the client connection failed to be established. + ASSERT_THAT(read(client.get(), &c, sizeof(c)), + AnyOf(SyscallFailsWithErrno(ECONNRESET), + SyscallFailsWithErrno(ECONNREFUSED))); + } +} + TEST_P(SocketInetLoopbackTest, TCPbacklog) { auto const& param = GetParam(); -- cgit v1.2.3 From 0eda0104a5a7c95a36dd288199ec1e90be9d8be9 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 16 Apr 2020 16:48:14 -0700 Subject: Fix data race in tcp_test. This change makes SynRcvdCountThreshold and the global synRcvdCount into a stack configurable value. This is required because in cases like mod_proxy which create multiple Stack instances the count will be a global value that impacts all Stack instances. Further the tests relied on modifying the global threshold to simulate tests where we want to verify SYN cookie based behaviour. This lead to data races due to the global being modified/read without locks or atomics. PiperOrigin-RevId: 306947723 --- pkg/tcpip/tcpip.go | 5 ++ pkg/tcpip/transport/tcp/accept.go | 106 ++++++++++---------------- pkg/tcpip/transport/tcp/dual_stack_test.go | 9 +-- pkg/tcpip/transport/tcp/protocol.go | 77 +++++++++++++++++++ pkg/tcpip/transport/tcp/tcp_sack_test.go | 39 +++++----- pkg/tcpip/transport/tcp/tcp_test.go | 32 ++++---- pkg/tcpip/transport/tcp/tcp_timestamp_test.go | 30 ++++---- 7 files changed, 176 insertions(+), 122 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 109121dbc..1ca4088c9 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -685,6 +685,11 @@ type TCPDeferAcceptOption time.Duration // default MinRTO used by the Stack. type TCPMinRTOOption time.Duration +// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify +// the number of endpoints that can be in SYN-RCVD state before the stack +// switches to using SYN cookies. +type TCPSynRcvdCountThresholdOption uint64 + // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index e07b436c4..b61c2a8c3 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -17,6 +17,7 @@ package tcp import ( "crypto/sha1" "encoding/binary" + "fmt" "hash" "io" "time" @@ -49,17 +50,14 @@ const ( // timestamp and the current timestamp. If the difference is greater // than maxTSDiff, the cookie is expired. maxTSDiff = 2 -) -var ( - // SynRcvdCountThreshold is the global maximum number of connections - // that are allowed to be in SYN-RCVD state before TCP starts using SYN - // cookies to accept connections. - // - // It is an exported variable only for testing, and should not otherwise - // be used by importers of this package. + // SynRcvdCountThreshold is the default global maximum number of + // connections that are allowed to be in SYN-RCVD state before TCP + // starts using SYN cookies to accept connections. SynRcvdCountThreshold uint64 = 1000 +) +var ( // mssTable is a slice containing the possible MSS values that we // encode in the SYN cookie with two bits. mssTable = []uint16{536, 1300, 1440, 1460} @@ -74,29 +72,42 @@ func encodeMSS(mss uint16) uint32 { return 0 } -// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is -// protected by a mutex so that we can increment only when it's guaranteed not -// to go above a threshold. -var synRcvdCount struct { - sync.Mutex - value uint64 - pending sync.WaitGroup -} - // listenContext is used by a listening endpoint to store state used while // listening for connections. This struct is allocated by the listen goroutine // and must not be accessed or have its methods called concurrently as they // may mutate the stored objects. type listenContext struct { - stack *stack.Stack - rcvWnd seqnum.Size - nonce [2][sha1.BlockSize]byte + stack *stack.Stack + + // synRcvdCount is a reference to the stack level synRcvdCount. + synRcvdCount *synRcvdCounter + + // rcvWnd is the receive window that is sent by this listening context + // in the initial SYN-ACK. + rcvWnd seqnum.Size + + // nonce are random bytes that are initialized once when the context + // is created and used to seed the hash function when generating + // the SYN cookie. + nonce [2][sha1.BlockSize]byte + + // listenEP is a reference to the listening endpoint associated with + // this context. Can be nil if the context is created by the forwarder. listenEP *endpoint + // hasherMu protects hasher. hasherMu sync.Mutex - hasher hash.Hash - v6only bool + // hasher is the hash function used to generate a SYN cookie. + hasher hash.Hash + + // v6Only is true if listenEP is a dual stack socket and has the + // IPV6_V6ONLY option set. + v6only bool + + // netProto indicates the network protocol(IPv4/v6) for the listening + // endpoint. netProto tcpip.NetworkProtocolNumber + // pendingMu protects pendingEndpoints. This should only be accessed // by the listening endpoint's worker goroutine. // @@ -115,44 +126,6 @@ func timeStamp() uint32 { return uint32(time.Now().Unix()>>6) & tsMask } -// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD -// state. It succeeds if the increment doesn't make the count go beyond the -// threshold, and fails otherwise. -func incSynRcvdCount() bool { - synRcvdCount.Lock() - - if synRcvdCount.value >= SynRcvdCountThreshold { - synRcvdCount.Unlock() - return false - } - - synRcvdCount.pending.Add(1) - synRcvdCount.value++ - - synRcvdCount.Unlock() - return true -} - -// decSynRcvdCount atomically decrements the global number of endpoints in -// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount -// succeeded. -func decSynRcvdCount() { - synRcvdCount.Lock() - - synRcvdCount.value-- - synRcvdCount.pending.Done() - synRcvdCount.Unlock() -} - -// synCookiesInUse() returns true if the synRcvdCount is greater than -// SynRcvdCountThreshold. -func synCookiesInUse() bool { - synRcvdCount.Lock() - v := synRcvdCount.value - synRcvdCount.Unlock() - return v >= SynRcvdCountThreshold -} - // newListenContext creates a new listen context. func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ @@ -164,6 +137,11 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, listenEP: listenEP, pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } + p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol) + if !ok { + panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk)) + } + l.synRcvdCount = p.SynRcvdCounter() rand.Read(l.nonce[0][:]) rand.Read(l.nonce[1][:]) @@ -410,7 +388,7 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { - defer decSynRcvdCount() + defer ctx.synRcvdCount.dec() defer func() { e.mu.Lock() e.decSynRcvdCount() @@ -477,7 +455,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { switch { case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) - if incSynRcvdCount() { + if ctx.synRcvdCount.inc() { // Only handle the syn if the following conditions hold // - accept queue is not full. // - number of connections in synRcvd state is less than the @@ -487,7 +465,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. return } - decSynRcvdCount() + ctx.synRcvdCount.dec() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() @@ -540,7 +518,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } - if !synCookiesInUse() { + if !ctx.synRcvdCount.synCookiesInUse() { // When not using SYN cookies, as per RFC 793, section 3.9, page 64: // Any acknowledgment is bad if it arrives on a connection still in // the LISTEN state. An acceptable reset segment should be formed diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 4f361b226..804e95aea 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -568,11 +568,10 @@ func TestV4AcceptOnV4(t *testing.T) { func testV4ListenClose(t *testing.T, c *context.Context) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 0 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err) + } + const n = uint16(32) // Start listening. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 91f25c132..effbf203f 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -94,6 +94,63 @@ const ( ccCubic = "cubic" ) +// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The +// value is protected by a mutex so that we can increment only when it's +// guaranteed not to go above a threshold. +type synRcvdCounter struct { + sync.Mutex + value uint64 + pending sync.WaitGroup + threshold uint64 +} + +// inc tries to increment the global number of endpoints in SYN-RCVD state. It +// succeeds if the increment doesn't make the count go beyond the threshold, and +// fails otherwise. +func (s *synRcvdCounter) inc() bool { + s.Lock() + defer s.Unlock() + if s.value >= s.threshold { + return false + } + + s.pending.Add(1) + s.value++ + + return true +} + +// dec atomically decrements the global number of endpoints in SYN-RCVD +// state. It must only be called if a previous call to inc succeeded. +func (s *synRcvdCounter) dec() { + s.Lock() + defer s.Unlock() + s.value-- + s.pending.Done() +} + +// synCookiesInUse returns true if the synRcvdCount is greater than +// SynRcvdCountThreshold. +func (s *synRcvdCounter) synCookiesInUse() bool { + s.Lock() + defer s.Unlock() + return s.value >= s.threshold +} + +// SetThreshold sets synRcvdCounter.Threshold to ths new threshold. +func (s *synRcvdCounter) SetThreshold(threshold uint64) { + s.Lock() + defer s.Unlock() + s.threshold = threshold +} + +// Threshold returns the current value of synRcvdCounter.Threhsold. +func (s *synRcvdCounter) Threshold() uint64 { + s.Lock() + defer s.Unlock() + return s.threshold +} + type protocol struct { mu sync.RWMutex sackEnabled bool @@ -106,6 +163,7 @@ type protocol struct { tcpLingerTimeout time.Duration tcpTimeWaitTimeout time.Duration minRTO time.Duration + synRcvdCount synRcvdCounter dispatcher *dispatcher } @@ -282,6 +340,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPSynRcvdCountThresholdOption: + p.mu.Lock() + p.synRcvdCount.SetThreshold(uint64(v)) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -350,6 +414,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil + case *tcpip.TCPSynRcvdCountThresholdOption: + p.mu.RLock() + *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold()) + p.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -365,6 +435,12 @@ func (p *protocol) Wait() { p.dispatcher.wait() } +// SynRcvdCounter returns a reference to the synRcvdCount for this protocol +// instance. +func (p *protocol) SynRcvdCounter() *synRcvdCounter { + return &p.synRcvdCount +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ @@ -374,6 +450,7 @@ func NewProtocol() stack.TransportProtocol { availableCongestionControl: []string{ccReno, ccCubic}, tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), minRTO: MinRTO, } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index c439d5281..1dd63dd61 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -149,21 +149,22 @@ func TestSackPermittedAccept(t *testing.T) { {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() + for _, tc := range testCases { t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } for _, sackEnabled := range []bool{false, true} { t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + + if tc.cookieEnabled { + // Set the SynRcvd threshold to + // zero to force a syn cookie + // based accept to happen. + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } setStackSACKPermitted(t, c, sackEnabled) rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) @@ -222,21 +223,23 @@ func TestSackDisabledAccept(t *testing.T) { {true, -1, 0xffff}, // When cookie is used window scaling is disabled. {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). } - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() + for _, tc := range testCases { t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - if tc.cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } else { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - } for _, sackEnabled := range []bool{false, true} { t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + + if tc.cookieEnabled { + // Set the SynRcvd threshold to + // zero to force a syn cookie + // based accept to happen. + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } + setStackSACKPermitted(t, c, sackEnabled) rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a9f121c17..74fb6e064 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2706,26 +2706,24 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 0 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } // Create EP and start listening. wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Do 3-way handshake. @@ -2743,7 +2741,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -5143,25 +5141,23 @@ func TestListenSynRcvdQueueFull(t *testing.T) { } func TestListenBacklogFullSynCookieInUse(t *testing.T) { - saved := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = saved - }() - tcp.SynRcvdCountThreshold = 1 - c := context.New(t, defaultMTU) defer c.Cleanup() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err) + } + // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } // Bind to wildcard. if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -5169,7 +5165,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { listenBacklog := 1 portOffset := uint16(0) if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } executeHandshake(t, c, context.TestPort+portOffset, false) diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index a641e953d..8edbff964 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -127,16 +127,14 @@ func TestTimeStampDisabledConnect(t *testing.T) { } func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() + c := context.New(t, defaultMTU) + defer c.Cleanup() if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } } - c := context.New(t, defaultMTU) - defer c.Cleanup() t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) tsVal := rand.Uint32() @@ -148,7 +146,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Unexpected error from Write: %s", err) } // Check that data is received and that the timestamp option TSEcr field @@ -190,17 +188,15 @@ func TestTimeStampEnabledAccept(t *testing.T) { } func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - savedSynCountThreshold := tcp.SynRcvdCountThreshold - defer func() { - tcp.SynRcvdCountThreshold = savedSynCountThreshold - }() - if cookieEnabled { - tcp.SynRcvdCountThreshold = 0 - } - c := context.New(t, defaultMTU) defer c.Cleanup() + if cookieEnabled { + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } + t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) @@ -211,7 +207,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Unexpected error from Write: %s", err) } // Check that data is received and that the timestamp option is disabled -- cgit v1.2.3 From 3b05f576d73be644daa17203d9ed64481c45b4a8 Mon Sep 17 00:00:00 2001 From: Mithun Iyer Date: Thu, 16 Apr 2020 17:57:06 -0700 Subject: Reset pending connections on listener shutdown. When the listening socket is read shutdown, we need to reset all pending and incoming connections. Ensure that the endpoint is not cleaned up from the demuxer and subsequent bind to same port does not go through. PiperOrigin-RevId: 306958038 --- pkg/tcpip/transport/tcp/accept.go | 20 +++--- pkg/tcpip/transport/tcp/connect.go | 7 ++- pkg/tcpip/transport/tcp/endpoint.go | 30 ++++++--- pkg/tcpip/transport/tcp/forwarder.go | 2 +- pkg/tcpip/transport/tcp/protocol.go | 8 +-- pkg/tcpip/transport/tcp/tcp_test.go | 16 ++--- test/syscalls/linux/socket_inet_loopback.cc | 94 +++++++++++++++++++++++++++-- 7 files changed, 138 insertions(+), 39 deletions(-) (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index b61c2a8c3..5bb243e3b 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -433,19 +432,16 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { - if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) { + e.rcvListMu.Lock() + rcvClosed := e.rcvClosed + e.rcvListMu.Unlock() + if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { + // If the endpoint is shutdown, reply with reset. + // // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST // must be sent in response to a SYN-ACK while in the listen // state to prevent completing a handshake from an old SYN. - e.sendTCP(&s.route, tcpFields{ - id: s.id, - ttl: e.ttl, - tos: e.sendTOS, - flags: header.TCPFlagRst, - seq: s.ackNumber, - ack: 0, - rcvWnd: 0, - }, buffer.VectorisedView{}, nil) + replyWithReset(s, e.sendTOS, e.ttl) return } @@ -534,7 +530,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // The only time we should reach here when a connection // was opened and closed really quickly and a delayed // ACK was received from the sender. - replyWithReset(s) + replyWithReset(s, e.sendTOS, e.ttl) return } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 994ac52a3..368865911 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1053,10 +1053,15 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route) } if ep == nil { - replyWithReset(s) + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) s.decRef() return } + + if e == ep { + panic("current endpoint not removed from demuxer, enqueing segments to itself") + } + if ep.(*endpoint).enqueueSegment(s) { ep.(*endpoint).newSegmentWaker.Assert() } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index bffc59e9f..5d0ea9e93 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2101,7 +2101,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { switch { case e.EndpointState().connected(): // Close for read. - if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { + if e.shutdownFlags&tcpip.ShutdownRead != 0 { // Mark read side as closed. e.rcvListMu.Lock() e.rcvClosed = true @@ -2110,7 +2110,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. - if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { + if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 { e.resetConnectionLocked(tcpip.ErrConnectionAborted) // Wake up worker to terminate loop. e.notifyProtocolGoroutine(notifyTickleWorker) @@ -2119,7 +2119,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { } // Close for write. - if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { e.sndBufMu.Lock() if e.sndClosed { // Already closed. @@ -2142,12 +2142,23 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { return nil case e.EndpointState() == StateListen: - // Tell protocolListenLoop to stop. - if flags&tcpip.ShutdownRead != 0 { - e.notifyProtocolGoroutine(notifyClose) + if e.shutdownFlags&tcpip.ShutdownRead != 0 { + // Reset all connections from the accept queue and keep the + // worker running so that it can continue handling incoming + // segments by replying with RST. + // + // By not removing this endpoint from the demuxer mapping, we + // ensure that any other bind to the same port fails, as on Linux. + // TODO(gvisor.dev/issue/2468): We need to enable applications to + // start listening on this endpoint again similar to Linux. + e.rcvListMu.Lock() + e.rcvClosed = true + e.rcvListMu.Unlock() + e.closePendingAcceptableConnectionsLocked() + // Notify waiters that the endpoint is shutdown. + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) } return nil - default: return tcpip.ErrNotConnected } @@ -2251,8 +2262,11 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { e.LockUser() defer e.UnlockUser() + e.rcvListMu.Lock() + rcvClosed := e.rcvClosed + e.rcvListMu.Unlock() // Endpoint must be in listen state before it can accept connections. - if e.EndpointState() != StateListen { + if rcvClosed || e.EndpointState() != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 808410c92..704d01c64 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -130,7 +130,7 @@ func (r *ForwarderRequest) Complete(sendReset bool) { // If the caller requested, send a reset. if sendReset { - replyWithReset(r.segment) + replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL()) } // Release all resources. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index effbf203f..cfd9a4e8e 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -223,12 +223,12 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo return true } - replyWithReset(s) + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) return true } // replyWithReset replies to the given segment with a reset segment. -func replyWithReset(s *segment) { +func replyWithReset(s *segment, tos, ttl uint8) { // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) ack := seqnum.Value(0) @@ -252,8 +252,8 @@ func replyWithReset(s *segment) { } sendTCP(&s.route, tcpFields{ id: s.id, - ttl: s.route.DefaultTTL(), - tos: stack.DefaultTOS, + ttl: ttl, + tos: tos, flags: flags, seq: seq, ack: ack, diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 74fb6e064..ab1014c7f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1034,8 +1034,8 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) { checker.SeqNum(200))) } -// TestListenShutdown tests for the listening endpoint not processing -// any receive when it is on read shutdown. +// TestListenShutdown tests for the listening endpoint replying with RST +// on read shutdown. func TestListenShutdown(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -1046,7 +1046,7 @@ func TestListenShutdown(t *testing.T) { t.Fatal("Bind failed:", err) } - if err := c.EP.Listen(10 /* backlog */); err != nil { + if err := c.EP.Listen(1 /* backlog */); err != nil { t.Fatal("Listen failed:", err) } @@ -1054,9 +1054,6 @@ func TestListenShutdown(t *testing.T) { t.Fatal("Shutdown failed:", err) } - // Wait for the endpoint state to be propagated. - time.Sleep(10 * time.Millisecond) - c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -1065,7 +1062,12 @@ func TestListenShutdown(t *testing.T) { AckNum: 200, }) - c.CheckNoPacket("Packet received when listening socket was shutdown") + // Expect the listening endpoint to reset the connection. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) } // TestListenCloseWhileConnect tests for the listening endpoint to diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index cd84e633a..d3000dbc6 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -319,6 +319,75 @@ TEST_P(SocketInetLoopbackTest, TCPListenUnbound) { tcpSimpleConnectTest(listener, connector, false); } +TEST_P(SocketInetLoopbackTest, TCPListenShutdown) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kBacklog = 2; + constexpr int kFDs = kBacklog + 1; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + // Shutdown the write of the listener, expect to not have any effect. + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_WR), SyscallSucceeds()); + + for (int i = 0; i < kFDs; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(client.get(), reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds()); + } + + // Shutdown the read of the listener, expect to fail subsequent + // server accepts, binds and client connects. + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds()); + + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EINVAL)); + + // Check that shutdown did not release the port. + FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + bind(new_listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Check that subsequent connection attempts receive a RST. + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + for (int i = 0; i < kFDs; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(client.get(), reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallFailsWithErrno(ECONNREFUSED)); + } +} + TEST_P(SocketInetLoopbackTest, TCPListenClose) { auto const& param = GetParam(); @@ -365,9 +434,8 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { } } -TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { - auto const& param = GetParam(); - +void TestListenWhileConnect(const TestParam& param, + void (*stopListen)(FileDescriptor&)) { TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -404,8 +472,8 @@ TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { clients.push_back(std::move(client)); } } - // Close the listening socket. - listen_fd.reset(); + + stopListen(listen_fd); for (auto& client : clients) { const int kTimeout = 10000; @@ -420,13 +488,26 @@ TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { char c; // Subsequent read can fail with: // ECONNRESET: If the client connection was established and was reset by the - // remote. ECONNREFUSED: If the client connection failed to be established. + // remote. + // ECONNREFUSED: If the client connection failed to be established. ASSERT_THAT(read(client.get(), &c, sizeof(c)), AnyOf(SyscallFailsWithErrno(ECONNRESET), SyscallFailsWithErrno(ECONNREFUSED))); } } +TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { + TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(close(f.release()), SyscallSucceeds()); + }); +} + +TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) { + TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds()); + }); +} + TEST_P(SocketInetLoopbackTest, TCPbacklog) { auto const& param = GetParam(); @@ -1134,6 +1215,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { if (connects_received >= kConnectAttempts) { // Another thread have shutdown our read side causing the // accept to fail. + ASSERT_EQ(errno, EINVAL); break; } ASSERT_NO_ERRNO(fd); -- cgit v1.2.3 From db2a60be67f0e869a58eb12d253a0d7fe13ebfa3 Mon Sep 17 00:00:00 2001 From: Eyal Soha Date: Sun, 19 Apr 2020 22:14:53 -0700 Subject: Don't accept segments outside the receive window Fixed to match RFC 793 page 69. Fixes #1607 PiperOrigin-RevId: 307334892 --- pkg/tcpip/seqnum/seqnum.go | 5 -- pkg/tcpip/transport/tcp/BUILD | 10 +++ pkg/tcpip/transport/tcp/accept.go | 14 ++--- pkg/tcpip/transport/tcp/connect.go | 15 +---- pkg/tcpip/transport/tcp/endpoint.go | 13 ++++ pkg/tcpip/transport/tcp/rcv.go | 23 +++++-- pkg/tcpip/transport/tcp/rcv_test.go | 74 +++++++++++++++++++++++ pkg/tcpip/transport/tcpconntrack/BUILD | 1 + pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go | 13 +--- test/packetimpact/tests/BUILD | 2 - 10 files changed, 125 insertions(+), 45 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/rcv_test.go (limited to 'pkg/tcpip/transport/tcp/accept.go') diff --git a/pkg/tcpip/seqnum/seqnum.go b/pkg/tcpip/seqnum/seqnum.go index b40a3c212..d3bea7de4 100644 --- a/pkg/tcpip/seqnum/seqnum.go +++ b/pkg/tcpip/seqnum/seqnum.go @@ -46,11 +46,6 @@ func (v Value) InWindow(first Value, size Size) bool { return v.InRange(first, first.Add(size)) } -// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y). -func Overlap(a Value, b Size, x Value, y Size) bool { - return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b)) -} - // Add calculates the sequence number following the [v, v+s) window. func (v Value) Add(s Size) Value { return v + Value(s) diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index edb7718a6..61426623c 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -109,3 +109,13 @@ go_test( "//runsc/testutil", ], ) + +go_test( + name = "rcv_test", + size = "small", + srcs = ["rcv_test.go"], + deps = [ + ":tcp", + "//pkg/tcpip/seqnum", + ], +) diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 5bb243e3b..e6a23c978 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -101,7 +101,7 @@ type listenContext struct { // v6Only is true if listenEP is a dual stack socket and has the // IPV6_V6ONLY option set. - v6only bool + v6Only bool // netProto indicates the network protocol(IPv4/v6) for the listening // endpoint. @@ -126,12 +126,12 @@ func timeStamp() uint32 { } // newListenContext creates a new listen context. -func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ stack: stk, rcvWnd: rcvWnd, hasher: sha1.New(), - v6only: v6only, + v6Only: v6Only, netProto: netProto, listenEP: listenEP, pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), @@ -207,7 +207,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i netProto = s.route.NetProto } n := newEndpoint(l.stack, netProto, queue) - n.v6only = l.v6only + n.v6only = l.v6Only n.ID = s.id n.boundNICID = s.route.NICID() n.route = s.route.Clone() @@ -293,7 +293,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head } // Perform the 3-way handshake. - h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) + h := newPassiveHandshake(ep, ep.rcv.rcvWnd, isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.mu.Unlock() ep.Close() @@ -613,8 +613,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // its own goroutine and is responsible for handling connection requests. func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() - v6only := e.v6only - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) + v6Only := e.v6only + ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 368865911..76e27bf26 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -105,24 +105,11 @@ type handshake struct { } func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { - rcvWndScale := ep.rcvWndScaleForHandshake() - - // Round-down the rcvWnd to a multiple of wndScale. This ensures that the - // window offered in SYN won't be reduced due to the loss of precision if - // window scaling is enabled after the handshake. - rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale) - - // Ensure we can always accept at least 1 byte if the scale specified - // was too high for the provided rcvWnd. - if rcvWnd == 0 { - rcvWnd = 1 - } - h := handshake{ ep: ep, active: true, rcvWnd: rcvWnd, - rcvWndScale: int(rcvWndScale), + rcvWndScale: ep.rcvWndScaleForHandshake(), } h.resetState() return h diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 7e8def82d..45f2aa78b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1062,6 +1062,19 @@ func (e *endpoint) initialReceiveWindow() int { if rcvWnd > routeWnd { rcvWnd = routeWnd } + rcvWndScale := e.rcvWndScaleForHandshake() + + // Round-down the rcvWnd to a multiple of wndScale. This ensures that the + // window offered in SYN won't be reduced due to the loss of precision if + // window scaling is enabled after the handshake. + rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale) + + // Ensure we can always accept at least 1 byte if the scale specified + // was too high for the provided rcvWnd. + if rcvWnd == 0 { + rcvWnd = 1 + } + return rcvWnd } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index caf8977b3..a4b73b588 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -70,13 +70,24 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale // acceptable checks if the segment sequence number range is acceptable // according to the table on page 26 of RFC 793. func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - rcvWnd := r.rcvNxt.Size(r.rcvAcc) - if rcvWnd == 0 { - return segLen == 0 && segSeq == r.rcvNxt - } + return Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc) +} - return segSeq.InWindow(r.rcvNxt, rcvWnd) || - seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen) +// Acceptable checks if a segment that starts at segSeq and has length segLen is +// "acceptable" for arriving in a receive window that starts at rcvNxt and ends +// before rcvAcc, according to the table on page 26 and 69 of RFC 793. +func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool { + if rcvNxt == rcvAcc { + return segLen == 0 && segSeq == rcvNxt + } + if segLen == 0 { + // rcvWnd is incremented by 1 because that is Linux's behavior despite the + // RFC. + return segSeq.InRange(rcvNxt, rcvAcc.Add(1)) + } + // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming + // the payload, so we'll accept any payload that overlaps the receieve window. + return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc) } // getSendParams returns the parameters needed by the sender when building diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go new file mode 100644 index 000000000..dc02729ce --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -0,0 +1,74 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rcv_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" +) + +func TestAcceptable(t *testing.T) { + for _, tt := range []struct { + segSeq seqnum.Value + segLen seqnum.Size + rcvNxt, rcvAcc seqnum.Value + want bool + }{ + // The segment is smaller than the window. + {105, 2, 100, 104, false}, + {105, 2, 101, 105, false}, + {105, 2, 102, 106, true}, + {105, 2, 103, 107, true}, + {105, 2, 104, 108, true}, + {105, 2, 105, 109, true}, + {105, 2, 106, 110, true}, + {105, 2, 107, 111, false}, + + // The segment is larger than the window. + {105, 4, 103, 105, false}, + {105, 4, 104, 106, true}, + {105, 4, 105, 107, true}, + {105, 4, 106, 108, true}, + {105, 4, 107, 109, true}, + {105, 4, 108, 110, true}, + {105, 4, 109, 111, false}, + {105, 4, 110, 112, false}, + + // The segment has no width. + {105, 0, 100, 102, false}, + {105, 0, 101, 103, false}, + {105, 0, 102, 104, false}, + {105, 0, 103, 105, true}, + {105, 0, 104, 106, true}, + {105, 0, 105, 107, true}, + {105, 0, 106, 108, false}, + {105, 0, 107, 109, false}, + + // The receive window has no width. + {105, 2, 103, 103, false}, + {105, 2, 104, 104, false}, + {105, 2, 105, 105, false}, + {105, 2, 106, 106, false}, + {105, 2, 107, 107, false}, + {105, 2, 108, 108, false}, + {105, 2, 109, 109, false}, + } { + if got := tcp.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { + t.Errorf("tcp.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) + } + } +} diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 3ad6994a7..2025ff757 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/tcpip/header", "//pkg/tcpip/seqnum", + "//pkg/tcpip/transport/tcp", ], ) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 93712cd45..30d05200f 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -20,6 +20,7 @@ package tcpconntrack import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" ) // Result is returned when the state of a TCB is updated in response to an @@ -311,17 +312,7 @@ type stream struct { // the window is zero, if it's a packet with no payload and sequence number // equal to una. func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - wnd := s.una.Size(s.end) - if wnd == 0 { - return segLen == 0 && segSeq == s.una - } - - // Make sure [segSeq, seqSeq+segLen) is non-empty. - if segLen == 0 { - segLen = 1 - } - - return seqnum.Overlap(s.una, wnd, segSeq, segLen) + return tcp.Acceptable(segSeq, segLen, s.una, s.end) } // closed determines if the stream has already been closed. This happens when diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 1410512e6..47c722ccd 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -43,8 +43,6 @@ packetimpact_go_test( packetimpact_go_test( name = "tcp_outside_the_window", srcs = ["tcp_outside_the_window_test.go"], - # TODO(eyalsoha): Fix #1607 then remove the line below. - netstack = False, deps = [ "//pkg/tcpip/header", "//pkg/tcpip/seqnum", -- cgit v1.2.3