From aee2c93366f451b9cc0a62430185749556fc3900 Mon Sep 17 00:00:00 2001 From: Jianfeng Tan Date: Thu, 29 Aug 2019 16:23:11 +0000 Subject: netstack: add counters for tcp CurrEstab and EstabResets Signed-off-by: Jianfeng Tan --- pkg/tcpip/transport/tcp/snd.go | 1 + 1 file changed, 1 insertion(+) (limited to 'pkg/tcpip/transport/tcp/snd.go') diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8332a0179..d3f7c9125 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -674,6 +674,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se default: s.ep.state = StateFinWait1 } + s.ep.stack.Stats().TCP.CurrentEstablished.Decrement() s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. -- cgit v1.2.3 From b1d44be7ad893bd6bdfd164a54a7142f4462414b Mon Sep 17 00:00:00 2001 From: Mithun Iyer Date: Fri, 6 Dec 2019 17:15:52 -0800 Subject: Add TCP stats for connection close and keep-alive timeouts. Fix bugs in updates to TCP CurrentEstablished stat. Fixes #1277 PiperOrigin-RevId: 284292459 --- pkg/sentry/socket/netstack/netstack.go | 2 ++ pkg/tcpip/tcpip.go | 8 ++++++ pkg/tcpip/transport/tcp/connect.go | 5 ++-- pkg/tcpip/transport/tcp/snd.go | 1 - pkg/tcpip/transport/tcp/tcp_test.go | 46 ++++++++++++++++++++++++++++++++++ 5 files changed, 58 insertions(+), 4 deletions(-) (limited to 'pkg/tcpip/transport/tcp/snd.go') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index d92399efd..fe5a46aa3 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -151,6 +151,8 @@ var Metrics = tcpip.Stats{ PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), + EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."), + EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 5746043cc..d5bb5b6ed 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -924,6 +924,14 @@ type TCPStats struct { // ESTABLISHED state or the CLOSE-WAIT state. EstablishedResets *StatCounter + // EstablishedClosed is the number of times established TCP connections + // made a transition to CLOSED state. + EstablishedClosed *StatCounter + + // EstablishedTimedout is the number of times an established connection + // was reset because of keep-alive time out. + EstablishedTimedout *StatCounter + // ListenOverflowSynDrop is the number of times the listen queue overflowed // and a SYN was dropped. ListenOverflowSynDrop *StatCounter diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 2975a1c3c..3d059c302 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -924,6 +924,7 @@ func (e *endpoint) transitionToStateCloseLocked() { } e.cleanupLocked() e.state = StateClose + e.stack.Stats().TCP.EstablishedClosed.Increment() } // tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed @@ -1094,6 +1095,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } @@ -1179,8 +1181,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateError e.HardError = err @@ -1389,7 +1389,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() if e.state != StateError { - e.stack.Stats().TCP.EstablishedResets.Increment() e.stack.Stats().TCP.CurrentEstablished.Decrement() e.transitionToStateCloseLocked() } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index d3f7c9125..8332a0179 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -674,7 +674,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se default: s.ep.state = StateFinWait1 } - s.ep.stack.Stats().TCP.CurrentEstablished.Decrement() s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 52c2fa7e3..bc5cfcf0e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -75,6 +75,20 @@ func TestGiveUpConnect(t *testing.T) { if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) } + + // Call Connect again to retreive the handshake failure status + // and stats updates. + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { + t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted) + } + + if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestConnectIncrementActiveConnection(t *testing.T) { @@ -548,6 +562,14 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } + // Check if the endpoint was moved to CLOSED and netstack a reset in // response to the ACK packet that we sent after last-ACK. checker.IPv4(t, c.GetPacket(), @@ -2694,6 +2716,13 @@ loop: if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } + + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestSendOnResetConnection(t *testing.T) { @@ -4363,9 +4392,17 @@ func TestKeepalive(t *testing.T) { ), ) + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) + } + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { @@ -5992,6 +6029,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) } + want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 + wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { @@ -6120,6 +6159,13 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.SeqNum(uint32(ackHeaders.AckNum)), checker.AckNum(uint32(ackHeaders.SeqNum)), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestTCPCloseWithData(t *testing.T) { -- cgit v1.2.3 From 6fc9f0aefd89ce42ef2c38ea7853f9ba7c4bee04 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 11 Dec 2019 17:51:37 -0800 Subject: Add support for TCP_USER_TIMEOUT option. The implementation follows the linux behavior where specifying a TCP_USER_TIMEOUT will cause the resend timer to honor the user specified timeout rather than the default rto based timeout. Further it alters when connections are timedout due to keepalive failures. It does not alter the behavior of when keepalives are sent. This is as per the linux behavior. PiperOrigin-RevId: 285099795 --- pkg/sentry/socket/netstack/netstack.go | 23 ++++ pkg/tcpip/tcpip.go | 5 + pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/accept.go | 15 +++ pkg/tcpip/transport/tcp/connect.go | 19 ++- pkg/tcpip/transport/tcp/endpoint.go | 19 +++ pkg/tcpip/transport/tcp/protocol.go | 21 ++- pkg/tcpip/transport/tcp/rcv.go | 19 ++- pkg/tcpip/transport/tcp/rcv_state.go | 29 ++++ pkg/tcpip/transport/tcp/snd.go | 48 ++++++- pkg/tcpip/transport/tcp/snd_state.go | 10 ++ pkg/tcpip/transport/tcp/tcp_test.go | 194 ++++++++++++++++++++++++--- test/syscalls/linux/socket_inet_loopback.cc | 56 +++++++- test/syscalls/linux/socket_ip_tcp_generic.cc | 63 +++++++++ test/syscalls/linux/tcp_socket.cc | 25 ++++ 15 files changed, 509 insertions(+), 38 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/rcv_state.go (limited to 'pkg/tcpip/transport/tcp/snd.go') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index fe5a46aa3..8a6522eac 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1127,6 +1127,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_USER_TIMEOUT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPUserTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Millisecond), nil + case linux.TCP_INFO: var v tcpip.TCPInfoOption if err := ep.GetSockOpt(&v); err != nil { @@ -1563,6 +1575,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_USER_TIMEOUT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) if err := ep.SetSockOpt(v); err != nil { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index d5bb5b6ed..f62fd729f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -576,6 +576,11 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user +// specified timeout for a given TCP connection. +// See: RFC5482 for details. +type TCPUserTimeoutOption time.Duration + // CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get // the current congestion control algorithm. type CongestionControlOption string diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 455a1c098..3b353d56c 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -28,6 +28,7 @@ go_library( "forwarder.go", "protocol.go", "rcv.go", + "rcv_state.go", "reno.go", "sack.go", "sack_scoreboard.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 74df3edfb..5422ae80c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -242,6 +242,13 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() + // Now inherit any socket options that should be inherited from the + // listening endpoint. + // In case of Forwarder listenEP will be nil and hence this check. + if l.listenEP != nil { + l.listenEP.propagateInheritableOptions(n) + } + // Register new endpoint so that packets are routed to it. if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { n.Close() @@ -350,6 +357,14 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } } +// propagateInheritableOptions propagates any options set on the listening +// endpoint to the newly created endpoint. +func (e *endpoint) propagateInheritableOptions(n *endpoint) { + e.mu.Lock() + n.userTimeout = e.userTimeout + e.mu.Unlock() +} + // handleSynSegment is called in its own goroutine once the listening endpoint // receives a SYN segment. It is responsible for completing the handshake and // queueing the new endpoint for acceptance. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 3d059c302..4c34fc9d2 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -862,7 +862,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { } e.state = StateError e.HardError = err - if err != tcpip.ErrConnectionReset { + if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk // which can cause sndNxt to be outside the acceptable window on the @@ -1087,12 +1087,24 @@ func (e *endpoint) handleSegments() *tcpip.Error { // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + e.mu.RLock() + userTimeout := e.userTimeout + e.mu.RUnlock() + e.keepalive.Lock() if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } + // If a userTimeout is set then abort the connection if it is + // exceeded. + if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return tcpip.ErrTimeout + } + if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() @@ -1112,7 +1124,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { // whether it is enabled for this endpoint. func (e *endpoint) resetKeepaliveTimer(receivedData bool) { e.keepalive.Lock() - defer e.keepalive.Unlock() if receivedData { e.keepalive.unacked = 0 } @@ -1120,6 +1131,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { // data to send. if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() + e.keepalive.Unlock() return } if e.keepalive.unacked > 0 { @@ -1127,6 +1139,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } else { e.keepalive.timer.enable(e.keepalive.idle) } + e.keepalive.Unlock() } // disableKeepaliveTimer stops the keepalive timer. @@ -1239,6 +1252,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { w: &e.snd.resendWaker, f: func() *tcpip.Error { if !e.snd.retransmitTimerExpired() { + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } return nil @@ -1405,6 +1419,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if s == nil { break } + e.tryDeliverSegmentFromClosedEndpoint(s) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4861ab513..dd8b47cbe 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -341,6 +341,7 @@ type endpoint struct { // TCP should never broadcast but Linux nevertheless supports enabling/ // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // Values used to reserve a port or register a transport endpoint // (which ever happens first). boundBindToDevice tcpip.NICID @@ -474,6 +475,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // userTimeout if non-zero specifies a user specified timeout for + // a connection w/ pending data to send. A connection that has pending + // unacked data will be forcibily aborted if the timeout is reached + // without any data being acked. + userTimeout time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -1333,6 +1340,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyKeepaliveChanged) return nil + case tcpip.TCPUserTimeoutOption: + e.mu.Lock() + e.userTimeout = time.Duration(v) + e.mu.Unlock() + return nil + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 @@ -1591,6 +1604,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.keepalive.Unlock() return nil + case *tcpip.TCPUserTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPUserTimeoutOption(e.userTimeout) + e.mu.Unlock() + return nil + case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 89b965c23..bc718064c 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -162,13 +162,26 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo func replyWithReset(s *segment) { // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) + ack := seqnum.Value(0) + flags := byte(header.TCPFlagRst) + // As per RFC 793 page 35 (Reset Generation) + // 1. If the connection does not exist (CLOSED) then a reset is sent + // in response to any incoming segment except another reset. In + // particular, SYNs addressed to a non-existent connection are rejected + // by this means. + + // If the incoming segment has an ACK field, the reset takes its + // sequence number from the ACK field of the segment, otherwise the + // reset has sequence number zero and the ACK field is set to the sum + // of the sequence number and segment length of the incoming segment. + // The connection remains in the CLOSED state. if s.flagIsSet(header.TCPFlagAck) { seq = s.ackNumber + } else { + flags |= header.TCPFlagAck + ack = s.sequenceNumber.Add(s.logicalLen()) } - - ack := s.sequenceNumber.Add(s.logicalLen()) - - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 5ee499c36..0a5534959 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -50,16 +50,20 @@ type receiver struct { pendingRcvdSegments segmentHeap pendingBufUsed seqnum.Size pendingBufSize seqnum.Size + + // Time when the last ack was received. + lastRcvdAckTime time.Time `state:".(unixTime)"` } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), - rcvWnd: rcvWnd, - rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWnd: rcvWnd, + rcvWndScale: rcvWndScale, + pendingBufSize: pendingBufSize, + lastRcvdAckTime: time.Now(), } } @@ -360,6 +364,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { return true, nil } + // Store the time of the last ack. + r.lastRcvdAckTime = time.Now() + // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go new file mode 100644 index 000000000..2bf21a2e7 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveLastRcvdAckTime is invoked by stateify. +func (r *receiver) saveLastRcvdAckTime() unixTime { + return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} +} + +// loadLastRcvdAckTime is invoked by stateify. +func (r *receiver) loadLastRcvdAckTime(unix unixTime) { + r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8332a0179..8a947dc66 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -28,8 +28,11 @@ import ( ) const ( - // minRTO is the minimum allowed value for the retransmit timeout. - minRTO = 200 * time.Millisecond + // MinRTO is the minimum allowed value for the retransmit timeout. + MinRTO = 200 * time.Millisecond + + // MaxRTO is the maximum allowed value for the retransmit timeout. + MaxRTO = 120 * time.Second // InitialCwnd is the initial congestion window. InitialCwnd = 10 @@ -134,6 +137,10 @@ type sender struct { // rttMeasureTime is the time when the rttMeasureSeqNum was sent. rttMeasureTime time.Time `state:".(unixTime)"` + // firstRetransmittedSegXmitTime is the original transmit time of + // the first segment that was retransmitted due to RTO expiration. + firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + closed bool writeNext *segment writeList segmentList @@ -392,8 +399,8 @@ func (s *sender) updateRTO(rtt time.Duration) { s.rto = s.rtt.srtt + 4*s.rtt.rttvar s.rtt.Unlock() - if s.rto < minRTO { - s.rto = minRTO + if s.rto < MinRTO { + s.rto = MinRTO } } @@ -438,8 +445,30 @@ func (s *sender) retransmitTimerExpired() bool { s.ep.stack.Stats().TCP.Timeouts.Increment() s.ep.stats.SendErrors.Timeouts.Increment() - // Give up if we've waited more than a minute since the last resend. - if s.rto >= 60*time.Second { + // Give up if we've waited more than a minute since the last resend or + // if a user time out is set and we have exceeded the user specified + // timeout since the first retransmission. + s.ep.mu.RLock() + uto := s.ep.userTimeout + s.ep.mu.RUnlock() + + if s.firstRetransmittedSegXmitTime.IsZero() { + // We store the original xmitTime of the segment that we are + // about to retransmit as the retransmission time. This is + // required as by the time the retransmitTimer has expired the + // segment has already been sent and unacked for the RTO at the + // time the segment was sent. + s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime + } + + elapsed := time.Since(s.firstRetransmittedSegXmitTime) + remaining := MaxRTO + if uto != 0 { + // Cap to the user specified timeout if one is specified. + remaining = uto - elapsed + } + + if remaining <= 0 || s.rto >= MaxRTO { return false } @@ -447,6 +476,11 @@ func (s *sender) retransmitTimerExpired() bool { // below. s.rto *= 2 + // Cap RTO to remaining time. + if s.rto > remaining { + s.rto = remaining + } + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. // // Retransmit timeouts: @@ -1168,6 +1202,8 @@ func (s *sender) handleRcvdSegment(seg *segment) { // RFC 6298 Rule 5.3 if s.sndUna == s.sndNxt { s.outstanding = 0 + // Reset firstRetransmittedSegXmitTime to the zero value. + s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() } } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index 12eff8afc..8b20c3455 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) { func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) } + +// saveFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { + return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} +} + +// loadFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { + s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index bc5cfcf0e..2a83f7bcc 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -323,8 +323,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) } func TestTCPResetsReceivedIncrement(t *testing.T) { @@ -460,18 +460,17 @@ func TestConnectResetAfterClose(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) break } } -// TestClosingWithEnqueuedSegments tests handling of -// still enqueued segments when the endpoint transitions -// to StateClose. The in-flight segments would be re-enqueued -// to a any listening endpoint. +// TestClosingWithEnqueuedSegments tests handling of still enqueued segments +// when the endpoint transitions to StateClose. The in-flight segments would be +// re-enqueued to a any listening endpoint. func TestClosingWithEnqueuedSegments(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -576,8 +575,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(793), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) } @@ -914,7 +913,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPFlags(header.TCPFlagRst), checker.SeqNum(200))) } @@ -942,7 +941,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPFlags(header.TCPFlagRst), checker.SeqNum(200))) } @@ -4291,8 +4290,9 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) @@ -4382,13 +4382,29 @@ func TestKeepalive(t *testing.T) { ) } + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + // The connection should be terminated after 5 unacked keepalives. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(next)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), ), ) @@ -6157,8 +6173,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(uint32(ackHeaders.SeqNum)), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) @@ -6336,7 +6352,147 @@ func TestTCPCloseWithData(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(uint32(ackHeaders.SeqNum)), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + +func TestTCPUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + userTimeout := 50 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Send some data and wait before ACKing it. + view := buffer.NewView(3) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Wait for a little over the minimum retransmit timeout of 200ms for + // the retransmitTimer to fire and close the connection. + time.Sleep(tcp.MinRTO + 10*time.Millisecond) + + // No packet should be received as the connection should be silently + // closed due to timeout. + c.CheckNoPacket("unexpected packet received after userTimeout has expired") + + next += uint32(len(view)) + + // The connection should be terminated after userTimeout has expired. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } +} + +func TestKeepaliveWithUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + const keepAliveInterval = 10 * time.Millisecond + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) + c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + + // Set userTimeout to be the duration for 3 keepalive probes. + userTimeout := 30 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Check that the connection is still alive. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Now receive 2 keepalives, but don't ACK them. The connection should + // be reset when the 3rd one should be sent due to userTimeout being + // 30ms and each keepalive probe should be sent 10ms apart as set above after + // the connection has been idle for 10ms. + for i := 0; i < 2; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + + // The connection should be terminated after 30ms. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(c.IRS + 1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index fa4358ae4..761c3a9fe 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -206,7 +206,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { } // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked // before function end. - // ds.reset() + // ds.reset(); } TEST_P(SocketInetLoopbackTest, TCPbacklog) { @@ -603,6 +603,60 @@ TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { SyscallSucceeds()); } +TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the userTimeout on the listening socket. + constexpr int kUserTimeout = 10; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kUserTimeout, sizeof(kUserTimeout)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + // Verify that the accepted socket inherited the user timeout set on + // listening socket. + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(accepted.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kUserTimeout); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index c74273436..57ce8e169 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -812,5 +812,68 @@ TEST_P(TCPSocketPairTest, TestTCPCloseWithData) { ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); } +TEST_P(TCPSocketPairTest, TCPUserTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kZero, sizeof(kZero)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutBelowZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kNeg = -10; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kNeg, sizeof(kNeg)), + SyscallFailsWithErrno(EINVAL)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kAbove = 10; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kAbove, sizeof(kAbove)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kAbove); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 99863b0ed..c503f3568 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1175,6 +1175,31 @@ TEST_P(SimpleTcpSocketTest, SetMaxSegFailsForInvalidMSSValues) { } } +TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + { + constexpr int kTCPUserTimeout = -1; + EXPECT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kTCPUserTimeout, sizeof(kTCPUserTimeout)), + SyscallFailsWithErrno(EINVAL)); + } + + // kTCPUserTimeout is in milliseconds. + constexpr int kTCPUserTimeout = 100; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kTCPUserTimeout, sizeof(kTCPUserTimeout)), + SyscallSucceedsWithValue(0)); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPUserTimeout); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3 From 27500d529f7fb87eef8812278fd1bbca67bcba72 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Thu, 9 Jan 2020 22:00:42 -0800 Subject: New sync package. * Rename syncutil to sync. * Add aliases to sync types. * Replace existing usage of standard library sync package. This will make it easier to swap out synchronization primitives. For example, this will allow us to use primitives from github.com/sasha-s/go-deadlock to check for lock ordering violations. Updates #1472 PiperOrigin-RevId: 289033387 --- pkg/amutex/BUILD | 1 + pkg/amutex/amutex_test.go | 3 +- pkg/atomicbitops/BUILD | 1 + pkg/atomicbitops/atomic_bitops_test.go | 3 +- pkg/compressio/BUILD | 5 +- pkg/compressio/compressio.go | 2 +- pkg/control/server/BUILD | 1 + pkg/control/server/server.go | 2 +- pkg/eventchannel/BUILD | 2 + pkg/eventchannel/event.go | 2 +- pkg/eventchannel/event_test.go | 2 +- pkg/fdchannel/BUILD | 1 + pkg/fdchannel/fdchannel_test.go | 3 +- pkg/fdnotifier/BUILD | 1 + pkg/fdnotifier/fdnotifier.go | 2 +- pkg/flipcall/BUILD | 3 +- pkg/flipcall/flipcall_example_test.go | 3 +- pkg/flipcall/flipcall_test.go | 3 +- pkg/flipcall/flipcall_unsafe.go | 10 +- pkg/gate/BUILD | 1 + pkg/gate/gate_test.go | 2 +- pkg/linewriter/BUILD | 1 + pkg/linewriter/linewriter.go | 3 +- pkg/log/BUILD | 5 +- pkg/log/log.go | 2 +- pkg/metric/BUILD | 1 + pkg/metric/metric.go | 2 +- pkg/p9/BUILD | 1 + pkg/p9/client.go | 2 +- pkg/p9/p9test/BUILD | 2 + pkg/p9/p9test/client_test.go | 2 +- pkg/p9/p9test/p9test.go | 2 +- pkg/p9/path_tree.go | 3 +- pkg/p9/pool.go | 2 +- pkg/p9/server.go | 2 +- pkg/p9/transport.go | 2 +- pkg/procid/BUILD | 2 + pkg/procid/procid_test.go | 3 +- pkg/rand/BUILD | 5 +- pkg/rand/rand_linux.go | 2 +- pkg/refs/BUILD | 2 + pkg/refs/refcounter.go | 2 +- pkg/refs/refcounter_test.go | 3 +- pkg/sentry/arch/BUILD | 1 + pkg/sentry/arch/arch_x86.go | 2 +- pkg/sentry/control/BUILD | 1 + pkg/sentry/control/pprof.go | 2 +- pkg/sentry/device/BUILD | 5 +- pkg/sentry/device/device.go | 2 +- pkg/sentry/fs/BUILD | 3 +- pkg/sentry/fs/copy_up.go | 2 +- pkg/sentry/fs/copy_up_test.go | 2 +- pkg/sentry/fs/dirent.go | 2 +- pkg/sentry/fs/dirent_cache.go | 3 +- pkg/sentry/fs/dirent_cache_limiter.go | 3 +- pkg/sentry/fs/fdpipe/BUILD | 1 + pkg/sentry/fs/fdpipe/pipe.go | 2 +- pkg/sentry/fs/fdpipe/pipe_state.go | 2 +- pkg/sentry/fs/file.go | 2 +- pkg/sentry/fs/file_overlay.go | 2 +- pkg/sentry/fs/filesystems.go | 2 +- pkg/sentry/fs/fs.go | 3 +- pkg/sentry/fs/fsutil/BUILD | 1 + pkg/sentry/fs/fsutil/host_file_mapper.go | 2 +- pkg/sentry/fs/fsutil/host_mappable.go | 2 +- pkg/sentry/fs/fsutil/inode.go | 3 +- pkg/sentry/fs/fsutil/inode_cached.go | 2 +- pkg/sentry/fs/gofer/BUILD | 1 + pkg/sentry/fs/gofer/inode.go | 2 +- pkg/sentry/fs/gofer/session.go | 2 +- pkg/sentry/fs/host/BUILD | 1 + pkg/sentry/fs/host/inode.go | 2 +- pkg/sentry/fs/host/socket.go | 2 +- pkg/sentry/fs/host/tty.go | 3 +- pkg/sentry/fs/inode.go | 3 +- pkg/sentry/fs/inode_inotify.go | 3 +- pkg/sentry/fs/inotify.go | 2 +- pkg/sentry/fs/inotify_watch.go | 2 +- pkg/sentry/fs/lock/BUILD | 1 + pkg/sentry/fs/lock/lock.go | 2 +- pkg/sentry/fs/mounts.go | 2 +- pkg/sentry/fs/overlay.go | 5 +- pkg/sentry/fs/proc/BUILD | 1 + pkg/sentry/fs/proc/seqfile/BUILD | 1 + pkg/sentry/fs/proc/seqfile/seqfile.go | 2 +- pkg/sentry/fs/proc/sys_net.go | 2 +- pkg/sentry/fs/ramfs/BUILD | 1 + pkg/sentry/fs/ramfs/dir.go | 2 +- pkg/sentry/fs/restore.go | 2 +- pkg/sentry/fs/tmpfs/BUILD | 1 + pkg/sentry/fs/tmpfs/inode_file.go | 2 +- pkg/sentry/fs/tty/BUILD | 1 + pkg/sentry/fs/tty/dir.go | 2 +- pkg/sentry/fs/tty/line_discipline.go | 2 +- pkg/sentry/fs/tty/queue.go | 3 +- pkg/sentry/fsimpl/ext/BUILD | 1 + pkg/sentry/fsimpl/ext/directory.go | 3 +- pkg/sentry/fsimpl/ext/filesystem.go | 2 +- pkg/sentry/fsimpl/ext/regular_file.go | 2 +- pkg/sentry/fsimpl/kernfs/BUILD | 2 + pkg/sentry/fsimpl/kernfs/inode_impl_util.go | 2 +- pkg/sentry/fsimpl/kernfs/kernfs.go | 2 +- pkg/sentry/fsimpl/kernfs/kernfs_test.go | 2 +- pkg/sentry/fsimpl/tmpfs/BUILD | 1 + pkg/sentry/fsimpl/tmpfs/regular_file.go | 2 +- pkg/sentry/fsimpl/tmpfs/tmpfs.go | 2 +- pkg/sentry/kernel/BUILD | 5 +- pkg/sentry/kernel/abstract_socket_namespace.go | 2 +- pkg/sentry/kernel/auth/BUILD | 3 +- pkg/sentry/kernel/auth/user_namespace.go | 2 +- pkg/sentry/kernel/epoll/BUILD | 1 + pkg/sentry/kernel/epoll/epoll.go | 2 +- pkg/sentry/kernel/eventfd/BUILD | 1 + pkg/sentry/kernel/eventfd/eventfd.go | 2 +- pkg/sentry/kernel/fasync/BUILD | 1 + pkg/sentry/kernel/fasync/fasync.go | 3 +- pkg/sentry/kernel/fd_table.go | 2 +- pkg/sentry/kernel/fd_table_test.go | 2 +- pkg/sentry/kernel/fs_context.go | 2 +- pkg/sentry/kernel/futex/BUILD | 8 +- pkg/sentry/kernel/futex/futex.go | 3 +- pkg/sentry/kernel/futex/futex_test.go | 2 +- pkg/sentry/kernel/kernel.go | 2 +- pkg/sentry/kernel/memevent/BUILD | 1 + pkg/sentry/kernel/memevent/memory_events.go | 2 +- pkg/sentry/kernel/pipe/BUILD | 1 + pkg/sentry/kernel/pipe/buffer.go | 2 +- pkg/sentry/kernel/pipe/node.go | 3 +- pkg/sentry/kernel/pipe/pipe.go | 2 +- pkg/sentry/kernel/pipe/pipe_util.go | 2 +- pkg/sentry/kernel/pipe/vfs.go | 3 +- pkg/sentry/kernel/semaphore/BUILD | 1 + pkg/sentry/kernel/semaphore/semaphore.go | 2 +- pkg/sentry/kernel/shm/BUILD | 1 + pkg/sentry/kernel/shm/shm.go | 2 +- pkg/sentry/kernel/signal_handlers.go | 3 +- pkg/sentry/kernel/signalfd/BUILD | 1 + pkg/sentry/kernel/signalfd/signalfd.go | 3 +- pkg/sentry/kernel/syscalls.go | 2 +- pkg/sentry/kernel/syslog.go | 3 +- pkg/sentry/kernel/task.go | 5 +- pkg/sentry/kernel/thread_group.go | 2 +- pkg/sentry/kernel/threads.go | 2 +- pkg/sentry/kernel/time/BUILD | 1 + pkg/sentry/kernel/time/time.go | 2 +- pkg/sentry/kernel/timekeeper.go | 2 +- pkg/sentry/kernel/tty.go | 2 +- pkg/sentry/kernel/uts_namespace.go | 3 +- pkg/sentry/limits/BUILD | 1 + pkg/sentry/limits/limits.go | 3 +- pkg/sentry/mm/BUILD | 2 +- pkg/sentry/mm/aio_context.go | 3 +- pkg/sentry/mm/mm.go | 8 +- pkg/sentry/pgalloc/BUILD | 1 + pkg/sentry/pgalloc/pgalloc.go | 2 +- pkg/sentry/platform/interrupt/BUILD | 1 + pkg/sentry/platform/interrupt/interrupt.go | 3 +- pkg/sentry/platform/kvm/BUILD | 1 + pkg/sentry/platform/kvm/address_space.go | 2 +- pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go | 2 - pkg/sentry/platform/kvm/kvm.go | 2 +- pkg/sentry/platform/kvm/machine.go | 2 +- pkg/sentry/platform/ptrace/BUILD | 1 + pkg/sentry/platform/ptrace/ptrace.go | 2 +- pkg/sentry/platform/ptrace/subprocess.go | 2 +- .../platform/ptrace/subprocess_linux_unsafe.go | 2 +- pkg/sentry/platform/ring0/defs.go | 2 +- pkg/sentry/platform/ring0/defs_amd64.go | 1 + pkg/sentry/platform/ring0/defs_arm64.go | 1 + pkg/sentry/platform/ring0/pagetables/BUILD | 5 +- pkg/sentry/platform/ring0/pagetables/pcids_x86.go | 2 +- pkg/sentry/socket/netlink/BUILD | 1 + pkg/sentry/socket/netlink/port/BUILD | 1 + pkg/sentry/socket/netlink/port/port.go | 3 +- pkg/sentry/socket/netlink/socket.go | 2 +- pkg/sentry/socket/netstack/BUILD | 1 + pkg/sentry/socket/netstack/netstack.go | 2 +- pkg/sentry/socket/rpcinet/conn/BUILD | 1 + pkg/sentry/socket/rpcinet/conn/conn.go | 2 +- pkg/sentry/socket/rpcinet/notifier/BUILD | 1 + pkg/sentry/socket/rpcinet/notifier/notifier.go | 2 +- pkg/sentry/socket/unix/transport/BUILD | 1 + pkg/sentry/socket/unix/transport/connectioned.go | 3 +- pkg/sentry/socket/unix/transport/queue.go | 3 +- pkg/sentry/socket/unix/transport/unix.go | 2 +- pkg/sentry/syscalls/linux/BUILD | 1 + pkg/sentry/syscalls/linux/error.go | 2 +- pkg/sentry/time/BUILD | 4 +- pkg/sentry/time/calibrated_clock.go | 2 +- pkg/sentry/usage/BUILD | 1 + pkg/sentry/usage/memory.go | 2 +- pkg/sentry/vfs/BUILD | 3 +- pkg/sentry/vfs/dentry.go | 2 +- pkg/sentry/vfs/file_description_impl_util.go | 2 +- pkg/sentry/vfs/mount_test.go | 3 +- pkg/sentry/vfs/mount_unsafe.go | 4 +- pkg/sentry/vfs/pathname.go | 3 +- pkg/sentry/vfs/resolving_path.go | 2 +- pkg/sentry/vfs/vfs.go | 2 +- pkg/sentry/watchdog/BUILD | 1 + pkg/sentry/watchdog/watchdog.go | 2 +- pkg/sync/BUILD | 53 +++++++ pkg/sync/LICENSE | 27 ++++ pkg/sync/README.md | 5 + pkg/sync/aliases.go | 37 +++++ pkg/sync/atomicptr_unsafe.go | 47 +++++++ pkg/sync/atomicptrtest/BUILD | 29 ++++ pkg/sync/atomicptrtest/atomicptr_test.go | 31 +++++ pkg/sync/downgradable_rwmutex_test.go | 150 ++++++++++++++++++++ pkg/sync/downgradable_rwmutex_unsafe.go | 146 ++++++++++++++++++++ pkg/sync/memmove_unsafe.go | 28 ++++ pkg/sync/norace_unsafe.go | 35 +++++ pkg/sync/race_unsafe.go | 41 ++++++ pkg/sync/seqatomic_unsafe.go | 72 ++++++++++ pkg/sync/seqatomictest/BUILD | 33 +++++ pkg/sync/seqatomictest/seqatomic_test.go | 132 ++++++++++++++++++ pkg/sync/seqcount.go | 149 ++++++++++++++++++++ pkg/sync/seqcount_test.go | 153 +++++++++++++++++++++ pkg/sync/syncutil.go | 7 + pkg/syncutil/BUILD | 52 ------- pkg/syncutil/LICENSE | 27 ---- pkg/syncutil/README.md | 5 - pkg/syncutil/atomicptr_unsafe.go | 47 ------- pkg/syncutil/atomicptrtest/BUILD | 29 ---- pkg/syncutil/atomicptrtest/atomicptr_test.go | 31 ----- pkg/syncutil/downgradable_rwmutex_test.go | 150 -------------------- pkg/syncutil/downgradable_rwmutex_unsafe.go | 146 -------------------- pkg/syncutil/memmove_unsafe.go | 28 ---- pkg/syncutil/norace_unsafe.go | 35 ----- pkg/syncutil/race_unsafe.go | 41 ------ pkg/syncutil/seqatomic_unsafe.go | 72 ---------- pkg/syncutil/seqatomictest/BUILD | 35 ----- pkg/syncutil/seqatomictest/seqatomic_test.go | 132 ------------------ pkg/syncutil/seqcount.go | 149 -------------------- pkg/syncutil/seqcount_test.go | 153 --------------------- pkg/syncutil/syncutil.go | 7 - pkg/tcpip/BUILD | 1 + pkg/tcpip/adapters/gonet/BUILD | 1 + pkg/tcpip/adapters/gonet/gonet.go | 2 +- pkg/tcpip/link/fdbased/BUILD | 1 + pkg/tcpip/link/fdbased/endpoint.go | 2 +- pkg/tcpip/link/sharedmem/BUILD | 2 + pkg/tcpip/link/sharedmem/pipe/BUILD | 1 + pkg/tcpip/link/sharedmem/pipe/pipe_test.go | 3 +- pkg/tcpip/link/sharedmem/sharedmem.go | 2 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/network/fragmentation/BUILD | 1 + pkg/tcpip/network/fragmentation/fragmentation.go | 2 +- pkg/tcpip/network/fragmentation/reassembler.go | 2 +- pkg/tcpip/ports/BUILD | 1 + pkg/tcpip/ports/ports.go | 2 +- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/linkaddrcache.go | 2 +- pkg/tcpip/stack/linkaddrcache_test.go | 2 +- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/stack/transport_demuxer.go | 2 +- pkg/tcpip/tcpip.go | 2 +- pkg/tcpip/transport/icmp/BUILD | 1 + pkg/tcpip/transport/icmp/endpoint.go | 3 +- pkg/tcpip/transport/packet/BUILD | 1 + pkg/tcpip/transport/packet/endpoint.go | 3 +- pkg/tcpip/transport/raw/BUILD | 1 + pkg/tcpip/transport/raw/endpoint.go | 3 +- pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/accept.go | 2 +- pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 2 +- pkg/tcpip/transport/tcp/endpoint_state.go | 2 +- pkg/tcpip/transport/tcp/forwarder.go | 3 +- pkg/tcpip/transport/tcp/protocol.go | 2 +- pkg/tcpip/transport/tcp/segment_queue.go | 2 +- pkg/tcpip/transport/tcp/snd.go | 2 +- pkg/tcpip/transport/udp/BUILD | 1 + pkg/tcpip/transport/udp/endpoint.go | 3 +- pkg/tmutex/BUILD | 1 + pkg/tmutex/tmutex_test.go | 3 +- pkg/unet/BUILD | 1 + pkg/unet/unet_test.go | 3 +- pkg/urpc/BUILD | 1 + pkg/urpc/urpc.go | 2 +- pkg/waiter/BUILD | 1 + pkg/waiter/waiter.go | 2 +- runsc/boot/BUILD | 2 + runsc/boot/compat.go | 2 +- runsc/boot/limits.go | 2 +- runsc/boot/loader.go | 2 +- runsc/boot/loader_test.go | 2 +- runsc/cmd/BUILD | 1 + runsc/cmd/create.go | 1 + runsc/cmd/gofer.go | 2 +- runsc/cmd/start.go | 1 + runsc/container/BUILD | 2 + runsc/container/console_test.go | 2 +- runsc/container/container_test.go | 2 +- runsc/container/multi_container_test.go | 2 +- runsc/container/state_file.go | 2 +- runsc/fsgofer/BUILD | 1 + runsc/fsgofer/fsgofer.go | 2 +- runsc/sandbox/BUILD | 1 + runsc/sandbox/sandbox.go | 2 +- runsc/testutil/BUILD | 1 + runsc/testutil/testutil.go | 2 +- 303 files changed, 1507 insertions(+), 1368 deletions(-) create mode 100644 pkg/sync/BUILD create mode 100644 pkg/sync/LICENSE create mode 100644 pkg/sync/README.md create mode 100644 pkg/sync/aliases.go create mode 100644 pkg/sync/atomicptr_unsafe.go create mode 100644 pkg/sync/atomicptrtest/BUILD create mode 100644 pkg/sync/atomicptrtest/atomicptr_test.go create mode 100644 pkg/sync/downgradable_rwmutex_test.go create mode 100644 pkg/sync/downgradable_rwmutex_unsafe.go create mode 100644 pkg/sync/memmove_unsafe.go create mode 100644 pkg/sync/norace_unsafe.go create mode 100644 pkg/sync/race_unsafe.go create mode 100644 pkg/sync/seqatomic_unsafe.go create mode 100644 pkg/sync/seqatomictest/BUILD create mode 100644 pkg/sync/seqatomictest/seqatomic_test.go create mode 100644 pkg/sync/seqcount.go create mode 100644 pkg/sync/seqcount_test.go create mode 100644 pkg/sync/syncutil.go delete mode 100644 pkg/syncutil/BUILD delete mode 100644 pkg/syncutil/LICENSE delete mode 100644 pkg/syncutil/README.md delete mode 100644 pkg/syncutil/atomicptr_unsafe.go delete mode 100644 pkg/syncutil/atomicptrtest/BUILD delete mode 100644 pkg/syncutil/atomicptrtest/atomicptr_test.go delete mode 100644 pkg/syncutil/downgradable_rwmutex_test.go delete mode 100644 pkg/syncutil/downgradable_rwmutex_unsafe.go delete mode 100644 pkg/syncutil/memmove_unsafe.go delete mode 100644 pkg/syncutil/norace_unsafe.go delete mode 100644 pkg/syncutil/race_unsafe.go delete mode 100644 pkg/syncutil/seqatomic_unsafe.go delete mode 100644 pkg/syncutil/seqatomictest/BUILD delete mode 100644 pkg/syncutil/seqatomictest/seqatomic_test.go delete mode 100644 pkg/syncutil/seqcount.go delete mode 100644 pkg/syncutil/seqcount_test.go delete mode 100644 pkg/syncutil/syncutil.go (limited to 'pkg/tcpip/transport/tcp/snd.go') diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index 6bc486b62..d99e37b40 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -15,4 +15,5 @@ go_test( size = "small", srcs = ["amutex_test.go"], embed = [":amutex"], + deps = ["//pkg/sync"], ) diff --git a/pkg/amutex/amutex_test.go b/pkg/amutex/amutex_test.go index 1d7f45641..8a3952f2a 100644 --- a/pkg/amutex/amutex_test.go +++ b/pkg/amutex/amutex_test.go @@ -15,9 +15,10 @@ package amutex import ( - "sync" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) type sleeper struct { diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 36beaade9..6403c60c2 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -20,4 +20,5 @@ go_test( size = "small", srcs = ["atomic_bitops_test.go"], embed = [":atomicbitops"], + deps = ["//pkg/sync"], ) diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomic_bitops_test.go index 965e9be79..9466d3e23 100644 --- a/pkg/atomicbitops/atomic_bitops_test.go +++ b/pkg/atomicbitops/atomic_bitops_test.go @@ -16,8 +16,9 @@ package atomicbitops import ( "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) const iterations = 100 diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index a0b21d4bd..2bb581b18 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -8,7 +8,10 @@ go_library( srcs = ["compressio.go"], importpath = "gvisor.dev/gvisor/pkg/compressio", visibility = ["//:sandbox"], - deps = ["//pkg/binary"], + deps = [ + "//pkg/binary", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index 3b0bb086e..5f52cbe74 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -52,9 +52,9 @@ import ( "hash" "io" "runtime" - "sync" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sync" ) var bufPool = sync.Pool{ diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD index 21adf3adf..adbd1e3f8 100644 --- a/pkg/control/server/BUILD +++ b/pkg/control/server/BUILD @@ -9,6 +9,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", ], diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go index a56152d10..41abe1f2d 100644 --- a/pkg/control/server/server.go +++ b/pkg/control/server/server.go @@ -22,9 +22,9 @@ package server import ( "os" - "sync" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/urpc" ) diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 0b4b7cc44..9d68682c7 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ ":eventchannel_go_proto", "//pkg/log", + "//pkg/sync", "//pkg/unet", "@com_github_golang_protobuf//proto:go_default_library", "@com_github_golang_protobuf//ptypes:go_default_library_gen", @@ -40,6 +41,7 @@ go_test( srcs = ["event_test.go"], embed = [":eventchannel"], deps = [ + "//pkg/sync", "@com_github_golang_protobuf//proto:go_default_library", ], ) diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go index d37ad0428..9a29c58bd 100644 --- a/pkg/eventchannel/event.go +++ b/pkg/eventchannel/event.go @@ -22,13 +22,13 @@ package eventchannel import ( "encoding/binary" "fmt" - "sync" "syscall" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go index 3649097d6..7f41b4a27 100644 --- a/pkg/eventchannel/event_test.go +++ b/pkg/eventchannel/event_test.go @@ -16,11 +16,11 @@ package eventchannel import ( "fmt" - "sync" "testing" "time" "github.com/golang/protobuf/proto" + "gvisor.dev/gvisor/pkg/sync" ) // testEmitter is an emitter that can be used in tests. It records all events diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index 56495cbd9..b0478c672 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -15,4 +15,5 @@ go_test( size = "small", srcs = ["fdchannel_test.go"], embed = [":fdchannel"], + deps = ["//pkg/sync"], ) diff --git a/pkg/fdchannel/fdchannel_test.go b/pkg/fdchannel/fdchannel_test.go index 5d01dc636..7a8a63a59 100644 --- a/pkg/fdchannel/fdchannel_test.go +++ b/pkg/fdchannel/fdchannel_test.go @@ -17,10 +17,11 @@ package fdchannel import ( "io/ioutil" "os" - "sync" "syscall" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func TestSendRecvFD(t *testing.T) { diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD index aca2d8a82..91a202a30 100644 --- a/pkg/fdnotifier/BUILD +++ b/pkg/fdnotifier/BUILD @@ -11,6 +11,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/fdnotifier", visibility = ["//:sandbox"], deps = [ + "//pkg/sync", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/fdnotifier/fdnotifier.go b/pkg/fdnotifier/fdnotifier.go index f4aae1953..a6b63c982 100644 --- a/pkg/fdnotifier/fdnotifier.go +++ b/pkg/fdnotifier/fdnotifier.go @@ -22,10 +22,10 @@ package fdnotifier import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index e590a71ba..85bd83af1 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -19,7 +19,7 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/memutil", - "//pkg/syncutil", + "//pkg/sync", ], ) @@ -31,4 +31,5 @@ go_test( "flipcall_test.go", ], embed = [":flipcall"], + deps = ["//pkg/sync"], ) diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go index 8d88b845d..2e28a149a 100644 --- a/pkg/flipcall/flipcall_example_test.go +++ b/pkg/flipcall/flipcall_example_test.go @@ -17,7 +17,8 @@ package flipcall import ( "bytes" "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) func Example() { diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go index 168a487ec..33fd55a44 100644 --- a/pkg/flipcall/flipcall_test.go +++ b/pkg/flipcall/flipcall_test.go @@ -16,9 +16,10 @@ package flipcall import ( "runtime" - "sync" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) var testPacketWindowSize = pageSize diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index 27b8939fc..ac974b232 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -18,7 +18,7 @@ import ( "reflect" "unsafe" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // Packets consist of a 16-byte header followed by an arbitrarily-sized @@ -75,13 +75,13 @@ func (ep *Endpoint) Data() []byte { var ioSync int64 func raceBecomeActive() { - if syncutil.RaceEnabled { - syncutil.RaceAcquire((unsafe.Pointer)(&ioSync)) + if sync.RaceEnabled { + sync.RaceAcquire((unsafe.Pointer)(&ioSync)) } } func raceBecomeInactive() { - if syncutil.RaceEnabled { - syncutil.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + if sync.RaceEnabled { + sync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) } } diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index 4b9321711..f22bd070d 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -19,5 +19,6 @@ go_test( ], deps = [ ":gate", + "//pkg/sync", ], ) diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go index 5dbd8d712..850693df8 100644 --- a/pkg/gate/gate_test.go +++ b/pkg/gate/gate_test.go @@ -15,11 +15,11 @@ package gate_test import ( - "sync" "testing" "time" "gvisor.dev/gvisor/pkg/gate" + "gvisor.dev/gvisor/pkg/sync" ) func TestBasicEnter(t *testing.T) { diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index a5d980d14..bcde6d308 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -8,6 +8,7 @@ go_library( srcs = ["linewriter.go"], importpath = "gvisor.dev/gvisor/pkg/linewriter", visibility = ["//visibility:public"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/linewriter/linewriter.go b/pkg/linewriter/linewriter.go index cd6e4e2ce..a1b1285d4 100644 --- a/pkg/linewriter/linewriter.go +++ b/pkg/linewriter/linewriter.go @@ -17,7 +17,8 @@ package linewriter import ( "bytes" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Writer is an io.Writer which buffers input, flushing diff --git a/pkg/log/BUILD b/pkg/log/BUILD index fc5f5779b..0df0f2849 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -16,7 +16,10 @@ go_library( visibility = [ "//visibility:public", ], - deps = ["//pkg/linewriter"], + deps = [ + "//pkg/linewriter", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/log/log.go b/pkg/log/log.go index 9387586e6..91a81b288 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -25,12 +25,12 @@ import ( stdlog "log" "os" "runtime" - "sync" "sync/atomic" "syscall" "time" "gvisor.dev/gvisor/pkg/linewriter" + "gvisor.dev/gvisor/pkg/sync" ) // Level is the log level. diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index dd6ca6d39..9145f3233 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -14,6 +14,7 @@ go_library( ":metric_go_proto", "//pkg/eventchannel", "//pkg/log", + "//pkg/sync", ], ) diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index eadde06e4..93d4f2b8c 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -18,12 +18,12 @@ package metric import ( "errors" "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/eventchannel" "gvisor.dev/gvisor/pkg/log" pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index f32244c69..a3e05c96d 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -29,6 +29,7 @@ go_library( "//pkg/fdchannel", "//pkg/flipcall", "//pkg/log", + "//pkg/sync", "//pkg/unet", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 221516c6c..4045e41fa 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -17,12 +17,12 @@ package p9 import ( "errors" "fmt" - "sync" "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 28707c0ca..f4edd68b2 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -70,6 +70,7 @@ go_library( "//pkg/fd", "//pkg/log", "//pkg/p9", + "//pkg/sync", "//pkg/unet", "@com_github_golang_mock//gomock:go_default_library", ], @@ -83,6 +84,7 @@ go_test( deps = [ "//pkg/fd", "//pkg/p9", + "//pkg/sync", "@com_github_golang_mock//gomock:go_default_library", ], ) diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 6e758148d..6e7bb3db2 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -22,7 +22,6 @@ import ( "os" "reflect" "strings" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( "github.com/golang/mock/gomock" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" ) func TestPanic(t *testing.T) { diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index 4d3271b37..dd8b01b6d 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -17,13 +17,13 @@ package p9test import ( "fmt" - "sync" "sync/atomic" "syscall" "testing" "github.com/golang/mock/gomock" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/path_tree.go b/pkg/p9/path_tree.go index 865459411..72ef53313 100644 --- a/pkg/p9/path_tree.go +++ b/pkg/p9/path_tree.go @@ -16,7 +16,8 @@ package p9 import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // pathNode is a single node in a path traversal. diff --git a/pkg/p9/pool.go b/pkg/p9/pool.go index 52de889e1..2b14a5ce3 100644 --- a/pkg/p9/pool.go +++ b/pkg/p9/pool.go @@ -15,7 +15,7 @@ package p9 import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // pool is a simple allocator. diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 40b8fa023..fdfa83648 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -17,7 +17,6 @@ package p9 import ( "io" "runtime/debug" - "sync" "sync/atomic" "syscall" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/fdchannel" "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go index 6e8b4bbcd..9c11e28ce 100644 --- a/pkg/p9/transport.go +++ b/pkg/p9/transport.go @@ -19,11 +19,11 @@ import ( "fmt" "io" "io/ioutil" - "sync" "syscall" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index 078f084b2..b506813f0 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -21,6 +21,7 @@ go_test( "procid_test.go", ], embed = [":procid"], + deps = ["//pkg/sync"], ) go_test( @@ -31,4 +32,5 @@ go_test( "procid_test.go", ], embed = [":procid"], + deps = ["//pkg/sync"], ) diff --git a/pkg/procid/procid_test.go b/pkg/procid/procid_test.go index 88dd0b3ae..9ec08c3d6 100644 --- a/pkg/procid/procid_test.go +++ b/pkg/procid/procid_test.go @@ -17,9 +17,10 @@ package procid import ( "os" "runtime" - "sync" "syscall" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) // runOnMain is used to send functions to run on the main (initial) thread. diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD index f4f2001f3..9d5b4859b 100644 --- a/pkg/rand/BUILD +++ b/pkg/rand/BUILD @@ -10,5 +10,8 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/rand", visibility = ["//:sandbox"], - deps = ["@org_golang_x_sys//unix:go_default_library"], + deps = [ + "//pkg/sync", + "@org_golang_x_sys//unix:go_default_library", + ], ) diff --git a/pkg/rand/rand_linux.go b/pkg/rand/rand_linux.go index 2b92db3e6..0bdad5fad 100644 --- a/pkg/rand/rand_linux.go +++ b/pkg/rand/rand_linux.go @@ -19,9 +19,9 @@ package rand import ( "crypto/rand" "io" - "sync" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" ) // reader implements an io.Reader that returns pseudorandom bytes. diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 7ad59dfd7..974d9af9b 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -27,6 +27,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", ], ) @@ -35,4 +36,5 @@ go_test( size = "small", srcs = ["refcounter_test.go"], embed = [":refs"], + deps = ["//pkg/sync"], ) diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index ad69e0757..c45ba8200 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -21,10 +21,10 @@ import ( "fmt" "reflect" "runtime" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" ) // RefCounter is the interface to be implemented by objects that are reference diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go index ffd3d3f07..1ab4a4440 100644 --- a/pkg/refs/refcounter_test.go +++ b/pkg/refs/refcounter_test.go @@ -16,8 +16,9 @@ package refs import ( "reflect" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) type testCounter struct { diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 18c73cc24..ae3e364cd 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/sentry/context", "//pkg/sentry/limits", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index 9294ac773..9f41e566f 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -19,7 +19,6 @@ package arch import ( "fmt" "io" - "sync" "syscall" "gvisor.dev/gvisor/pkg/binary" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/log" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index 5522cecd0..2561a6109 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/sentry/strace", "//pkg/sentry/usage", "//pkg/sentry/watchdog", + "//pkg/sync", "//pkg/tcpip/link/sniffer", "//pkg/urpc", ], diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index e1f2fea60..151808911 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -19,10 +19,10 @@ import ( "runtime" "runtime/pprof" "runtime/trace" - "sync" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/urpc" ) diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 1098ed777..97fa1512c 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -8,7 +8,10 @@ go_library( srcs = ["device.go"], importpath = "gvisor.dev/gvisor/pkg/sentry/device", visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/abi/linux"], + deps = [ + "//pkg/abi/linux", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go index 47945d1a7..69e71e322 100644 --- a/pkg/sentry/device/device.go +++ b/pkg/sentry/device/device.go @@ -19,10 +19,10 @@ package device import ( "bytes" "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" ) // Registry tracks all simple devices and related state on the system for diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index c035ffff7..7d5d72d5a 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -68,7 +68,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], @@ -115,6 +115,7 @@ go_test( "//pkg/sentry/fs/tmpfs", "//pkg/sentry/kernel/contexttest", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index 9ac62c84d..734177e90 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -17,12 +17,12 @@ package fs import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index 1d80bf15a..738580c5f 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -19,13 +19,13 @@ import ( "crypto/rand" "fmt" "io" - "sync" "testing" "gvisor.dev/gvisor/pkg/sentry/fs" _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) const ( diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index 3cb73bd78..31fc4d87b 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -18,7 +18,6 @@ import ( "fmt" "path" "sort" - "sync" "sync/atomic" "syscall" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go index 60a15a275..25514ace4 100644 --- a/pkg/sentry/fs/dirent_cache.go +++ b/pkg/sentry/fs/dirent_cache.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // DirentCache is an LRU cache of Dirents. The Dirent's refCount is diff --git a/pkg/sentry/fs/dirent_cache_limiter.go b/pkg/sentry/fs/dirent_cache_limiter.go index ebb80bd50..525ee25f9 100644 --- a/pkg/sentry/fs/dirent_cache_limiter.go +++ b/pkg/sentry/fs/dirent_cache_limiter.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // DirentCacheLimiter acts as a global limit for all dirent caches in the diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index 277ee4c31..cc43de69d 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -23,6 +23,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/safemem", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index 669ffcb75..5b6cfeb0a 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -17,7 +17,6 @@ package fdpipe import ( "os" - "sync" "syscall" "gvisor.dev/gvisor/pkg/fd" @@ -29,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go index 29175fb3d..cee87f726 100644 --- a/pkg/sentry/fs/fdpipe/pipe_state.go +++ b/pkg/sentry/fs/fdpipe/pipe_state.go @@ -17,10 +17,10 @@ package fdpipe import ( "fmt" "io/ioutil" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" ) // beforeSave is invoked by stateify. diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index a2f966cb6..7c4586296 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -16,7 +16,6 @@ package fs import ( "math" - "sync" "sync/atomic" "time" @@ -29,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 225e40186..8a633b1ba 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -16,13 +16,13 @@ package fs import ( "io" - "sync" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go index b157fd228..c5b51620a 100644 --- a/pkg/sentry/fs/filesystems.go +++ b/pkg/sentry/fs/filesystems.go @@ -18,9 +18,9 @@ import ( "fmt" "sort" "strings" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" ) // FilesystemFlags matches include/linux/fs.h:file_system_type.fs_flags. diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index 8b2a5e6b2..26abf49e2 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -54,10 +54,9 @@ package fs import ( - "sync" - "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 9ca695a95..945b6270d 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -93,6 +93,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index b06a71cc2..837fc70b5 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -16,7 +16,6 @@ package fsutil import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/log" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // HostFileMapper caches mappings of an arbitrary host file descriptor. It is diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index 30475f340..a625f0e26 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -16,7 +16,6 @@ package fsutil import ( "math" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // HostMappable implements memmap.Mappable and platform.File over a diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go index 4e100a402..adf5ec69c 100644 --- a/pkg/sentry/fs/fsutil/inode.go +++ b/pkg/sentry/fs/fsutil/inode.go @@ -15,13 +15,12 @@ package fsutil import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 798920d18..20a014402 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -17,7 +17,6 @@ package fsutil import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // Lock order (compare the lock order model in mm/mm.go): diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 4a005c605..fd870e8e1 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -44,6 +44,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 91263ebdc..245fe2ef1 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -16,7 +16,6 @@ package gofer import ( "errors" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/host" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 4e358a46a..edc796ce0 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -16,7 +16,6 @@ package gofer import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refs" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 23daeb528..2b581aa69 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/sentry/unimpl", "//pkg/sentry/uniqueid", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index a6e4a09e3..873a1c52d 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -15,7 +15,6 @@ package host import ( - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 107336a3e..c076d5bdd 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -16,7 +16,6 @@ package host import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -30,6 +29,7 @@ import ( unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 90331e3b2..753ef8cd6 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -15,8 +15,6 @@ package host import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index 91e2fde2f..468043df0 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -15,8 +15,6 @@ package fs import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" @@ -26,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go index 0f2a66a79..efd3c962b 100644 --- a/pkg/sentry/fs/inode_inotify.go +++ b/pkg/sentry/fs/inode_inotify.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Watches is the collection of inotify watches on an inode. diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index ba3e0233d..cc7dd1c92 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -16,7 +16,6 @@ package fs import ( "io" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go index 0aa0a5e9b..900cba3ca 100644 --- a/pkg/sentry/fs/inotify_watch.go +++ b/pkg/sentry/fs/inotify_watch.go @@ -15,10 +15,10 @@ package fs import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" ) // Watch represent a particular inotify watch created by inotify_add_watch. diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 8d62642e7..2c332a82a 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -44,6 +44,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go index 636484424..41b040818 100644 --- a/pkg/sentry/fs/lock/lock.go +++ b/pkg/sentry/fs/lock/lock.go @@ -52,9 +52,9 @@ package lock import ( "fmt" "math" - "sync" "syscall" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index ac0398bd9..db3dfd096 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -19,7 +19,6 @@ import ( "math" "path" "strings" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index 25573e986..4cad55327 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -17,13 +17,12 @@ package fs import ( "fmt" "strings" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -199,7 +198,7 @@ type overlayEntry struct { upper *Inode // dirCacheMu protects dirCache. - dirCacheMu syncutil.DowngradableRWMutex `state:"nosave"` + dirCacheMu sync.DowngradableRWMutex `state:"nosave"` // dirCache is cache of DentAttrs from upper and lower Inodes. dirCache *SortedDentryMap diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 75cbb0622..94d46ab1b 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -51,6 +51,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", "//pkg/waiter", diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index fe7067be1..38b246dff 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/proc/device", "//pkg/sentry/kernel/time", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go index 5fe823000..f9af191d5 100644 --- a/pkg/sentry/fs/proc/seqfile/seqfile.go +++ b/pkg/sentry/fs/proc/seqfile/seqfile.go @@ -17,7 +17,6 @@ package seqfile import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -26,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index bd93f83fa..a37e1fa06 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -17,7 +17,6 @@ package proc import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 012cb3e44..3fb7b0633 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go index 78e082b8e..dcbb8eb2e 100644 --- a/pkg/sentry/fs/ramfs/dir.go +++ b/pkg/sentry/fs/ramfs/dir.go @@ -17,7 +17,6 @@ package ramfs import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/restore.go b/pkg/sentry/fs/restore.go index f10168125..64c6a6ae9 100644 --- a/pkg/sentry/fs/restore.go +++ b/pkg/sentry/fs/restore.go @@ -15,7 +15,7 @@ package fs import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // RestoreEnvironment is the restore environment for file systems. It consists diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 59ce400c2..3400b940c 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index f86dfaa36..f1c87fe41 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -17,7 +17,6 @@ package tmpfs import ( "fmt" "io" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 95ad98cb0..f6f60d0cf 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 2f639c823..88aa66b24 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -19,7 +19,6 @@ import ( "fmt" "math" "strconv" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go index 7cc0eb409..894964260 100644 --- a/pkg/sentry/fs/tty/line_discipline.go +++ b/pkg/sentry/fs/tty/line_discipline.go @@ -16,13 +16,13 @@ package tty import ( "bytes" - "sync" "unicode/utf8" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go index 231e4e6eb..8b5d4699a 100644 --- a/pkg/sentry/fs/tty/queue.go +++ b/pkg/sentry/fs/tty/queue.go @@ -15,13 +15,12 @@ package tty import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index bc90330bc..903874141 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/sentry/syscalls/linux", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 91802dc1e..8944171c8 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -15,8 +15,6 @@ package ext import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/log" @@ -25,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 616fc002a..9afb1a84c 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -17,13 +17,13 @@ package ext import ( "errors" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index aec33e00a..d11153c90 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -16,7 +16,6 @@ package ext import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 39c03ee9d..809178250 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -39,6 +39,7 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", ], ) @@ -56,6 +57,7 @@ go_test( "//pkg/sentry/kernel/auth", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "@com_github_google_go-cmp//cmp:go_default_library", ], diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 752e0f659..1d469a0db 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -16,7 +16,6 @@ package kernfs import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index d69b299ae..bb12f39a2 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -53,7 +53,6 @@ package kernfs import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -61,6 +60,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" ) // FilesystemType implements vfs.FilesystemType. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 4b6b95f5f..5c9d580e1 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "runtime" - "sync" "testing" "github.com/google/go-cmp/cmp" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index a5b285987..82f5c2f41 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -47,6 +47,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index f51e247a7..f200e767d 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -17,7 +17,6 @@ package tmpfs import ( "io" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 7be6faa5b..701826f90 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -26,7 +26,6 @@ package tmpfs import ( "fmt" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 2706927ff..ac85ba0c8 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -35,7 +35,7 @@ go_template_instance( out = "seqatomic_taskgoroutineschedinfo_unsafe.go", package = "kernel", suffix = "TaskGoroutineSchedInfo", - template = "//pkg/syncutil:generic_seqatomic", + template = "//pkg/sync:generic_seqatomic", types = { "Value": "TaskGoroutineSchedInfo", }, @@ -209,7 +209,7 @@ go_library( "//pkg/sentry/usermem", "//pkg/state", "//pkg/state/statefile", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", @@ -241,6 +241,7 @@ go_test( "//pkg/sentry/time", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index 244655b5c..920fe4329 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -15,11 +15,11 @@ package kernel import ( - "sync" "syscall" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" ) // +stateify savable diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 04c244447..1aa72fa47 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_credentials_unsafe.go", package = "auth", suffix = "Credentials", - template = "//pkg/syncutil:generic_atomicptr", + template = "//pkg/sync:generic_atomicptr", types = { "Value": "Credentials", }, @@ -64,6 +64,7 @@ go_library( "//pkg/bits", "//pkg/log", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/auth/user_namespace.go b/pkg/sentry/kernel/auth/user_namespace.go index af28ccc65..9dd52c860 100644 --- a/pkg/sentry/kernel/auth/user_namespace.go +++ b/pkg/sentry/kernel/auth/user_namespace.go @@ -16,8 +16,8 @@ package auth import ( "math" - "sync" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index 3361e8b7d..c47f6b6fc 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 9c0a4e1b4..430311cc0 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -18,7 +18,6 @@ package epoll import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/refs" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index e65b961e8..c831fbab2 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 12f0d429b..687690679 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -18,7 +18,6 @@ package eventfd import ( "math" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 49d81b712..6b36bc63e 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index 6b0bb0324..d32c3e90a 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -16,12 +16,11 @@ package fasync import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 11f613a11..cd1501f85 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -18,7 +18,6 @@ import ( "bytes" "fmt" "math" - "sync" "sync/atomic" "syscall" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) // FDFlags define flags for an individual descriptor. diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 2bcb6216a..eccb7d1e7 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -16,7 +16,6 @@ package kernel import ( "runtime" - "sync" "testing" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/filetest" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) const ( diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index ded27d668..2448c1d99 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -16,10 +16,10 @@ package kernel import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" ) // FSContext contains filesystem context. diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 75ec31761..50db443ce 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "atomicptr_bucket_unsafe.go", package = "futex", suffix = "Bucket", - template = "//pkg/syncutil:generic_atomicptr", + template = "//pkg/sync:generic_atomicptr", types = { "Value": "bucket", }, @@ -42,6 +42,7 @@ go_library( "//pkg/sentry/context", "//pkg/sentry/memmap", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) @@ -51,5 +52,8 @@ go_test( size = "small", srcs = ["futex_test.go"], embed = [":futex"], - deps = ["//pkg/sentry/usermem"], + deps = [ + "//pkg/sentry/usermem", + "//pkg/sync", + ], ) diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index 278cc8143..d1931c8f4 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -18,11 +18,10 @@ package futex import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go index 65e5d1428..c23126ca5 100644 --- a/pkg/sentry/kernel/futex/futex_test.go +++ b/pkg/sentry/kernel/futex/futex_test.go @@ -17,13 +17,13 @@ package futex import ( "math" "runtime" - "sync" "sync/atomic" "syscall" "testing" "unsafe" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // testData implements the Target interface, and allows us to diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 8653d2f63..c85e97fef 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -36,7 +36,6 @@ import ( "fmt" "io" "path/filepath" - "sync" "sync/atomic" "time" @@ -67,6 +66,7 @@ import ( uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD index d7a7d1169..7f36252a9 100644 --- a/pkg/sentry/kernel/memevent/BUILD +++ b/pkg/sentry/kernel/memevent/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/metric", "//pkg/sentry/kernel", "//pkg/sentry/usage", + "//pkg/sync", ], ) diff --git a/pkg/sentry/kernel/memevent/memory_events.go b/pkg/sentry/kernel/memevent/memory_events.go index b0d98e7f0..200565bb8 100644 --- a/pkg/sentry/kernel/memevent/memory_events.go +++ b/pkg/sentry/kernel/memevent/memory_events.go @@ -17,7 +17,6 @@ package memevent import ( - "sync" "time" "gvisor.dev/gvisor/pkg/eventchannel" @@ -26,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" pb "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sync" ) var totalTicks = metric.MustCreateNewUint64Metric("/memory_events/ticks", false /*sync*/, "Total number of memory event periods that have elapsed since startup.") diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 9d34f6d4d..5eeaeff66 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -43,6 +43,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go index 95bee2d37..1c0f34269 100644 --- a/pkg/sentry/kernel/pipe/buffer.go +++ b/pkg/sentry/kernel/pipe/buffer.go @@ -16,9 +16,9 @@ package pipe import ( "io" - "sync" "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/sync" ) // buffer encapsulates a queueable byte buffer. diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index 4a19ab7ce..716f589af 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -15,12 +15,11 @@ package pipe import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 1a1b38f83..e4fd7d420 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -17,12 +17,12 @@ package pipe import ( "fmt" - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index ef9641e6a..8394eb78b 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -17,7 +17,6 @@ package pipe import ( "io" "math" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 6416e0dd8..bf7461cbb 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -15,13 +15,12 @@ package pipe import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index f4c00cd86..13a961594 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index de9617e9d..18299814e 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -17,7 +17,6 @@ package semaphore import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index cd48945e6..7321b22ed 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -24,6 +24,7 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 19034a21e..8ddef7eb8 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -35,7 +35,6 @@ package shm import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" @@ -49,6 +48,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go index a16f3d57f..768fda220 100644 --- a/pkg/sentry/kernel/signal_handlers.go +++ b/pkg/sentry/kernel/signal_handlers.go @@ -15,10 +15,9 @@ package kernel import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sync" ) // SignalHandlers holds information about signal actions. diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 9f7e19b4d..89e4d84b1 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 4b08d7d72..28be4a939 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -16,8 +16,6 @@ package signalfd import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/context" @@ -26,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 2fdee0282..d2d01add4 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -16,13 +16,13 @@ package kernel import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // maxSyscallNum is the highest supported syscall number. diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go index 8227ecf1d..4607cde2f 100644 --- a/pkg/sentry/kernel/syslog.go +++ b/pkg/sentry/kernel/syslog.go @@ -17,7 +17,8 @@ package kernel import ( "fmt" "math/rand" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // syslog represents a sentry-global kernel log. diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index d25a7903b..978d66da8 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -17,7 +17,6 @@ package kernel import ( gocontext "context" "runtime/trace" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -37,7 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) @@ -85,7 +84,7 @@ type Task struct { // // gosched is protected by goschedSeq. gosched is owned by the task // goroutine. - goschedSeq syncutil.SeqCount `state:"nosave"` + goschedSeq sync.SeqCount `state:"nosave"` gosched TaskGoroutineSchedInfo // yieldCount is the number of times the task goroutine has called diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index c0197a563..768e958d2 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -15,7 +15,6 @@ package kernel import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 8267929a6..bf2dabb6e 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -16,9 +16,9 @@ package kernel import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 31847e1df..4e4de0512 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -13,6 +13,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index 107394183..706de83ef 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -19,10 +19,10 @@ package time import ( "fmt" "math" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 76417342a..dc99301de 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -16,7 +16,6 @@ package kernel import ( "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/log" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sync" ) // Timekeeper manages all of the kernel clocks. diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go index 048de26dc..464d2306a 100644 --- a/pkg/sentry/kernel/tty.go +++ b/pkg/sentry/kernel/tty.go @@ -14,7 +14,7 @@ package kernel -import "sync" +import "gvisor.dev/gvisor/pkg/sync" // TTY defines the relationship between a thread group and its controlling // terminal. diff --git a/pkg/sentry/kernel/uts_namespace.go b/pkg/sentry/kernel/uts_namespace.go index 0a563e715..8ccf04bd1 100644 --- a/pkg/sentry/kernel/uts_namespace.go +++ b/pkg/sentry/kernel/uts_namespace.go @@ -15,9 +15,8 @@ package kernel import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" ) // UTSNamespace represents a UTS namespace, a holder of two system identifiers: diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 156e67bf8..9fa841e8b 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/sentry/context", + "//pkg/sync", ], ) diff --git a/pkg/sentry/limits/limits.go b/pkg/sentry/limits/limits.go index b6c22656b..31b9e9ff6 100644 --- a/pkg/sentry/limits/limits.go +++ b/pkg/sentry/limits/limits.go @@ -16,8 +16,9 @@ package limits import ( - "sync" "syscall" + + "gvisor.dev/gvisor/pkg/sync" ) // LimitType defines a type of resource limit. diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 839931f67..83e248431 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -118,7 +118,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/usage", "//pkg/sentry/usermem", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/buffer", ], diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 1b746d030..4b48866ad 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -15,8 +15,6 @@ package mm import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" @@ -25,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 58a5c186d..fa86ebced 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -35,8 +35,6 @@ package mm import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -44,7 +42,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // MemoryManager implements a virtual address space. @@ -82,7 +80,7 @@ type MemoryManager struct { users int32 // mappingMu is analogous to Linux's struct mm_struct::mmap_sem. - mappingMu syncutil.DowngradableRWMutex `state:"nosave"` + mappingMu sync.DowngradableRWMutex `state:"nosave"` // vmas stores virtual memory areas. Since vmas are stored by value, // clients should usually use vmaIterator.ValuePtr() instead of @@ -125,7 +123,7 @@ type MemoryManager struct { // activeMu is loosely analogous to Linux's struct // mm_struct::page_table_lock. - activeMu syncutil.DowngradableRWMutex `state:"nosave"` + activeMu sync.DowngradableRWMutex `state:"nosave"` // pmas stores platform mapping areas used to implement vmas. Since pmas // are stored by value, clients should usually use pmaIterator.ValuePtr() diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index f404107af..a9a2642c5 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -73,6 +73,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index f7f7298c4..c99e023d9 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -25,7 +25,6 @@ import ( "fmt" "math" "os" - "sync" "sync/atomic" "syscall" "time" @@ -37,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index b6d008dbe..85e882df9 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -10,6 +10,7 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/interrupt", visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go index a4651f500..57be41647 100644 --- a/pkg/sentry/platform/interrupt/interrupt.go +++ b/pkg/sentry/platform/interrupt/interrupt.go @@ -17,7 +17,8 @@ package interrupt import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Receiver receives interrupt notifications from a Forwarder. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index f3afd98da..6a358d1d4 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -55,6 +55,7 @@ go_library( "//pkg/sentry/platform/safecopy", "//pkg/sentry/time", "//pkg/sentry/usermem", + "//pkg/sync", ], ) diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index ea8b9632e..a25f3c449 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -15,13 +15,13 @@ package kvm import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // dirtySet tracks vCPUs for invalidation. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index e5fac0d6a..2f02c03cf 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -17,8 +17,6 @@ package kvm import ( - "unsafe" - "gvisor.dev/gvisor/pkg/sentry/arch" ) diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index f2c2c059e..a7850faed 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -18,13 +18,13 @@ package kvm import ( "fmt" "os" - "sync" "syscall" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // KVM represents a lightweight VM context. diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 7d02ebf19..e6d912168 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -17,7 +17,6 @@ package kvm import ( "fmt" "runtime" - "sync" "sync/atomic" "syscall" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // machine contains state associated with the VM as a whole. diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 0df8cfa0f..cd13390c3 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -33,6 +33,7 @@ go_library( "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/safecopy", "//pkg/sentry/usermem", + "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index 7b120a15d..bb0e03880 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -46,13 +46,13 @@ package ptrace import ( "os" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 20244fd95..15dc46a5b 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -18,7 +18,6 @@ import ( "fmt" "os" "runtime" - "sync" "syscall" "golang.org/x/sys/unix" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // Linux kernel errnos which "should never be seen by user programs", but will diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go index 2e6fbe488..245b20722 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go @@ -18,7 +18,6 @@ package ptrace import ( - "sync" "sync/atomic" "syscall" "unsafe" @@ -26,6 +25,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/hostcpu" + "gvisor.dev/gvisor/pkg/sync" ) // maskPool contains reusable CPU masks for setting affinity. Unfortunately, diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index 3f094c2a7..86fd5ed58 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -17,7 +17,7 @@ package ring0 import ( "syscall" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) // Kernel is a global kernel object. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 10dbd381f..9dae0dccb 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -18,6 +18,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) var ( diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index dc0eeec01..a850ce6cf 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -18,6 +18,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) var ( diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index e2e15ba5c..387a7f6c3 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -96,7 +96,10 @@ go_library( "//pkg/sentry/platform/kvm:__subpackages__", "//pkg/sentry/platform/ring0:__subpackages__", ], - deps = ["//pkg/sentry/usermem"], + deps = [ + "//pkg/sentry/usermem", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go index 0f029f25d..e199bae18 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go +++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go @@ -17,7 +17,7 @@ package pagetables import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // limitPCID is the number of valid PCIDs. diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 136821963..103933144 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -27,6 +27,7 @@ go_library( "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 463544c1a..2d9f4ba9b 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -8,6 +8,7 @@ go_library( srcs = ["port.go"], importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port", visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go index e9d3275b1..2cd3afc22 100644 --- a/pkg/sentry/socket/netlink/port/port.go +++ b/pkg/sentry/socket/netlink/port/port.go @@ -24,7 +24,8 @@ import ( "fmt" "math" "math/rand" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // maxPorts is a sanity limit on the maximum number of ports to allocate per diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index d2e3644a6..cea56f4ed 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -17,7 +17,6 @@ package netlink import ( "math" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index e414d8055..f78784569 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -34,6 +34,7 @@ go_library( "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 764f11a6b..0affb8071 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -29,7 +29,6 @@ import ( "io" "math" "reflect" - "sync" "syscall" "time" @@ -49,6 +48,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/socket/rpcinet/conn/BUILD b/pkg/sentry/socket/rpcinet/conn/BUILD index 23eadcb1b..b2677c659 100644 --- a/pkg/sentry/socket/rpcinet/conn/BUILD +++ b/pkg/sentry/socket/rpcinet/conn/BUILD @@ -10,6 +10,7 @@ go_library( deps = [ "//pkg/binary", "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", + "//pkg/sync", "//pkg/syserr", "//pkg/unet", "@com_github_golang_protobuf//proto:go_default_library", diff --git a/pkg/sentry/socket/rpcinet/conn/conn.go b/pkg/sentry/socket/rpcinet/conn/conn.go index 356adad99..02f39c767 100644 --- a/pkg/sentry/socket/rpcinet/conn/conn.go +++ b/pkg/sentry/socket/rpcinet/conn/conn.go @@ -17,12 +17,12 @@ package conn import ( "fmt" - "sync" "sync/atomic" "syscall" "github.com/golang/protobuf/proto" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/unet" diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD index a3585e10d..a5954f22b 100644 --- a/pkg/sentry/socket/rpcinet/notifier/BUILD +++ b/pkg/sentry/socket/rpcinet/notifier/BUILD @@ -10,6 +10,7 @@ go_library( deps = [ "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", "//pkg/sentry/socket/rpcinet/conn", + "//pkg/sync", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go index 7efe4301f..82b75d6dd 100644 --- a/pkg/sentry/socket/rpcinet/notifier/notifier.go +++ b/pkg/sentry/socket/rpcinet/notifier/notifier.go @@ -17,12 +17,12 @@ package notifier import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 788ad70d2..d7ba95dff 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/ilist", "//pkg/refs", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index dea11e253..9e6fbc111 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -15,10 +15,9 @@ package transport import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index e27b1c714..5dcd3d95e 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -15,9 +15,8 @@ package transport import ( - "sync" - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 37c7ac3c1..fcc0da332 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,11 +16,11 @@ package transport import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index a76975cee..aa05e208a 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -91,6 +91,7 @@ go_library( "//pkg/sentry/syscalls", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/waiter", diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 1d9018c96..60469549d 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -16,13 +16,13 @@ package linux import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 18e212dff..3cde3a0be 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "seqatomic_parameters_unsafe.go", package = "time", suffix = "Parameters", - template = "//pkg/syncutil:generic_seqatomic", + template = "//pkg/sync:generic_seqatomic", types = { "Value": "Parameters", }, @@ -36,7 +36,7 @@ go_library( deps = [ "//pkg/log", "//pkg/metric", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go index 318503277..f9a93115d 100644 --- a/pkg/sentry/time/calibrated_clock.go +++ b/pkg/sentry/time/calibrated_clock.go @@ -17,11 +17,11 @@ package time import ( - "sync" "time" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD index c32fe3241..5518ac3d0 100644 --- a/pkg/sentry/usage/BUILD +++ b/pkg/sentry/usage/BUILD @@ -18,5 +18,6 @@ go_library( deps = [ "//pkg/bits", "//pkg/memutil", + "//pkg/sync", ], ) diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go index d6ef644d8..538c645eb 100644 --- a/pkg/sentry/usage/memory.go +++ b/pkg/sentry/usage/memory.go @@ -17,12 +17,12 @@ package usage import ( "fmt" "os" - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/sync" ) // MemoryKind represents a type of memory used by the application. diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 4c6aa04a1..35c7be259 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -34,7 +34,7 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/usermem", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], @@ -54,6 +54,7 @@ go_test( "//pkg/sentry/context/contexttest", "//pkg/sentry/kernel/auth", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 1bc9c4a38..486a76475 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -16,9 +16,9 @@ package vfs import ( "fmt" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 66eb57bc2..c00b3c84b 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -17,13 +17,13 @@ package vfs import ( "bytes" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index adff0b94b..3b933468d 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -17,8 +17,9 @@ package vfs import ( "fmt" "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) func TestMountTableLookupEmpty(t *testing.T) { diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index ab13fa461..bd90d36c4 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -26,7 +26,7 @@ import ( "sync/atomic" "unsafe" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // mountKey represents the location at which a Mount is mounted. It is @@ -75,7 +75,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq syncutil.SeqCount + seq sync.SeqCount seed uint32 // for hashing keys // size holds both length (number of elements) and capacity (number of diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go index 8e155654f..cf80df90e 100644 --- a/pkg/sentry/vfs/pathname.go +++ b/pkg/sentry/vfs/pathname.go @@ -15,10 +15,9 @@ package vfs import ( - "sync" - "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index f0641d314..8a0b382f6 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -16,11 +16,11 @@ package vfs import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index ea2db7031..1f21b0b31 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -29,12 +29,12 @@ package vfs import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD index 4d8435265..28f21f13d 100644 --- a/pkg/sentry/watchdog/BUILD +++ b/pkg/sentry/watchdog/BUILD @@ -13,5 +13,6 @@ go_library( "//pkg/metric", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", + "//pkg/sync", ], ) diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 5e4611333..bfb2fac26 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -32,7 +32,6 @@ package watchdog import ( "bytes" "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -40,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sync" ) // Opts configures the watchdog. diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD new file mode 100644 index 000000000..e8cd16b8f --- /dev/null +++ b/pkg/sync/BUILD @@ -0,0 +1,53 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +exports_files(["LICENSE"]) + +go_template( + name = "generic_atomicptr", + srcs = ["atomicptr_unsafe.go"], + types = [ + "Value", + ], +) + +go_template( + name = "generic_seqatomic", + srcs = ["seqatomic_unsafe.go"], + types = [ + "Value", + ], + deps = [ + ":sync", + ], +) + +go_library( + name = "sync", + srcs = [ + "aliases.go", + "downgradable_rwmutex_unsafe.go", + "memmove_unsafe.go", + "norace_unsafe.go", + "race_unsafe.go", + "seqcount.go", + "syncutil.go", + ], + importpath = "gvisor.dev/gvisor/pkg/sync", +) + +go_test( + name = "sync_test", + size = "small", + srcs = [ + "downgradable_rwmutex_test.go", + "seqcount_test.go", + ], + embed = [":sync"], +) diff --git a/pkg/sync/LICENSE b/pkg/sync/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/pkg/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/sync/README.md b/pkg/sync/README.md new file mode 100644 index 000000000..2183c4e20 --- /dev/null +++ b/pkg/sync/README.md @@ -0,0 +1,5 @@ +# Syncutil + +This package provides additional synchronization primitives not provided by the +Go stdlib 'sync' package. It is partially derived from the upstream 'sync' +package from go1.10. diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go new file mode 100644 index 000000000..20c7ca041 --- /dev/null +++ b/pkg/sync/aliases.go @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "sync" +) + +// Aliases of standard library types. +type ( + // Mutex is an alias of sync.Mutex. + Mutex = sync.Mutex + + // RWMutex is an alias of sync.RWMutex. + RWMutex = sync.RWMutex + + // Cond is an alias of sync.Cond. + Cond = sync.Cond + + // Locker is an alias of sync.Locker. + Locker = sync.Locker + + // Once is an alias of sync.Once. + Once = sync.Once + + // Pool is an alias of sync.Pool. + Pool = sync.Pool + + // WaitGroup is an alias of sync.WaitGroup. + WaitGroup = sync.WaitGroup + + // Map is an alias of sync.Map. + Map = sync.Map +) diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sync/atomicptr_unsafe.go new file mode 100644 index 000000000..525c4beed --- /dev/null +++ b/pkg/sync/atomicptr_unsafe.go @@ -0,0 +1,47 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "sync/atomic" + "unsafe" +) + +// Value is a required type parameter. +type Value struct{} + +// An AtomicPtr is a pointer to a value of type Value that can be atomically +// loaded and stored. The zero value of an AtomicPtr represents nil. +// +// Note that copying AtomicPtr by value performs a non-atomic read of the +// stored pointer, which is unsafe if Store() can be called concurrently; in +// this case, do `dst.Store(src.Load())` instead. +// +// +stateify savable +type AtomicPtr struct { + ptr unsafe.Pointer `state:".(*Value)"` +} + +func (p *AtomicPtr) savePtr() *Value { + return p.Load() +} + +func (p *AtomicPtr) loadPtr(v *Value) { + p.Store(v) +} + +// Load returns the value set by the most recent Store. It returns nil if there +// has been no previous call to Store. +func (p *AtomicPtr) Load() *Value { + return (*Value)(atomic.LoadPointer(&p.ptr)) +} + +// Store sets the value returned by Load to x. +func (p *AtomicPtr) Store(x *Value) { + atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) +} diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD new file mode 100644 index 000000000..418eda29c --- /dev/null +++ b/pkg/sync/atomicptrtest/BUILD @@ -0,0 +1,29 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "atomicptr_int", + out = "atomicptr_int_unsafe.go", + package = "atomicptr", + suffix = "Int", + template = "//pkg/sync:generic_atomicptr", + types = { + "Value": "int", + }, +) + +go_library( + name = "atomicptr", + srcs = ["atomicptr_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/sync/atomicptr", +) + +go_test( + name = "atomicptr_test", + size = "small", + srcs = ["atomicptr_test.go"], + embed = [":atomicptr"], +) diff --git a/pkg/sync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptrtest/atomicptr_test.go new file mode 100644 index 000000000..8fdc5112e --- /dev/null +++ b/pkg/sync/atomicptrtest/atomicptr_test.go @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package atomicptr + +import ( + "testing" +) + +func newInt(val int) *int { + return &val +} + +func TestAtomicPtr(t *testing.T) { + var p AtomicPtrInt + if got := p.Load(); got != nil { + t.Errorf("initial value is %p (%v), wanted nil", got, got) + } + want := newInt(42) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } + want = newInt(100) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } +} diff --git a/pkg/sync/downgradable_rwmutex_test.go b/pkg/sync/downgradable_rwmutex_test.go new file mode 100644 index 000000000..f04496bc5 --- /dev/null +++ b/pkg/sync/downgradable_rwmutex_test.go @@ -0,0 +1,150 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// GOMAXPROCS=10 go test + +// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the +// addition of downgradingWriter and the renaming of num_iterations to +// numIterations to shut up Golint. + +package sync + +import ( + "fmt" + "runtime" + "sync/atomic" + "testing" +) + +func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) { + m.RLock() + clocked <- true + <-cunlock + m.RUnlock() + cdone <- true +} + +func doTestParallelReaders(numReaders, gomaxprocs int) { + runtime.GOMAXPROCS(gomaxprocs) + var m DowngradableRWMutex + clocked := make(chan bool) + cunlock := make(chan bool) + cdone := make(chan bool) + for i := 0; i < numReaders; i++ { + go parallelReader(&m, clocked, cunlock, cdone) + } + // Wait for all parallel RLock()s to succeed. + for i := 0; i < numReaders; i++ { + <-clocked + } + for i := 0; i < numReaders; i++ { + cunlock <- true + } + // Wait for the goroutines to finish. + for i := 0; i < numReaders; i++ { + <-cdone + } +} + +func TestParallelReaders(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + doTestParallelReaders(1, 4) + doTestParallelReaders(3, 4) + doTestParallelReaders(4, 2) +} + +func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.RLock() + n := atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.Unlock() + } + cdone <- true +} + +func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.DowngradeLock() + n = atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + n = atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { + runtime.GOMAXPROCS(gomaxprocs) + // Number of active readers + 10000 * number of active writers. + var activity int32 + var rwm DowngradableRWMutex + cdone := make(chan bool) + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + var i int + for i = 0; i < numReaders/2; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + for ; i < numReaders; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + // Wait for the 4 writers and all readers to finish. + for i := 0; i < 4+numReaders; i++ { + <-cdone + } +} + +func TestDowngradableRWMutex(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + n := 1000 + if testing.Short() { + n = 5 + } + HammerDowngradableRWMutex(1, 1, n) + HammerDowngradableRWMutex(1, 3, n) + HammerDowngradableRWMutex(1, 10, n) + HammerDowngradableRWMutex(4, 1, n) + HammerDowngradableRWMutex(4, 3, n) + HammerDowngradableRWMutex(4, 10, n) + HammerDowngradableRWMutex(10, 1, n) + HammerDowngradableRWMutex(10, 3, n) + HammerDowngradableRWMutex(10, 10, n) + HammerDowngradableRWMutex(10, 5, n) +} diff --git a/pkg/sync/downgradable_rwmutex_unsafe.go b/pkg/sync/downgradable_rwmutex_unsafe.go new file mode 100644 index 000000000..9bb55cd3a --- /dev/null +++ b/pkg/sync/downgradable_rwmutex_unsafe.go @@ -0,0 +1,146 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +// This is mostly copied from the standard library's sync/rwmutex.go. +// +// Happens-before relationships indicated to the race detector: +// - Unlock -> Lock (via writerSem) +// - Unlock -> RLock (via readerSem) +// - RUnlock -> Lock (via writerSem) +// - DowngradeLock -> RLock (via readerSem) + +package sync + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +//go:linkname runtimeSemacquire sync.runtime_Semacquire +func runtimeSemacquire(s *uint32) + +//go:linkname runtimeSemrelease sync.runtime_Semrelease +func runtimeSemrelease(s *uint32, handoff bool, skipframes int) + +// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock +// method. +type DowngradableRWMutex struct { + w sync.Mutex // held if there are pending writers + writerSem uint32 // semaphore for writers to wait for completing readers + readerSem uint32 // semaphore for readers to wait for completing writers + readerCount int32 // number of pending readers + readerWait int32 // number of departing readers +} + +const rwmutexMaxReaders = 1 << 30 + +// RLock locks rw for reading. +func (rw *DowngradableRWMutex) RLock() { + if RaceEnabled { + RaceDisable() + } + if atomic.AddInt32(&rw.readerCount, 1) < 0 { + // A writer is pending, wait for it. + runtimeSemacquire(&rw.readerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.readerSem)) + } +} + +// RUnlock undoes a single RLock call. +func (rw *DowngradableRWMutex) RUnlock() { + if RaceEnabled { + RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) + RaceDisable() + } + if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 { + if r+1 == 0 || r+1 == -rwmutexMaxReaders { + panic("RUnlock of unlocked DowngradableRWMutex") + } + // A writer is pending. + if atomic.AddInt32(&rw.readerWait, -1) == 0 { + // The last reader unblocks the writer. + runtimeSemrelease(&rw.writerSem, false, 0) + } + } + if RaceEnabled { + RaceEnable() + } +} + +// Lock locks rw for writing. +func (rw *DowngradableRWMutex) Lock() { + if RaceEnabled { + RaceDisable() + } + // First, resolve competition with other writers. + rw.w.Lock() + // Announce to readers there is a pending writer. + r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders + // Wait for active readers. + if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 { + runtimeSemacquire(&rw.writerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.writerSem)) + } +} + +// Unlock unlocks rw for writing. +func (rw *DowngradableRWMutex) Unlock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.writerSem)) + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders) + if r >= rwmutexMaxReaders { + panic("Unlock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. + for i := 0; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} + +// DowngradeLock atomically unlocks rw for writing and locks it for reading. +func (rw *DowngradableRWMutex) DowngradeLock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer and one additional reader. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1) + if r >= rwmutexMaxReaders+1 { + panic("DowngradeLock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. Note that this loop starts as 1 since r + // includes this goroutine. + for i := 1; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed to rw.w.Lock(). Note that they will still + // block on rw.writerSem since at least this reader exists, such that + // DowngradeLock() is atomic with the previous write lock. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go new file mode 100644 index 000000000..ad4a3a37e --- /dev/null +++ b/pkg/sync/memmove_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.12 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +package sync + +import ( + "unsafe" +) + +//go:linkname memmove runtime.memmove +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) + +// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad, which can't +// define it because go_generics can't update the go:linkname annotation. +// Furthermore, go:linkname silently doesn't work if the local name is exported +// (this is of course undocumented), which is why this indirection is +// necessary. +func Memmove(to, from unsafe.Pointer, n uintptr) { + memmove(to, from, n) +} diff --git a/pkg/sync/norace_unsafe.go b/pkg/sync/norace_unsafe.go new file mode 100644 index 000000000..006055dd6 --- /dev/null +++ b/pkg/sync/norace_unsafe.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !race + +package sync + +import ( + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = false + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { +} diff --git a/pkg/sync/race_unsafe.go b/pkg/sync/race_unsafe.go new file mode 100644 index 000000000..31d8fa9a6 --- /dev/null +++ b/pkg/sync/race_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build race + +package sync + +import ( + "runtime" + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = true + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { + runtime.RaceDisable() +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { + runtime.RaceEnable() +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { + runtime.RaceAcquire(addr) +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { + runtime.RaceRelease(addr) +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { + runtime.RaceReleaseMerge(addr) +} diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go new file mode 100644 index 000000000..eda6fb131 --- /dev/null +++ b/pkg/sync/seqatomic_unsafe.go @@ -0,0 +1,72 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "fmt" + "reflect" + "strings" + "unsafe" + + "gvisor.dev/gvisor/pkg/sync" +) + +// Value is a required type parameter. +// +// Value must not contain any pointers, including interface objects, function +// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs +// containing any of the above. An init() function will panic if this property +// does not hold. +type Value struct{} + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in sc. +func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value { + // This function doesn't use SeqAtomicTryLoad because doing so is + // measurably, significantly (~20%) slower; Go is awful at inlining. + var val Value + for { + epoch := sc.BeginRead() + if sync.RaceEnabled { + // runtime.RaceDisable() doesn't actually stop the race detector, + // so it can't help us here. Instead, call runtime.memmove + // directly, which is not instrumented by the race detector. + sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + // This is ~40% faster for short reads than going through memmove. + val = *ptr + } + if sc.ReadOk(epoch) { + break + } + } + return val +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read +// would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) { + var val Value + if sync.RaceEnabled { + sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + val = *ptr + } + return val, sc.ReadOk(epoch) +} + +func init() { + var val Value + typ := reflect.TypeOf(val) + name := typ.Name() + if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 { + panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) + } +} diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD new file mode 100644 index 000000000..eba21518d --- /dev/null +++ b/pkg/sync/seqatomictest/BUILD @@ -0,0 +1,33 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "seqatomic_int", + out = "seqatomic_int_unsafe.go", + package = "seqatomic", + suffix = "Int", + template = "//pkg/sync:generic_seqatomic", + types = { + "Value": "int", + }, +) + +go_library( + name = "seqatomic", + srcs = ["seqatomic_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/sync/seqatomic", + deps = [ + "//pkg/sync", + ], +) + +go_test( + name = "seqatomic_test", + size = "small", + srcs = ["seqatomic_test.go"], + embed = [":seqatomic"], + deps = ["//pkg/sync"], +) diff --git a/pkg/sync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomictest/seqatomic_test.go new file mode 100644 index 000000000..2c4568b07 --- /dev/null +++ b/pkg/sync/seqatomictest/seqatomic_test.go @@ -0,0 +1,132 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seqatomic + +import ( + "sync/atomic" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +func TestSeqAtomicLoadUncontended(t *testing.T) { + var seq sync.SeqCount + const want = 1 + data := want + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadAfterWrite(t *testing.T) { + var seq sync.SeqCount + var data int + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadDuringWrite(t *testing.T) { + var seq sync.SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicTryLoadUncontended(t *testing.T) { + var seq sync.SeqCount + const want = 1 + data := want + epoch := seq.BeginRead() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } +} + +func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { + var seq sync.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } + seq.EndWrite() +} + +func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { + var seq sync.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } +} + +func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { + var seq sync.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := SeqAtomicLoadInt(&seq, &data); got != want { + b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } + } + }) +} + +func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { + var seq sync.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + epoch := seq.BeginRead() + for pb.Next() { + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } + } + }) +} + +// For comparison: +func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { + var a atomic.Value + const want = 42 + a.Store(int(want)) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := a.Load().(int); got != want { + b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) + } + } + }) +} diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go new file mode 100644 index 000000000..a1e895352 --- /dev/null +++ b/pkg/sync/seqcount.go @@ -0,0 +1,149 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "fmt" + "reflect" + "runtime" + "sync/atomic" +) + +// SeqCount is a synchronization primitive for optimistic reader/writer +// synchronization in cases where readers can work with stale data and +// therefore do not need to block writers. +// +// Compared to sync/atomic.Value: +// +// - Mutation of SeqCount-protected data does not require memory allocation, +// whereas atomic.Value generally does. This is a significant advantage when +// writes are common. +// +// - Atomic reads of SeqCount-protected data require copying. This is a +// disadvantage when atomic reads are common. +// +// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other +// operations to be made atomic with reads of SeqCount-protected data. +// +// - SeqCount may be less flexible: as of this writing, SeqCount-protected data +// cannot include pointers. +// +// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected +// data require instantiating function templates using go_generics (see +// seqatomic.go). +type SeqCount struct { + // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd + // if a writer critical section is active, and a read from data protected + // by this SeqCount is atomic iff epoch is the same even value before and + // after the read. + epoch uint32 +} + +// SeqCountEpoch tracks writer critical sections in a SeqCount. +type SeqCountEpoch struct { + val uint32 +} + +// We assume that: +// +// - All functions in sync/atomic that perform a memory read are at least a +// read fence: memory reads before calls to such functions cannot be reordered +// after the call, and memory reads after calls to such functions cannot be +// reordered before the call, even if those reads do not use sync/atomic. +// +// - All functions in sync/atomic that perform a memory write are at least a +// write fence: memory writes before calls to such functions cannot be +// reordered after the call, and memory writes after calls to such functions +// cannot be reordered before the call, even if those writes do not use +// sync/atomic. +// +// As of this writing, the Go memory model completely fails to describe +// sync/atomic, but these properties are implied by +// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8. + +// BeginRead indicates the beginning of a reader critical section. Reader +// critical sections DO NOT BLOCK writer critical sections, so operations in a +// reader critical section MAY RACE with writer critical sections. Races are +// detected by ReadOk at the end of the reader critical section. Thus, the +// low-level structure of readers is generally: +// +// for { +// epoch := seq.BeginRead() +// // do something idempotent with seq-protected data +// if seq.ReadOk(epoch) { +// break +// } +// } +// +// However, since reader critical sections may race with writer critical +// sections, the Go race detector will (accurately) flag data races in readers +// using this pattern. Most users of SeqCount will need to use the +// SeqAtomicLoad function template in seqatomic.go. +func (s *SeqCount) BeginRead() SeqCountEpoch { + epoch := atomic.LoadUint32(&s.epoch) + for epoch&1 != 0 { + runtime.Gosched() + epoch = atomic.LoadUint32(&s.epoch) + } + return SeqCountEpoch{epoch} +} + +// ReadOk returns true if the reader critical section initiated by a previous +// call to BeginRead() that returned epoch did not race with any writer critical +// sections. +// +// ReadOk may be called any number of times during a reader critical section. +// Reader critical sections do not need to be explicitly terminated; the last +// call to ReadOk is implicitly the end of the reader critical section. +func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { + return atomic.LoadUint32(&s.epoch) == epoch.val +} + +// BeginWrite indicates the beginning of a writer critical section. +// +// SeqCount does not support concurrent writer critical sections; clients with +// concurrent writers must synchronize them using e.g. sync.Mutex. +func (s *SeqCount) BeginWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 { + panic("SeqCount.BeginWrite during writer critical section") + } +} + +// EndWrite ends the effect of a preceding BeginWrite. +func (s *SeqCount) EndWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 { + panic("SeqCount.EndWrite outside writer critical section") + } +} + +// PointersInType returns a list of pointers reachable from values named +// valName of the given type. +// +// PointersInType is not exhaustive, but it is guaranteed that if typ contains +// at least one pointer, then PointersInTypeOf returns a non-empty list. +func PointersInType(typ reflect.Type, valName string) []string { + switch kind := typ.Kind(); kind { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return nil + + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer: + return []string{valName} + + case reflect.Array: + return PointersInType(typ.Elem(), valName+"[]") + + case reflect.Struct: + var ptrs []string + for i, n := 0, typ.NumField(); i < n; i++ { + field := typ.Field(i) + ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...) + } + return ptrs + + default: + return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)} + } +} diff --git a/pkg/sync/seqcount_test.go b/pkg/sync/seqcount_test.go new file mode 100644 index 000000000..6eb7b4b59 --- /dev/null +++ b/pkg/sync/seqcount_test.go @@ -0,0 +1,153 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "reflect" + "testing" + "time" +) + +func TestSeqCountWriteUncontended(t *testing.T) { + var seq SeqCount + seq.BeginWrite() + seq.EndWrite() +} + +func TestSeqCountReadUncontended(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadAfterWrite(t *testing.T) { + var seq SeqCount + var data int32 + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadDuringWrite(t *testing.T) { + var seq SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountReadOkAfterWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } +} + +func TestSeqCountReadOkDuringWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } + seq.EndWrite() +} + +func BenchmarkSeqCountWriteUncontended(b *testing.B) { + var seq SeqCount + for i := 0; i < b.N; i++ { + seq.BeginWrite() + seq.EndWrite() + } +} + +func BenchmarkSeqCountReadUncontended(b *testing.B) { + var seq SeqCount + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + b.Fatalf("ReadOk: got false, wanted true") + } + } + }) +} + +func TestPointersInType(t *testing.T) { + for _, test := range []struct { + name string // used for both test and value name + val interface{} + ptrs []string + }{ + { + name: "EmptyStruct", + val: struct{}{}, + }, + { + name: "Int", + val: int(0), + }, + { + name: "MixedStruct", + val: struct { + b bool + I int + ExportedPtr *struct{} + unexportedPtr *struct{} + arr [2]int + ptrArr [2]*int + nestedStruct struct { + nestedNonptr int + nestedPtr *int + } + structArr [1]struct { + nonptr int + ptr *int + } + }{}, + ptrs: []string{ + "MixedStruct.ExportedPtr", + "MixedStruct.unexportedPtr", + "MixedStruct.ptrArr[]", + "MixedStruct.nestedStruct.nestedPtr", + "MixedStruct.structArr[].ptr", + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + typ := reflect.TypeOf(test.val) + ptrs := PointersInType(typ, test.name) + t.Logf("Found pointers: %v", ptrs) + if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { + t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) + } + }) + } +} diff --git a/pkg/sync/syncutil.go b/pkg/sync/syncutil.go new file mode 100644 index 000000000..b16cf5333 --- /dev/null +++ b/pkg/sync/syncutil.go @@ -0,0 +1,7 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sync provides synchronization primitives. +package sync diff --git a/pkg/syncutil/BUILD b/pkg/syncutil/BUILD deleted file mode 100644 index cb1f41628..000000000 --- a/pkg/syncutil/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -exports_files(["LICENSE"]) - -go_template( - name = "generic_atomicptr", - srcs = ["atomicptr_unsafe.go"], - types = [ - "Value", - ], -) - -go_template( - name = "generic_seqatomic", - srcs = ["seqatomic_unsafe.go"], - types = [ - "Value", - ], - deps = [ - ":sync", - ], -) - -go_library( - name = "syncutil", - srcs = [ - "downgradable_rwmutex_unsafe.go", - "memmove_unsafe.go", - "norace_unsafe.go", - "race_unsafe.go", - "seqcount.go", - "syncutil.go", - ], - importpath = "gvisor.dev/gvisor/pkg/syncutil", -) - -go_test( - name = "syncutil_test", - size = "small", - srcs = [ - "downgradable_rwmutex_test.go", - "seqcount_test.go", - ], - embed = [":syncutil"], -) diff --git a/pkg/syncutil/LICENSE b/pkg/syncutil/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/syncutil/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/syncutil/README.md b/pkg/syncutil/README.md deleted file mode 100644 index 2183c4e20..000000000 --- a/pkg/syncutil/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Syncutil - -This package provides additional synchronization primitives not provided by the -Go stdlib 'sync' package. It is partially derived from the upstream 'sync' -package from go1.10. diff --git a/pkg/syncutil/atomicptr_unsafe.go b/pkg/syncutil/atomicptr_unsafe.go deleted file mode 100644 index 525c4beed..000000000 --- a/pkg/syncutil/atomicptr_unsafe.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package template doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package template - -import ( - "sync/atomic" - "unsafe" -) - -// Value is a required type parameter. -type Value struct{} - -// An AtomicPtr is a pointer to a value of type Value that can be atomically -// loaded and stored. The zero value of an AtomicPtr represents nil. -// -// Note that copying AtomicPtr by value performs a non-atomic read of the -// stored pointer, which is unsafe if Store() can be called concurrently; in -// this case, do `dst.Store(src.Load())` instead. -// -// +stateify savable -type AtomicPtr struct { - ptr unsafe.Pointer `state:".(*Value)"` -} - -func (p *AtomicPtr) savePtr() *Value { - return p.Load() -} - -func (p *AtomicPtr) loadPtr(v *Value) { - p.Store(v) -} - -// Load returns the value set by the most recent Store. It returns nil if there -// has been no previous call to Store. -func (p *AtomicPtr) Load() *Value { - return (*Value)(atomic.LoadPointer(&p.ptr)) -} - -// Store sets the value returned by Load to x. -func (p *AtomicPtr) Store(x *Value) { - atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) -} diff --git a/pkg/syncutil/atomicptrtest/BUILD b/pkg/syncutil/atomicptrtest/BUILD deleted file mode 100644 index 63f411a90..000000000 --- a/pkg/syncutil/atomicptrtest/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "atomicptr_int", - out = "atomicptr_int_unsafe.go", - package = "atomicptr", - suffix = "Int", - template = "//pkg/syncutil:generic_atomicptr", - types = { - "Value": "int", - }, -) - -go_library( - name = "atomicptr", - srcs = ["atomicptr_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/syncutil/atomicptr", -) - -go_test( - name = "atomicptr_test", - size = "small", - srcs = ["atomicptr_test.go"], - embed = [":atomicptr"], -) diff --git a/pkg/syncutil/atomicptrtest/atomicptr_test.go b/pkg/syncutil/atomicptrtest/atomicptr_test.go deleted file mode 100644 index 8fdc5112e..000000000 --- a/pkg/syncutil/atomicptrtest/atomicptr_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package atomicptr - -import ( - "testing" -) - -func newInt(val int) *int { - return &val -} - -func TestAtomicPtr(t *testing.T) { - var p AtomicPtrInt - if got := p.Load(); got != nil { - t.Errorf("initial value is %p (%v), wanted nil", got, got) - } - want := newInt(42) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } - want = newInt(100) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } -} diff --git a/pkg/syncutil/downgradable_rwmutex_test.go b/pkg/syncutil/downgradable_rwmutex_test.go deleted file mode 100644 index ffaf7ecc7..000000000 --- a/pkg/syncutil/downgradable_rwmutex_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// GOMAXPROCS=10 go test - -// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the -// addition of downgradingWriter and the renaming of num_iterations to -// numIterations to shut up Golint. - -package syncutil - -import ( - "fmt" - "runtime" - "sync/atomic" - "testing" -) - -func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) { - m.RLock() - clocked <- true - <-cunlock - m.RUnlock() - cdone <- true -} - -func doTestParallelReaders(numReaders, gomaxprocs int) { - runtime.GOMAXPROCS(gomaxprocs) - var m DowngradableRWMutex - clocked := make(chan bool) - cunlock := make(chan bool) - cdone := make(chan bool) - for i := 0; i < numReaders; i++ { - go parallelReader(&m, clocked, cunlock, cdone) - } - // Wait for all parallel RLock()s to succeed. - for i := 0; i < numReaders; i++ { - <-clocked - } - for i := 0; i < numReaders; i++ { - cunlock <- true - } - // Wait for the goroutines to finish. - for i := 0; i < numReaders; i++ { - <-cdone - } -} - -func TestParallelReaders(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - doTestParallelReaders(1, 4) - doTestParallelReaders(3, 4) - doTestParallelReaders(4, 2) -} - -func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.RLock() - n := atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.Unlock() - } - cdone <- true -} - -func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.DowngradeLock() - n = atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - n = atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { - runtime.GOMAXPROCS(gomaxprocs) - // Number of active readers + 10000 * number of active writers. - var activity int32 - var rwm DowngradableRWMutex - cdone := make(chan bool) - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - var i int - for i = 0; i < numReaders/2; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - for ; i < numReaders; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - // Wait for the 4 writers and all readers to finish. - for i := 0; i < 4+numReaders; i++ { - <-cdone - } -} - -func TestDowngradableRWMutex(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - n := 1000 - if testing.Short() { - n = 5 - } - HammerDowngradableRWMutex(1, 1, n) - HammerDowngradableRWMutex(1, 3, n) - HammerDowngradableRWMutex(1, 10, n) - HammerDowngradableRWMutex(4, 1, n) - HammerDowngradableRWMutex(4, 3, n) - HammerDowngradableRWMutex(4, 10, n) - HammerDowngradableRWMutex(10, 1, n) - HammerDowngradableRWMutex(10, 3, n) - HammerDowngradableRWMutex(10, 10, n) - HammerDowngradableRWMutex(10, 5, n) -} diff --git a/pkg/syncutil/downgradable_rwmutex_unsafe.go b/pkg/syncutil/downgradable_rwmutex_unsafe.go deleted file mode 100644 index 51e11555d..000000000 --- a/pkg/syncutil/downgradable_rwmutex_unsafe.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.13 -// +build !go1.15 - -// Check go:linkname function signatures when updating Go version. - -// This is mostly copied from the standard library's sync/rwmutex.go. -// -// Happens-before relationships indicated to the race detector: -// - Unlock -> Lock (via writerSem) -// - Unlock -> RLock (via readerSem) -// - RUnlock -> Lock (via writerSem) -// - DowngradeLock -> RLock (via readerSem) - -package syncutil - -import ( - "sync" - "sync/atomic" - "unsafe" -) - -//go:linkname runtimeSemacquire sync.runtime_Semacquire -func runtimeSemacquire(s *uint32) - -//go:linkname runtimeSemrelease sync.runtime_Semrelease -func runtimeSemrelease(s *uint32, handoff bool, skipframes int) - -// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock -// method. -type DowngradableRWMutex struct { - w sync.Mutex // held if there are pending writers - writerSem uint32 // semaphore for writers to wait for completing readers - readerSem uint32 // semaphore for readers to wait for completing writers - readerCount int32 // number of pending readers - readerWait int32 // number of departing readers -} - -const rwmutexMaxReaders = 1 << 30 - -// RLock locks rw for reading. -func (rw *DowngradableRWMutex) RLock() { - if RaceEnabled { - RaceDisable() - } - if atomic.AddInt32(&rw.readerCount, 1) < 0 { - // A writer is pending, wait for it. - runtimeSemacquire(&rw.readerSem) - } - if RaceEnabled { - RaceEnable() - RaceAcquire(unsafe.Pointer(&rw.readerSem)) - } -} - -// RUnlock undoes a single RLock call. -func (rw *DowngradableRWMutex) RUnlock() { - if RaceEnabled { - RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) - RaceDisable() - } - if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 { - if r+1 == 0 || r+1 == -rwmutexMaxReaders { - panic("RUnlock of unlocked DowngradableRWMutex") - } - // A writer is pending. - if atomic.AddInt32(&rw.readerWait, -1) == 0 { - // The last reader unblocks the writer. - runtimeSemrelease(&rw.writerSem, false, 0) - } - } - if RaceEnabled { - RaceEnable() - } -} - -// Lock locks rw for writing. -func (rw *DowngradableRWMutex) Lock() { - if RaceEnabled { - RaceDisable() - } - // First, resolve competition with other writers. - rw.w.Lock() - // Announce to readers there is a pending writer. - r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders - // Wait for active readers. - if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 { - runtimeSemacquire(&rw.writerSem) - } - if RaceEnabled { - RaceEnable() - RaceAcquire(unsafe.Pointer(&rw.writerSem)) - } -} - -// Unlock unlocks rw for writing. -func (rw *DowngradableRWMutex) Unlock() { - if RaceEnabled { - RaceRelease(unsafe.Pointer(&rw.writerSem)) - RaceRelease(unsafe.Pointer(&rw.readerSem)) - RaceDisable() - } - // Announce to readers there is no active writer. - r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders) - if r >= rwmutexMaxReaders { - panic("Unlock of unlocked DowngradableRWMutex") - } - // Unblock blocked readers, if any. - for i := 0; i < int(r); i++ { - runtimeSemrelease(&rw.readerSem, false, 0) - } - // Allow other writers to proceed. - rw.w.Unlock() - if RaceEnabled { - RaceEnable() - } -} - -// DowngradeLock atomically unlocks rw for writing and locks it for reading. -func (rw *DowngradableRWMutex) DowngradeLock() { - if RaceEnabled { - RaceRelease(unsafe.Pointer(&rw.readerSem)) - RaceDisable() - } - // Announce to readers there is no active writer and one additional reader. - r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1) - if r >= rwmutexMaxReaders+1 { - panic("DowngradeLock of unlocked DowngradableRWMutex") - } - // Unblock blocked readers, if any. Note that this loop starts as 1 since r - // includes this goroutine. - for i := 1; i < int(r); i++ { - runtimeSemrelease(&rw.readerSem, false, 0) - } - // Allow other writers to proceed to rw.w.Lock(). Note that they will still - // block on rw.writerSem since at least this reader exists, such that - // DowngradeLock() is atomic with the previous write lock. - rw.w.Unlock() - if RaceEnabled { - RaceEnable() - } -} diff --git a/pkg/syncutil/memmove_unsafe.go b/pkg/syncutil/memmove_unsafe.go deleted file mode 100644 index 348675baa..000000000 --- a/pkg/syncutil/memmove_unsafe.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.12 -// +build !go1.15 - -// Check go:linkname function signatures when updating Go version. - -package syncutil - -import ( - "unsafe" -) - -//go:linkname memmove runtime.memmove -//go:noescape -func memmove(to, from unsafe.Pointer, n uintptr) - -// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad, which can't -// define it because go_generics can't update the go:linkname annotation. -// Furthermore, go:linkname silently doesn't work if the local name is exported -// (this is of course undocumented), which is why this indirection is -// necessary. -func Memmove(to, from unsafe.Pointer, n uintptr) { - memmove(to, from, n) -} diff --git a/pkg/syncutil/norace_unsafe.go b/pkg/syncutil/norace_unsafe.go deleted file mode 100644 index 0a0a9deda..000000000 --- a/pkg/syncutil/norace_unsafe.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !race - -package syncutil - -import ( - "unsafe" -) - -// RaceEnabled is true if the Go data race detector is enabled. -const RaceEnabled = false - -// RaceDisable has the same semantics as runtime.RaceDisable. -func RaceDisable() { -} - -// RaceEnable has the same semantics as runtime.RaceEnable. -func RaceEnable() { -} - -// RaceAcquire has the same semantics as runtime.RaceAcquire. -func RaceAcquire(addr unsafe.Pointer) { -} - -// RaceRelease has the same semantics as runtime.RaceRelease. -func RaceRelease(addr unsafe.Pointer) { -} - -// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. -func RaceReleaseMerge(addr unsafe.Pointer) { -} diff --git a/pkg/syncutil/race_unsafe.go b/pkg/syncutil/race_unsafe.go deleted file mode 100644 index 206067ec1..000000000 --- a/pkg/syncutil/race_unsafe.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build race - -package syncutil - -import ( - "runtime" - "unsafe" -) - -// RaceEnabled is true if the Go data race detector is enabled. -const RaceEnabled = true - -// RaceDisable has the same semantics as runtime.RaceDisable. -func RaceDisable() { - runtime.RaceDisable() -} - -// RaceEnable has the same semantics as runtime.RaceEnable. -func RaceEnable() { - runtime.RaceEnable() -} - -// RaceAcquire has the same semantics as runtime.RaceAcquire. -func RaceAcquire(addr unsafe.Pointer) { - runtime.RaceAcquire(addr) -} - -// RaceRelease has the same semantics as runtime.RaceRelease. -func RaceRelease(addr unsafe.Pointer) { - runtime.RaceRelease(addr) -} - -// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. -func RaceReleaseMerge(addr unsafe.Pointer) { - runtime.RaceReleaseMerge(addr) -} diff --git a/pkg/syncutil/seqatomic_unsafe.go b/pkg/syncutil/seqatomic_unsafe.go deleted file mode 100644 index cb6d2eb22..000000000 --- a/pkg/syncutil/seqatomic_unsafe.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package template doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package template - -import ( - "fmt" - "reflect" - "strings" - "unsafe" - - "gvisor.dev/gvisor/pkg/syncutil" -) - -// Value is a required type parameter. -// -// Value must not contain any pointers, including interface objects, function -// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs -// containing any of the above. An init() function will panic if this property -// does not hold. -type Value struct{} - -// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race -// with any writer critical sections in sc. -func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value { - // This function doesn't use SeqAtomicTryLoad because doing so is - // measurably, significantly (~20%) slower; Go is awful at inlining. - var val Value - for { - epoch := sc.BeginRead() - if syncutil.RaceEnabled { - // runtime.RaceDisable() doesn't actually stop the race detector, - // so it can't help us here. Instead, call runtime.memmove - // directly, which is not instrumented by the race detector. - syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - // This is ~40% faster for short reads than going through memmove. - val = *ptr - } - if sc.ReadOk(epoch) { - break - } - } - return val -} - -// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section -// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read -// would race with a writer critical section, SeqAtomicTryLoad returns -// (unspecified, false). -func SeqAtomicTryLoad(sc *syncutil.SeqCount, epoch syncutil.SeqCountEpoch, ptr *Value) (Value, bool) { - var val Value - if syncutil.RaceEnabled { - syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - val = *ptr - } - return val, sc.ReadOk(epoch) -} - -func init() { - var val Value - typ := reflect.TypeOf(val) - name := typ.Name() - if ptrs := syncutil.PointersInType(typ, name); len(ptrs) != 0 { - panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) - } -} diff --git a/pkg/syncutil/seqatomictest/BUILD b/pkg/syncutil/seqatomictest/BUILD deleted file mode 100644 index ba18f3238..000000000 --- a/pkg/syncutil/seqatomictest/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "seqatomic_int", - out = "seqatomic_int_unsafe.go", - package = "seqatomic", - suffix = "Int", - template = "//pkg/syncutil:generic_seqatomic", - types = { - "Value": "int", - }, -) - -go_library( - name = "seqatomic", - srcs = ["seqatomic_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/syncutil/seqatomic", - deps = [ - "//pkg/syncutil", - ], -) - -go_test( - name = "seqatomic_test", - size = "small", - srcs = ["seqatomic_test.go"], - embed = [":seqatomic"], - deps = [ - "//pkg/syncutil", - ], -) diff --git a/pkg/syncutil/seqatomictest/seqatomic_test.go b/pkg/syncutil/seqatomictest/seqatomic_test.go deleted file mode 100644 index b0db44999..000000000 --- a/pkg/syncutil/seqatomictest/seqatomic_test.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seqatomic - -import ( - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/syncutil" -) - -func TestSeqAtomicLoadUncontended(t *testing.T) { - var seq syncutil.SeqCount - const want = 1 - data := want - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadAfterWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadDuringWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicTryLoadUncontended(t *testing.T) { - var seq syncutil.SeqCount - const want = 1 - data := want - epoch := seq.BeginRead() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } -} - -func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } - seq.EndWrite() -} - -func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } -} - -func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { - var seq syncutil.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := SeqAtomicLoadInt(&seq, &data); got != want { - b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } - } - }) -} - -func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { - var seq syncutil.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - epoch := seq.BeginRead() - for pb.Next() { - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } - } - }) -} - -// For comparison: -func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { - var a atomic.Value - const want = 42 - a.Store(int(want)) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := a.Load().(int); got != want { - b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) - } - } - }) -} diff --git a/pkg/syncutil/seqcount.go b/pkg/syncutil/seqcount.go deleted file mode 100644 index 11d8dbfaa..000000000 --- a/pkg/syncutil/seqcount.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package syncutil - -import ( - "fmt" - "reflect" - "runtime" - "sync/atomic" -) - -// SeqCount is a synchronization primitive for optimistic reader/writer -// synchronization in cases where readers can work with stale data and -// therefore do not need to block writers. -// -// Compared to sync/atomic.Value: -// -// - Mutation of SeqCount-protected data does not require memory allocation, -// whereas atomic.Value generally does. This is a significant advantage when -// writes are common. -// -// - Atomic reads of SeqCount-protected data require copying. This is a -// disadvantage when atomic reads are common. -// -// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other -// operations to be made atomic with reads of SeqCount-protected data. -// -// - SeqCount may be less flexible: as of this writing, SeqCount-protected data -// cannot include pointers. -// -// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected -// data require instantiating function templates using go_generics (see -// seqatomic.go). -type SeqCount struct { - // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd - // if a writer critical section is active, and a read from data protected - // by this SeqCount is atomic iff epoch is the same even value before and - // after the read. - epoch uint32 -} - -// SeqCountEpoch tracks writer critical sections in a SeqCount. -type SeqCountEpoch struct { - val uint32 -} - -// We assume that: -// -// - All functions in sync/atomic that perform a memory read are at least a -// read fence: memory reads before calls to such functions cannot be reordered -// after the call, and memory reads after calls to such functions cannot be -// reordered before the call, even if those reads do not use sync/atomic. -// -// - All functions in sync/atomic that perform a memory write are at least a -// write fence: memory writes before calls to such functions cannot be -// reordered after the call, and memory writes after calls to such functions -// cannot be reordered before the call, even if those writes do not use -// sync/atomic. -// -// As of this writing, the Go memory model completely fails to describe -// sync/atomic, but these properties are implied by -// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8. - -// BeginRead indicates the beginning of a reader critical section. Reader -// critical sections DO NOT BLOCK writer critical sections, so operations in a -// reader critical section MAY RACE with writer critical sections. Races are -// detected by ReadOk at the end of the reader critical section. Thus, the -// low-level structure of readers is generally: -// -// for { -// epoch := seq.BeginRead() -// // do something idempotent with seq-protected data -// if seq.ReadOk(epoch) { -// break -// } -// } -// -// However, since reader critical sections may race with writer critical -// sections, the Go race detector will (accurately) flag data races in readers -// using this pattern. Most users of SeqCount will need to use the -// SeqAtomicLoad function template in seqatomic.go. -func (s *SeqCount) BeginRead() SeqCountEpoch { - epoch := atomic.LoadUint32(&s.epoch) - for epoch&1 != 0 { - runtime.Gosched() - epoch = atomic.LoadUint32(&s.epoch) - } - return SeqCountEpoch{epoch} -} - -// ReadOk returns true if the reader critical section initiated by a previous -// call to BeginRead() that returned epoch did not race with any writer critical -// sections. -// -// ReadOk may be called any number of times during a reader critical section. -// Reader critical sections do not need to be explicitly terminated; the last -// call to ReadOk is implicitly the end of the reader critical section. -func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { - return atomic.LoadUint32(&s.epoch) == epoch.val -} - -// BeginWrite indicates the beginning of a writer critical section. -// -// SeqCount does not support concurrent writer critical sections; clients with -// concurrent writers must synchronize them using e.g. sync.Mutex. -func (s *SeqCount) BeginWrite() { - if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 { - panic("SeqCount.BeginWrite during writer critical section") - } -} - -// EndWrite ends the effect of a preceding BeginWrite. -func (s *SeqCount) EndWrite() { - if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 { - panic("SeqCount.EndWrite outside writer critical section") - } -} - -// PointersInType returns a list of pointers reachable from values named -// valName of the given type. -// -// PointersInType is not exhaustive, but it is guaranteed that if typ contains -// at least one pointer, then PointersInTypeOf returns a non-empty list. -func PointersInType(typ reflect.Type, valName string) []string { - switch kind := typ.Kind(); kind { - case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: - return nil - - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer: - return []string{valName} - - case reflect.Array: - return PointersInType(typ.Elem(), valName+"[]") - - case reflect.Struct: - var ptrs []string - for i, n := 0, typ.NumField(); i < n; i++ { - field := typ.Field(i) - ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...) - } - return ptrs - - default: - return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)} - } -} diff --git a/pkg/syncutil/seqcount_test.go b/pkg/syncutil/seqcount_test.go deleted file mode 100644 index 14d6aedea..000000000 --- a/pkg/syncutil/seqcount_test.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package syncutil - -import ( - "reflect" - "testing" - "time" -) - -func TestSeqCountWriteUncontended(t *testing.T) { - var seq SeqCount - seq.BeginWrite() - seq.EndWrite() -} - -func TestSeqCountReadUncontended(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadAfterWrite(t *testing.T) { - var seq SeqCount - var data int32 - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadDuringWrite(t *testing.T) { - var seq SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountReadOkAfterWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } -} - -func TestSeqCountReadOkDuringWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } - seq.EndWrite() -} - -func BenchmarkSeqCountWriteUncontended(b *testing.B) { - var seq SeqCount - for i := 0; i < b.N; i++ { - seq.BeginWrite() - seq.EndWrite() - } -} - -func BenchmarkSeqCountReadUncontended(b *testing.B) { - var seq SeqCount - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - b.Fatalf("ReadOk: got false, wanted true") - } - } - }) -} - -func TestPointersInType(t *testing.T) { - for _, test := range []struct { - name string // used for both test and value name - val interface{} - ptrs []string - }{ - { - name: "EmptyStruct", - val: struct{}{}, - }, - { - name: "Int", - val: int(0), - }, - { - name: "MixedStruct", - val: struct { - b bool - I int - ExportedPtr *struct{} - unexportedPtr *struct{} - arr [2]int - ptrArr [2]*int - nestedStruct struct { - nestedNonptr int - nestedPtr *int - } - structArr [1]struct { - nonptr int - ptr *int - } - }{}, - ptrs: []string{ - "MixedStruct.ExportedPtr", - "MixedStruct.unexportedPtr", - "MixedStruct.ptrArr[]", - "MixedStruct.nestedStruct.nestedPtr", - "MixedStruct.structArr[].ptr", - }, - }, - } { - t.Run(test.name, func(t *testing.T) { - typ := reflect.TypeOf(test.val) - ptrs := PointersInType(typ, test.name) - t.Logf("Found pointers: %v", ptrs) - if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { - t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) - } - }) - } -} diff --git a/pkg/syncutil/syncutil.go b/pkg/syncutil/syncutil.go deleted file mode 100644 index 66e750d06..000000000 --- a/pkg/syncutil/syncutil.go +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package syncutil provides synchronization primitives. -package syncutil diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index e07ebd153..db06d02c6 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -15,6 +15,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip/buffer", "//pkg/tcpip/iptables", "//pkg/waiter", diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 78df5a0b1..3df7d18d3 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index cd6ce930a..a2f44b496 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -20,9 +20,9 @@ import ( "errors" "io" "net" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 897c94821..66cc53ed4 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -16,6 +16,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index fa8a703d9..b7f60178e 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,10 +41,10 @@ package fdbased import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index a4f9cdd69..09165dd4c 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -15,6 +15,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -31,6 +32,7 @@ go_test( ], embed = [":sharedmem"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 6b5bc542c..a0d4ad0be 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -21,4 +21,5 @@ go_test( "pipe_test.go", ], embed = [":pipe"], + deps = ["//pkg/sync"], ) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index 59ef69a8b..dc239a0d0 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -18,8 +18,9 @@ import ( "math/rand" "reflect" "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) func TestSimpleReadWrite(t *testing.T) { diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 080f9d667..655e537c4 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -23,11 +23,11 @@ package sharedmem import ( - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 89603c48f..5c729a439 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -22,11 +22,11 @@ import ( "math/rand" "os" "strings" - "sync" "syscall" "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index acf1e022c..ed16076fd 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -28,6 +28,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", ], diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 6da5238ec..92f2aa13a 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -19,9 +19,9 @@ package fragmentation import ( "fmt" "log" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9e002e396..0a83d81f2 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -18,9 +18,9 @@ import ( "container/heap" "fmt" "math" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index e156b01f6..a6ef3bdcc 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 6c5e19e8f..b937cb84b 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -18,9 +18,9 @@ package ports import ( "math" "math/rand" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 826fca4de..6a8654105 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -36,6 +36,7 @@ go_library( "//pkg/ilist", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", @@ -80,6 +81,7 @@ go_test( embed = [":stack"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 267df60d1..403557fd7 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -16,10 +16,10 @@ package stack import ( "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 9946b8fe8..1baa498d0 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -16,12 +16,12 @@ package stack import ( "fmt" - "sync" "sync/atomic" "testing" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3810c6602..fe557ccbd 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,9 +16,9 @@ package stack import ( "strings" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 41bf9fd9b..a47ceba54 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -21,13 +21,13 @@ package stack import ( "encoding/binary" - "sync" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 67c21be42..f384a91de 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -18,8 +18,8 @@ import ( "fmt" "math/rand" "sort" - "sync" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 72b5ce179..4a090ac86 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -35,10 +35,10 @@ import ( "reflect" "strconv" "strings" - "sync" "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/waiter" diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index d8c5b5058..3aa23d529 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -28,6 +28,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index c7ce74cdd..330786f4c 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,8 +15,7 @@ package icmp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index 44b58ff6b..4858d150c 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -28,6 +28,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 07ffa8aba..fc5bc69fa 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,8 +25,7 @@ package packet import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 00991ac8e..2f2131ff7 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -29,6 +29,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 85f7eb76b..ee9c4c58b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,8 +26,7 @@ package raw import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 3b353d56c..353bd06f4 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -48,6 +48,7 @@ go_library( "//pkg/log", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 5422ae80c..1ea996936 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -19,11 +19,11 @@ import ( "encoding/binary" "hash" "io" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index cdd69f360..613ec1775 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -16,11 +16,11 @@ package tcp import ( "encoding/binary" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 830bc1e3e..cca511fb9 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -19,12 +19,12 @@ import ( "fmt" "math" "strings" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 7aa4c3f0e..4b8d867bc 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -16,9 +16,9 @@ package tcp import ( "fmt" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 4983bca81..7eb613be5 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -15,8 +15,7 @@ package tcp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index bc718064c..9a8f64aa6 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -22,9 +22,9 @@ package tcp import ( "strings" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index e0759225e..bd20a7ee9 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -15,7 +15,7 @@ package tcp import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // segmentQueue is a bounded, thread-safe queue of TCP segments. diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8a947dc66..79f2d274b 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -16,11 +16,11 @@ package tcp import ( "math" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 97e4d5825..57ff123e3 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -30,6 +30,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 864dc8733..a4ff29a7d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,8 +15,7 @@ package udp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 6afdb29b7..07778e4f7 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -15,4 +15,5 @@ go_test( size = "medium", srcs = ["tmutex_test.go"], embed = [":tmutex"], + deps = ["//pkg/sync"], ) diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go index ce34c7962..05540696a 100644 --- a/pkg/tmutex/tmutex_test.go +++ b/pkg/tmutex/tmutex_test.go @@ -17,10 +17,11 @@ package tmutex import ( "fmt" "runtime" - "sync" "sync/atomic" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func TestBasicLock(t *testing.T) { diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index 8f6f180e5..d1885ae66 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -24,4 +24,5 @@ go_test( "unet_test.go", ], embed = [":unet"], + deps = ["//pkg/sync"], ) diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go index a3cc6f5d3..5c4b9e8e9 100644 --- a/pkg/unet/unet_test.go +++ b/pkg/unet/unet_test.go @@ -19,10 +19,11 @@ import ( "os" "path/filepath" "reflect" - "sync" "syscall" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func randomFilename() (string, error) { diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b6bbb0ea2..b8fdc3125 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -11,6 +11,7 @@ go_library( deps = [ "//pkg/fd", "//pkg/log", + "//pkg/sync", "//pkg/unet", ], ) diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go index df59ffab1..13b2ea314 100644 --- a/pkg/urpc/urpc.go +++ b/pkg/urpc/urpc.go @@ -27,10 +27,10 @@ import ( "os" "reflect" "runtime" - "sync" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 0427bc41f..1c6890e52 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -24,6 +24,7 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/waiter", visibility = ["//visibility:public"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 8a65ed164..f708e95fa 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -58,7 +58,7 @@ package waiter import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // EventMask represents io events as used in the poll() syscall. diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 6226b63f8..3e20f8f2f 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -74,6 +74,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/sentry/watchdog", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/link/fdbased", @@ -114,6 +115,7 @@ go_test( "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//runsc/fsgofer", "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go index 352e710d2..9c23b9553 100644 --- a/runsc/boot/compat.go +++ b/runsc/boot/compat.go @@ -17,7 +17,6 @@ package boot import ( "fmt" "os" - "sync" "syscall" "github.com/golang/protobuf/proto" @@ -27,6 +26,7 @@ import ( ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" "gvisor.dev/gvisor/pkg/sentry/strace" spb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto" + "gvisor.dev/gvisor/pkg/sync" ) func initCompatLogs(fd int) error { diff --git a/runsc/boot/limits.go b/runsc/boot/limits.go index d1c0bb9b5..ce62236e5 100644 --- a/runsc/boot/limits.go +++ b/runsc/boot/limits.go @@ -16,12 +16,12 @@ package boot import ( "fmt" - "sync" "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) // Mapping from linux resource names to limits.LimitType. diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index bc1d0c1bb..fad72f4ab 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -20,7 +20,6 @@ import ( mrand "math/rand" "os" "runtime" - "sync" "sync/atomic" "syscall" gtime "time" @@ -46,6 +45,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/watchdog" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index 147ff7703..bec0dc292 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -19,7 +19,6 @@ import ( "math/rand" "os" "reflect" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/runsc/fsgofer" ) diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index 250845ad7..b94bc4fa0 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -44,6 +44,7 @@ go_library( "//pkg/sentry/control", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", "//runsc/boot", diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go index a4e3071b3..1815c93b9 100644 --- a/runsc/cmd/create.go +++ b/runsc/cmd/create.go @@ -16,6 +16,7 @@ package cmd import ( "context" + "flag" "github.com/google/subcommands" "gvisor.dev/gvisor/runsc/boot" diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 4831210c0..7df7995f0 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -21,7 +21,6 @@ import ( "os" "path/filepath" "strings" - "sync" "syscall" "flag" @@ -30,6 +29,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/fsgofer" diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go index de2115dff..5e9bc53ab 100644 --- a/runsc/cmd/start.go +++ b/runsc/cmd/start.go @@ -16,6 +16,7 @@ package cmd import ( "context" + "flag" "github.com/google/subcommands" "gvisor.dev/gvisor/runsc/boot" diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 2bd12120d..6dea179e4 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -18,6 +18,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sentry/control", + "//pkg/sync", "//runsc/boot", "//runsc/cgroup", "//runsc/sandbox", @@ -53,6 +54,7 @@ go_test( "//pkg/sentry/control", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", "//runsc/boot", diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 5ed131a7f..060b63bf3 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -20,7 +20,6 @@ import ( "io" "os" "path/filepath" - "sync" "syscall" "testing" "time" @@ -29,6 +28,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/testutil" diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index c10f85992..b54d8f712 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -26,7 +26,6 @@ import ( "reflect" "strconv" "strings" - "sync" "syscall" "testing" "time" @@ -39,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" "gvisor.dev/gvisor/runsc/specutils" diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 4ad09ceab..2da93ec5b 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -22,7 +22,6 @@ import ( "path" "path/filepath" "strings" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" "gvisor.dev/gvisor/runsc/testutil" diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go index d95151ea5..17a251530 100644 --- a/runsc/container/state_file.go +++ b/runsc/container/state_file.go @@ -20,10 +20,10 @@ import ( "io/ioutil" "os" "path/filepath" - "sync" "github.com/gofrs/flock" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" ) const stateFileExtension = ".state" diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index afcb41801..a9582d92b 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/fd", "//pkg/log", "//pkg/p9", + "//pkg/sync", "//pkg/syserr", "//runsc/specutils", "@org_golang_x_sys//unix:go_default_library", diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index b59e1a70e..93606d051 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -29,7 +29,6 @@ import ( "path/filepath" "runtime" "strconv" - "sync" "syscall" "golang.org/x/sys/unix" @@ -37,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/specutils" ) diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index 8001949d5..ddbc37456 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/log", "//pkg/sentry/control", "//pkg/sentry/platform", + "//pkg/sync", "//pkg/tcpip/header", "//pkg/tcpip/stack", "//pkg/urpc", diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index ce1452b87..ec72bdbfd 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -22,7 +22,6 @@ import ( "os" "os/exec" "strconv" - "sync" "syscall" "time" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD index c96ca2eb6..3c3027cb5 100644 --- a/runsc/testutil/BUILD +++ b/runsc/testutil/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", "//runsc/boot", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go index 9632776d2..fb22eae39 100644 --- a/runsc/testutil/testutil.go +++ b/runsc/testutil/testutil.go @@ -34,7 +34,6 @@ import ( "path/filepath" "strconv" "strings" - "sync" "sync/atomic" "syscall" "time" @@ -42,6 +41,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" ) -- cgit v1.2.3 From dacd349d6fb4fc7453b1fbf694158fd25496ed42 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Fri, 10 Jan 2020 06:01:10 -0800 Subject: panic fix in retransmitTimerExpired. This is a band-aid fix for now to prevent panics. PiperOrigin-RevId: 289078453 --- pkg/tcpip/transport/tcp/snd.go | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'pkg/tcpip/transport/tcp/snd.go') diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 79f2d274b..fdff7ed81 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -442,6 +442,13 @@ func (s *sender) retransmitTimerExpired() bool { return true } + // TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases + // when writeList is empty. Remove this once we have a proper fix for this + // issue. + if s.writeList.Front() == nil { + return true + } + s.ep.stack.Stats().TCP.Timeouts.Increment() s.ep.stats.SendErrors.Timeouts.Increment() -- cgit v1.2.3 From a611fdaee3c14abe2222140ae0a8a742ebfd31ab Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 14 Jan 2020 14:14:17 -0800 Subject: Changes TCP packet dispatch to use a pool of goroutines. All inbound segments for connections in ESTABLISHED state are delivered to the endpoint's queue but for every segment delivered we also queue the endpoint for processing to a selected processor. This ensures that when there are a large number of connections in ESTABLISHED state the inbound packets are all handled by a small number of goroutines and significantly reduces the amount of work the goscheduler has to perform. We let connections in other states follow the current path where the endpoint's goroutine directly handles the segments. Updates #231 PiperOrigin-RevId: 289728325 --- benchmarks/tcp/tcp_proxy.go | 6 +- pkg/sleep/sleep_test.go | 31 +++ pkg/tcpip/stack/transport_demuxer.go | 54 ++++- pkg/tcpip/transport/tcp/BUILD | 15 +- pkg/tcpip/transport/tcp/accept.go | 9 +- pkg/tcpip/transport/tcp/connect.go | 310 ++++++++++++++++------------ pkg/tcpip/transport/tcp/dispatcher.go | 218 +++++++++++++++++++ pkg/tcpip/transport/tcp/endpoint.go | 303 ++++++++++++++++++--------- pkg/tcpip/transport/tcp/endpoint_state.go | 30 +-- pkg/tcpip/transport/tcp/protocol.go | 11 + pkg/tcpip/transport/tcp/rcv.go | 21 +- pkg/tcpip/transport/tcp/snd.go | 14 +- pkg/tcpip/transport/tcp/tcp_test.go | 11 +- test/syscalls/linux/socket_inet_loopback.cc | 2 +- test/syscalls/linux/tcp_socket.cc | 14 ++ 15 files changed, 769 insertions(+), 280 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/dispatcher.go (limited to 'pkg/tcpip/transport/tcp/snd.go') diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index be0d7bdd6..dc96add66 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -85,7 +85,7 @@ func (netImpl) printStats() { const ( nicID = 1 // Fixed. - rcvBufSize = 1 << 20 // 1MB. + rcvBufSize = 4 << 20 // 1MB. ) type netstackImpl struct { @@ -130,6 +130,10 @@ func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) { return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) } + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, rcvBufSize); err != nil { + return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) + } + if !*swgso && *gso != 0 { if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go index 130806c86..af47e2ba1 100644 --- a/pkg/sleep/sleep_test.go +++ b/pkg/sleep/sleep_test.go @@ -376,6 +376,37 @@ func TestRace(t *testing.T) { } } +// TestRaceInOrder tests that multiple wakers can continuously send wake requests to +// the sleeper and that the wakers are retrieved in the order asserted. +func TestRaceInOrder(t *testing.T) { + const wakers = 100 + const wakeRequests = 10000 + + w := make([]Waker, wakers) + s := Sleeper{} + + // Associate each waker and start goroutines that will assert them. + for i := range w { + s.AddWaker(&w[i], i) + } + go func() { + n := 0 + for n < wakeRequests { + wk := w[n%len(w)] + wk.Assert() + n++ + } + }() + + // Wait for all wake up notifications from all wakers. + for i := 0; i < wakeRequests; i++ { + v, _ := s.Fetch(true) + if got, want := v, i%wakers; got != want { + t.Fatalf("got %d want %d", got, want) + } + } +} + // BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up // from 4 wakers when at least one is already asserted. func BenchmarkSleeperMultiSelect(b *testing.B) { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index f384a91de..d686e6eb8 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -104,7 +104,14 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p return } // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) + transEP := selectEndpoint(id, mpep, epsByNic.seed) + if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { + queuedProtocol.QueuePacket(r, transEP, id, pkt) + epsByNic.mu.RUnlock() + return + } + + transEP.HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. } @@ -130,7 +137,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { epsByNic.mu.Lock() defer epsByNic.mu.Unlock() @@ -140,7 +147,7 @@ func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort } // This is a new binding. - multiPortEp := &multiPortEndpoint{} + multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto} multiPortEp.endpointsMap = make(map[TransportEndpoint]int) multiPortEp.reuse = reusePort epsByNic.endpoints[bindToDevice] = multiPortEp @@ -168,18 +175,34 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T // newTransportDemuxer. type transportDemuxer struct { // protocol is immutable. - protocol map[protocolIDs]*transportEndpoints + protocol map[protocolIDs]*transportEndpoints + queuedProtocols map[protocolIDs]queuedTransportProtocol +} + +// queuedTransportProtocol if supported by a protocol implementation will cause +// the dispatcher to delivery packets to the QueuePacket method instead of +// calling HandlePacket directly on the endpoint. +type queuedTransportProtocol interface { + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { - d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} + d := &transportDemuxer{ + protocol: make(map[protocolIDs]*transportEndpoints), + queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), + } // Add each network and transport pair to the demuxer. for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { - d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ + protoIDs := protocolIDs{netProto, proto} + d.protocol[protoIDs] = &transportEndpoints{ endpoints: make(map[TransportEndpointID]*endpointsByNic), } + qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) + if isQueued { + d.queuedProtocols[protoIDs] = qTransProto + } } } @@ -209,7 +232,11 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // // +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + demux *transportDemuxer + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int // reuse indicates if more than one endpoint is allowed. @@ -258,13 +285,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() + queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] for i, endpoint := range ep.endpointsArr { // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. if i == len(ep.endpointsArr)-1 { + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt) + break + } endpoint.HandlePacket(r, id, pkt) break } + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + continue + } endpoint.HandlePacket(r, id, pkt.Clone()) } ep.mu.RUnlock() // Don't use defer for performance reasons. @@ -357,7 +393,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if epsByNic, ok := eps.endpoints[id]; ok { // There was already a binding. - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // This is a new binding. @@ -367,7 +403,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol } eps.endpoints[id] = epsByNic - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 353bd06f4..0e3ab05ad 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -16,6 +16,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tcp_endpoint_list", + out = "tcp_endpoint_list.go", + package = "tcp", + prefix = "endpoint", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*endpoint", + "Linker": "*endpoint", + }, +) + go_library( name = "tcp", srcs = [ @@ -23,6 +35,7 @@ go_library( "connect.go", "cubic.go", "cubic_state.go", + "dispatcher.go", "endpoint.go", "endpoint_state.go", "forwarder.go", @@ -38,6 +51,7 @@ go_library( "segment_state.go", "snd.go", "snd_state.go", + "tcp_endpoint_list.go", "tcp_segment_list.go", "timer.go", ], @@ -45,7 +59,6 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ - "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 1ea996936..1a2e3efa9 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -285,7 +285,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // listenEP is nil when listenContext is used by tcp.Forwarder. if l.listenEP != nil { l.listenEP.mu.Lock() - if l.listenEP.state != StateListen { + if l.listenEP.EndpointState() != StateListen { l.listenEP.mu.Unlock() return nil, tcpip.ErrConnectionAborted } @@ -344,11 +344,12 @@ func (l *listenContext) closeAllPendingEndpoints() { // instead. func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.Lock() - state := e.state + state := e.EndpointState() e.pendingAccepted.Add(1) defer e.pendingAccepted.Done() acceptedChan := e.acceptedChan e.mu.Unlock() + if state == StateListen { acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) @@ -562,8 +563,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // We do not use transitionToStateEstablishedLocked here as there is // no handshake state available when doing a SYN cookie based accept. n.stack.Stats().TCP.CurrentEstablished.Increment() - n.state = StateEstablished n.isConnectNotified = true + n.setEndpointState(StateEstablished) // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -596,7 +597,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // handleSynSegment() from attempting to queue new connections // to the endpoint. e.mu.Lock() - e.state = StateClose + e.setEndpointState(StateClose) // close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 613ec1775..f3896715b 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -190,7 +190,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.ep.mu.Lock() - h.ep.state = StateSynRecv + h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() } @@ -274,14 +274,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // SYN-RCVD state. h.state = handshakeSynRcvd h.ep.mu.Lock() - h.ep.state = StateSynRecv ttl := h.ep.ttl + h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), TS: rcvSynOpts.TS, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), // We only send SACKPermitted if the other side indicated it // permits SACK. This is not explicitly defined in the RFC but @@ -341,7 +341,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { WS: h.rcvWndScale, TS: h.ep.sendTSOk, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } @@ -501,7 +501,7 @@ func (h *handshake) execute() *tcpip.Error { WS: h.rcvWndScale, TS: true, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } @@ -792,7 +792,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // Ref: https://tools.ietf.org/html/rfc7323#section-5.4. offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:]) + offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) } if e.sackPermitted && len(sackBlocks) > 0 { offset += header.EncodeNOP(options[offset:]) @@ -811,7 +811,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -848,6 +848,9 @@ func (e *endpoint) handleWrite() *tcpip.Error { } func (e *endpoint) handleClose() *tcpip.Error { + if !e.EndpointState().connected() { + return nil + } // Drain the send queue. e.handleWrite() @@ -864,11 +867,7 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. - if e.state == StateEstablished || e.state == StateCloseWait { - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() - } - e.state = StateError + e.setEndpointState(StateError) e.HardError = err if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the @@ -888,9 +887,12 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { } // completeWorkerLocked is called by the worker goroutine when it's about to -// exit. It marks the worker as completed and performs cleanup work if requested -// by Close(). +// exit. func (e *endpoint) completeWorkerLocked() { + // Worker is terminating(either due to moving to + // CLOSED or ERROR state, ensure we release all + // registrations port reservations even if the socket + // itself is not yet closed by the application. e.workerRunning = false if e.workerCleanup { e.cleanupLocked() @@ -917,8 +919,7 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { e.rcvAutoParams.prevCopied = int(h.rcvWnd) e.rcvListMu.Unlock() } - h.ep.stack.Stats().TCP.CurrentEstablished.Increment() - e.state = StateEstablished + e.setEndpointState(StateEstablished) } // transitionToStateCloseLocked ensures that the endpoint is @@ -927,11 +928,12 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // delivered to this endpoint from the demuxer when the endpoint // is transitioned to StateClose. func (e *endpoint) transitionToStateCloseLocked() { - if e.state == StateClose { + if e.EndpointState() == StateClose { return } + // Mark the endpoint as fully closed for reads/writes. e.cleanupLocked() - e.state = StateClose + e.setEndpointState(StateClose) e.stack.Stats().TCP.EstablishedClosed.Increment() } @@ -946,7 +948,9 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { s.decRef() return } - ep.(*endpoint).enqueueSegment(s) + if ep.(*endpoint).enqueueSegment(s) { + ep.(*endpoint).newSegmentWaker.Assert() + } } func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { @@ -955,9 +959,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // except SYN-SENT, all reset (RST) segments are // validated by checking their SEQ-fields." So // we only process it if it's acceptable. - s.decRef() e.mu.Lock() - switch e.state { + switch e.EndpointState() { // In case of a RST in CLOSE-WAIT linux moves // the socket to closed state with an error set // to indicate EPIPE. @@ -981,103 +984,57 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { e.transitionToStateCloseLocked() e.HardError = tcpip.ErrAborted e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: e.mu.Unlock() + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + + // Notify protocol goroutine. This is required when + // handleSegment is invoked from the processor goroutine + // rather than the worker goroutine. + e.notifyProtocolGoroutine(notifyResetByPeer) return false, tcpip.ErrConnectionReset } } return true, nil } -// handleSegments pulls segments from the queue and processes them. It returns -// no error if the protocol loop should continue, an error otherwise. -func (e *endpoint) handleSegments() *tcpip.Error { +// handleSegments processes all inbound segments. +func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { + if e.EndpointState() == StateClose || e.EndpointState() == StateError { + return nil + } s := e.segmentQueue.dequeue() if s == nil { checkRequeue = false break } - // Invoke the tcp probe if installed. - if e.probe != nil { - e.probe(e.completeState()) + cont, err := e.handleSegment(s) + if err != nil { + s.decRef() + e.mu.Lock() + e.setEndpointState(StateError) + e.HardError = err + e.mu.Unlock() + return err } - - if s.flagIsSet(header.TCPFlagRst) { - if ok, err := e.handleReset(s); !ok { - return err - } - } else if s.flagIsSet(header.TCPFlagSyn) { - // See: https://tools.ietf.org/html/rfc5961#section-4.1 - // 1) If the SYN bit is set, irrespective of the sequence number, TCP - // MUST send an ACK (also referred to as challenge ACK) to the remote - // peer: - // - // - // - // After sending the acknowledgment, TCP MUST drop the unacceptable - // segment and stop processing further. - // - // By sending an ACK, the remote peer is challenged to confirm the loss - // of the previous connection and the request to start a new connection. - // A legitimate peer, after restart, would not have a TCB in the - // synchronized state. Thus, when the ACK arrives, the peer should send - // a RST segment back with the sequence number derived from the ACK - // field that caused the RST. - - // This RST will confirm that the remote peer has indeed closed the - // previous connection. Upon receipt of a valid RST, the local TCP - // endpoint MUST terminate its connection. The local TCP endpoint - // should then rely on SYN retransmission from the remote end to - // re-establish the connection. - - e.snd.sendAck() - } else if s.flagIsSet(header.TCPFlagAck) { - // Patch the window size in the segment according to the - // send window scale. - s.window <<= e.snd.sndWndScale - - // RFC 793, page 41 states that "once in the ESTABLISHED - // state all segments must carry current acknowledgment - // information." - drop, err := e.rcv.handleRcvdSegment(s) - if err != nil { - s.decRef() - return err - } - if drop { - s.decRef() - continue - } - - // Now check if the received segment has caused us to transition - // to a CLOSED state, if yes then terminate processing and do - // not invoke the sender. - e.mu.RLock() - state := e.state - e.mu.RUnlock() - if state == StateClose { - // When we get into StateClose while processing from the queue, - // return immediately and let the protocolMainloop handle it. - // - // We can reach StateClose only while processing a previous segment - // or a notification from the protocolMainLoop (caller goroutine). - // This means that with this return, the segment dequeue below can - // never occur on a closed endpoint. - s.decRef() - return nil - } - e.snd.handleRcvdSegment(s) + if !cont { + s.decRef() + return nil } - s.decRef() } - // If the queue is not empty, make sure we'll wake up in the next - // iteration. - if checkRequeue && !e.segmentQueue.empty() { + // When fastPath is true we don't want to wake up the worker + // goroutine. If the endpoint has more segments to process the + // dispatcher will call handleSegments again anyway. + if !fastPath && checkRequeue && !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } @@ -1086,11 +1043,88 @@ func (e *endpoint) handleSegments() *tcpip.Error { e.snd.sendAck() } - e.resetKeepaliveTimer(true) + e.resetKeepaliveTimer(true /* receivedData */) return nil } +// handleSegment handles a given segment and notifies the worker goroutine if +// if the connection should be terminated. +func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { + // Invoke the tcp probe if installed. + if e.probe != nil { + e.probe(e.completeState()) + } + + if s.flagIsSet(header.TCPFlagRst) { + if ok, err := e.handleReset(s); !ok { + return false, err + } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() + } else if s.flagIsSet(header.TCPFlagAck) { + // Patch the window size in the segment according to the + // send window scale. + s.window <<= e.snd.sndWndScale + + // RFC 793, page 41 states that "once in the ESTABLISHED + // state all segments must carry current acknowledgment + // information." + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + return false, err + } + if drop { + return true, nil + } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return false, nil + } + + e.snd.handleRcvdSegment(s) + } + + return true, nil +} + // keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. @@ -1160,7 +1194,7 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. -func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1182,6 +1216,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() + e.workMu.Unlock() // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1193,7 +1228,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { initialRcvWnd := e.initialReceiveWindow() h := newHandshake(e, seqnum.Size(initialRcvWnd)) e.mu.Lock() - h.ep.state = StateSynSent + h.ep.setEndpointState(StateSynSent) e.mu.Unlock() if err := h.execute(); err != nil { @@ -1202,12 +1237,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() - e.state = StateError + e.setEndpointState(StateError) e.HardError = err // Lock released below. epilogue() - return err } } @@ -1215,7 +1249,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.keepalive.timer.init(&e.keepalive.waker) defer e.keepalive.timer.cleanup() - // Tell waiters that the endpoint is connected and writable. e.mu.Lock() drained := e.drainDone != nil e.mu.Unlock() @@ -1224,8 +1257,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { <-e.undrain } - e.waiterQueue.Notify(waiter.EventOut) - // Set up the functions that will be called when the main protocol loop // wakes up. funcs := []struct { @@ -1240,18 +1271,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { w: &e.sndCloseWaker, f: e.handleClose, }, - { - w: &e.newSegmentWaker, - f: e.handleSegments, - }, { w: &closeWaker, f: func() *tcpip.Error { // This means the socket is being closed due - // to the TCP_FIN_WAIT2 timeout was hit. Just + // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. e.mu.Lock() e.transitionToStateCloseLocked() + e.workerCleanup = true e.mu.Unlock() return nil }, @@ -1266,6 +1294,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil }, }, + { + w: &e.newSegmentWaker, + f: func() *tcpip.Error { + return e.handleSegments(false /* fastPath */) + }, + }, { w: &e.keepalive.waker, f: e.keepaliveTimerExpired, @@ -1293,14 +1327,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } if n¬ifyReset != 0 { - e.mu.Lock() - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.mu.Unlock() + return tcpip.ErrConnectionAborted + } + + if n¬ifyResetByPeer != 0 { + return tcpip.ErrConnectionReset } if n¬ifyClose != 0 && closeTimer == nil { e.mu.Lock() - if e.state == StateFinWait2 && e.closed { + if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { @@ -1320,11 +1356,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { - if err := e.handleSegments(); err != nil { + if err := e.handleSegments(false /* fastPath */); err != nil { return err } } - if e.state != StateClose && e.state != StateError { + if e.EndpointState() != StateClose && e.EndpointState() != StateError { // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) @@ -1349,14 +1385,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { s.AddWaker(funcs[i].w, i) } + // Notify the caller that the waker initialization is complete and the + // endpoint is ready. + if wakerInitDone != nil { + close(wakerInitDone) + } + + // Tell waiters that the endpoint is connected and writable. + e.waiterQueue.Notify(waiter.EventOut) + // The following assertions and notifications are needed for restored // endpoints. Fresh newly created endpoints have empty states and should // not invoke any. - e.segmentQueue.mu.Lock() - if !e.segmentQueue.list.Empty() { + if !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } - e.segmentQueue.mu.Unlock() e.rcvListMu.Lock() if !e.rcvList.Empty() { @@ -1372,27 +1415,32 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError { e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() + // We need to double check here because the notification + // maybe stale by the time we got around to processing it. + // NOTE: since we now hold the workMu the processors cannot + // change the state of the endpoint so it' safe to proceed + // after this check. + if e.EndpointState() == StateTimeWait || e.EndpointState() == StateClose || e.EndpointState() == StateError { + e.mu.Lock() + break + } if err := funcs[v].f(); err != nil { e.mu.Lock() - // Ensure we release all endpoint registration and route - // references as the connection is now in an error - // state. e.workerCleanup = true e.resetConnectionLocked(err) // Lock released below. epilogue() - return nil } e.mu.Lock() } - state := e.state + state := e.EndpointState() e.mu.Unlock() var reuseTW func() if state == StateTimeWait { @@ -1405,13 +1453,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { s.Done() // Wake up any waiters before we enter TIME_WAIT. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.mu.Lock() + e.workerCleanup = true + e.mu.Unlock() reuseTW = e.doTimeWait() } // Mark endpoint as closed. e.mu.Lock() - if e.state != StateError { - e.stack.Stats().TCP.CurrentEstablished.Decrement() + if e.EndpointState() != StateError { e.transitionToStateCloseLocked() } @@ -1468,7 +1518,11 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() tcpEP := listenEP.(*endpoint) if EndpointState(tcpEP.State()) == StateListen { reuseTW = func() { - tcpEP.enqueueSegment(s) + if !tcpEP.enqueueSegment(s) { + s.decRef() + return + } + tcpEP.newSegmentWaker.Assert() } // We explicitly do not decRef // the segment as it's still diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go new file mode 100644 index 000000000..a72f0c379 --- /dev/null +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -0,0 +1,218 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// epQueue is a queue of endpoints. +type epQueue struct { + mu sync.Mutex + list endpointList +} + +// enqueue adds e to the queue if the endpoint is not already on the queue. +func (q *epQueue) enqueue(e *endpoint) { + q.mu.Lock() + if e.pendingProcessing { + q.mu.Unlock() + return + } + q.list.PushBack(e) + e.pendingProcessing = true + q.mu.Unlock() +} + +// dequeue removes and returns the first element from the queue if available, +// returns nil otherwise. +func (q *epQueue) dequeue() *endpoint { + q.mu.Lock() + if e := q.list.Front(); e != nil { + q.list.Remove(e) + e.pendingProcessing = false + q.mu.Unlock() + return e + } + q.mu.Unlock() + return nil +} + +// empty returns true if the queue is empty, false otherwise. +func (q *epQueue) empty() bool { + q.mu.Lock() + v := q.list.Empty() + q.mu.Unlock() + return v +} + +// processor is responsible for processing packets queued to a tcp endpoint. +type processor struct { + epQ epQueue + newEndpointWaker sleep.Waker + id int +} + +func newProcessor(id int) *processor { + p := &processor{ + id: id, + } + go p.handleSegments() + return p +} + +func (p *processor) queueEndpoint(ep *endpoint) { + // Queue an endpoint for processing by the processor goroutine. + p.epQ.enqueue(ep) + p.newEndpointWaker.Assert() +} + +func (p *processor) handleSegments() { + const newEndpointWaker = 1 + s := sleep.Sleeper{} + s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + defer s.Done() + for { + s.Fetch(true) + for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { + if ep.segmentQueue.empty() { + continue + } + + // If socket has transitioned out of connected state + // then just let the worker handle the packet. + // + // NOTE: We read this outside of e.mu lock which means + // that by the time we get to handleSegments the + // endpoint may not be in ESTABLISHED. But this should + // be fine as all normal shutdown states are handled by + // handleSegments and if the endpoint moves to a + // CLOSED/ERROR state then handleSegments is a noop. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + continue + } + + if !ep.workMu.TryLock() { + ep.newSegmentWaker.Assert() + continue + } + // If the endpoint is in a connected state then we do + // direct delivery to ensure low latency and avoid + // scheduler interactions. + if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { + ep.notifyProtocolGoroutine(notifyTickleWorker) + ep.workMu.Unlock() + continue + } + + if !ep.segmentQueue.empty() { + p.epQ.enqueue(ep) + } + + ep.workMu.Unlock() + } + } +} + +// dispatcher manages a pool of TCP endpoint processors which are responsible +// for the processing of inbound segments. This fixed pool of processor +// goroutines do full tcp processing. The processor is selected based on the +// hash of the endpoint id to ensure that delivery for the same endpoint happens +// in-order. +type dispatcher struct { + processors []*processor + seed uint32 +} + +func newDispatcher(nProcessors int) *dispatcher { + processors := []*processor{} + for i := 0; i < nProcessors; i++ { + processors = append(processors, newProcessor(i)) + } + return &dispatcher{ + processors: processors, + seed: generateRandUint32(), + } +} + +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + ep := stackEP.(*endpoint) + s := newSegment(r, id, pkt) + if !s.parse() { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + s.decRef() + return + } + + if !s.csumValid { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.ChecksumErrors.Increment() + ep.stats.ReceiveErrors.ChecksumErrors.Increment() + s.decRef() + return + } + + ep.stack.Stats().TCP.ValidSegmentsReceived.Increment() + ep.stats.SegmentsReceived.Increment() + if (s.flags & header.TCPFlagRst) != 0 { + ep.stack.Stats().TCP.ResetsReceived.Increment() + } + + if !ep.enqueueSegment(s) { + s.decRef() + return + } + + // For sockets not in established state let the worker goroutine + // handle the packets. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + return + } + + d.selectProcessor(id).queueEndpoint(ep) +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { + payload := []byte{ + byte(id.LocalPort), + byte(id.LocalPort >> 8), + byte(id.RemotePort), + byte(id.RemotePort >> 8)} + + h := jenkins.Sum32(d.seed) + h.Write(payload) + h.Write([]byte(id.LocalAddress)) + h.Write([]byte(id.RemoteAddress)) + + return d.processors[h.Sum32()%uint32(len(d.processors))] +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cc8b533c8..1799c6e10 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -120,6 +120,7 @@ const ( notifyMTUChanged notifyDrain notifyReset + notifyResetByPeer notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -127,6 +128,7 @@ const ( // ensures the loop terminates if the final state of the endpoint is // say TIME_WAIT. notifyTickleWorker + notifyError ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -283,6 +285,18 @@ func (*EndpointInfo) IsEndpointInfo() {} type endpoint struct { EndpointInfo + // endpointEntry is used to queue endpoints for processing to the + // a given tcp processor goroutine. + // + // Precondition: epQueue.mu must be held to read/write this field.. + endpointEntry `state:"nosave"` + + // pendingProcessing is true if this endpoint is queued for processing + // to a TCP processor. + // + // Precondition: epQueue.mu must be held to read/write this field.. + pendingProcessing bool `state:"nosave"` + // workMu is used to arbitrate which goroutine may perform protocol // work. Only the main protocol goroutine is expected to call Lock() on // it, but other goroutines (e.g., send) may call TryLock() to eagerly @@ -324,6 +338,7 @@ type endpoint struct { // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` + // state must be read/set using the EndpointState()/setEndpointState() methods. state EndpointState `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the @@ -359,7 +374,7 @@ type endpoint struct { workerRunning bool // workerCleanup specifies if the worker goroutine must perform cleanup - // before exitting. This can only be set to true when workerRunning is + // before exiting. This can only be set to true when workerRunning is // also true, and they're both protected by the mutex. workerCleanup bool @@ -371,6 +386,8 @@ type endpoint struct { // recentTS is the timestamp that should be sent in the TSEcr field of // the timestamp for future segments sent by the endpoint. This field is // updated if required when a new segment is received by this endpoint. + // + // recentTS must be read/written atomically. recentTS uint32 // tsOffset is a randomized offset added to the value of the @@ -567,6 +584,47 @@ func (e *endpoint) ResumeWork() { e.workMu.Unlock() } +// setEndpointState updates the state of the endpoint to state atomically. This +// method is unexported as the only place we should update the state is in this +// package but we allow the state to be read freely without holding e.mu. +// +// Precondition: e.mu must be held to call this method. +func (e *endpoint) setEndpointState(state EndpointState) { + oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + switch state { + case StateEstablished: + e.stack.Stats().TCP.CurrentEstablished.Increment() + case StateError: + fallthrough + case StateClose: + if oldstate == StateCloseWait || oldstate == StateEstablished { + e.stack.Stats().TCP.EstablishedResets.Increment() + } + fallthrough + default: + if oldstate == StateEstablished { + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } + } + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) +} + +// EndpointState returns the current state of the endpoint. +func (e *endpoint) EndpointState() EndpointState { + return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) +} + +// setRecentTimestamp atomically sets the recentTS field to the +// provided value. +func (e *endpoint) setRecentTimestamp(recentTS uint32) { + atomic.StoreUint32(&e.recentTS, recentTS) +} + +// recentTimestamp atomically reads and returns the value of the recentTS field. +func (e *endpoint) recentTimestamp() uint32 { + return atomic.LoadUint32(&e.recentTS) +} + // keepalive is a synchronization wrapper used to appease stateify. See the // comment in endpoint, where it is used. // @@ -656,7 +714,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { e.mu.RLock() defer e.mu.RUnlock() - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. @@ -672,7 +730,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } } } - if e.state.connected() { + if e.EndpointState().connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -733,14 +791,20 @@ func (e *endpoint) Close() { // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) + e.closeNoShutdown() +} +// closeNoShutdown closes the endpoint without doing a full shutdown. This is +// used when a connection needs to be aborted with a RST and we want to skip +// a full 4 way TCP shutdown. +func (e *endpoint) closeNoShutdown() { e.mu.Lock() // For listening sockets, we always release ports inline so that they // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == StateListen && e.isPortReserved { + if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false @@ -780,6 +844,8 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { defer close(done) for n := range e.acceptedChan { n.notifyProtocolGoroutine(notifyReset) + // close all connections that have completed but + // not accepted by the application. n.Close() } }() @@ -797,11 +863,13 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { // after Close() is called and the worker goroutine (if any) is done with its // work. func (e *endpoint) cleanupLocked() { + // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { e.closePendingAcceptableConnectionsLocked() } + e.workerCleanup = false if e.isRegistered { @@ -920,7 +988,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { + if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError e.mu.RUnlock() @@ -944,7 +1012,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -980,8 +1048,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. - if !e.state.connected() { - switch e.state { + if !e.EndpointState().connected() { + switch e.EndpointState() { case StateError: return 0, e.HardError default: @@ -1039,42 +1107,86 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, perr } - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() + if opts.Atomic { + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() + } - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { + if e.workMu.TryLock() { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] } - } - - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() - - if e.workMu.TryLock() { // Do the work inline. e.handleWrite() e.workMu.Unlock() } else { + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() + + } // Let the protocol goroutine do the work. e.sndWaker.Assert() } @@ -1091,7 +1203,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; !s.connected() && s != StateClose { + if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { return 0, tcpip.ControlMessages{}, e.HardError } @@ -1103,7 +1215,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } @@ -1187,7 +1299,7 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -1402,14 +1514,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // Acquire the work mutex as we may need to // reinitialize the congestion control state. e.mu.Lock() - state := e.state + state := e.EndpointState() e.cc = v e.mu.Unlock() switch state { case StateEstablished: e.workMu.Lock() e.mu.Lock() - if e.state == state { + if e.EndpointState() == state { e.snd.cc = e.snd.initCongestionControl(e.cc) } e.mu.Unlock() @@ -1472,7 +1584,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { defer e.mu.RUnlock() // The endpoint cannot be in listen state. - if e.state == StateListen { + if e.EndpointState() == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1731,7 +1843,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return err } - if e.state.connected() { + if e.EndpointState().connected() { // The endpoint is already connected. If caller hasn't been // notified yet, return success. if !e.isConnectNotified { @@ -1743,7 +1855,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } nicID := addr.NIC - switch e.state { + switch e.EndpointState() { case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. @@ -1850,7 +1962,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.isRegistered = true - e.state = StateConnecting + e.setEndpointState(StateConnecting) e.route = r.Clone() e.boundNICID = nicID e.effectiveNetProtos = netProtos @@ -1871,14 +1983,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = StateEstablished - e.stack.Stats().TCP.CurrentEstablished.Increment() + e.setEndpointState(StateEstablished) } if run { e.workerRunning = true e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save. + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } return tcpip.ErrConnectStarted @@ -1896,7 +2007,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.shutdownFlags |= flags finQueued := false switch { - case e.state.connected(): + case e.EndpointState().connected(): // Close for read. if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { // Mark read side as closed. @@ -1908,8 +2019,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { - e.notifyProtocolGoroutine(notifyReset) + // Move the socket to error state immediately. + // This is done redundantly because in case of + // save/restore on a Shutdown/Close() the socket + // state needs to indicate the error otherwise + // save file will show the socket in established + // state even though snd/rcv are closed. e.mu.Unlock() + // Try to send an active reset immediately if the + // work mutex is available. + if e.workMu.TryLock() { + e.mu.Lock() + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.mu.Unlock() + e.workMu.Unlock() + } else { + e.notifyProtocolGoroutine(notifyReset) + } return nil } } @@ -1931,11 +2057,10 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { finQueued = true // Mark endpoint as closed. e.sndClosed = true - e.sndBufMu.Unlock() } - case e.state == StateListen: + case e.EndpointState() == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) @@ -1976,7 +2101,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // When the endpoint shuts down, it sets workerCleanup to true, and from // that point onward, acceptedChan is the responsibility of the cleanup() // method (and should not be touched anywhere else, including here). - if e.state == StateListen && !e.workerCleanup { + if e.EndpointState() == StateListen && !e.workerCleanup { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. if len(e.acceptedChan) > backlog { @@ -1994,7 +2119,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { return nil } - if e.state == StateInitial { + if e.EndpointState() == StateInitial { // The listen is called on an unbound socket, the socket is // automatically bound to a random free port with the local // address set to INADDR_ANY. @@ -2004,7 +2129,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Endpoint must be bound before it can transition to listen mode. - if e.state != StateBound { + if e.EndpointState() != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } @@ -2015,24 +2140,27 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } e.isRegistered = true - e.state = StateListen + e.setEndpointState(StateListen) + if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } e.workerRunning = true - go e.protocolListenLoop( // S/R-SAFE: drained on save. seqnum.Size(e.receiveBufferAvailable())) - return nil } // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { + e.mu.Lock() e.waiterQueue = waiterQueue e.workerRunning = true - go e.protocolMainLoop(false) // S/R-SAFE: drained on save. + e.mu.Unlock() + wakerInitDone := make(chan struct{}) + go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save. + <-wakerInitDone } // Accept returns a new endpoint if a peer has established a connection @@ -2042,7 +2170,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { defer e.mu.RUnlock() // Endpoint must be in listen state before it can accept connections. - if e.state != StateListen { + if e.EndpointState() != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } @@ -2069,7 +2197,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrAlreadyBound } @@ -2131,7 +2259,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } // Mark endpoint as bound. - e.state = StateBound + e.setEndpointState(StateBound) return nil } @@ -2153,7 +2281,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if !e.state.connected() { + if !e.EndpointState().connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -2164,45 +2292,22 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { - s := newSegment(r, id, pkt) - if !s.parse() { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() - e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() - s.decRef() - return - } - - if !s.csumValid { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.ChecksumErrors.Increment() - e.stats.ReceiveErrors.ChecksumErrors.Increment() - s.decRef() - return - } - - e.stack.Stats().TCP.ValidSegmentsReceived.Increment() - e.stats.SegmentsReceived.Increment() - if (s.flags & header.TCPFlagRst) != 0 { - e.stack.Stats().TCP.ResetsReceived.Increment() - } - - e.enqueueSegment(s) + // TCP HandlePacket is not required anymore as inbound packets first + // land at the Dispatcher which then can either delivery using the + // worker go routine or directly do the invoke the tcp processing inline + // based on the state of the endpoint. } -func (e *endpoint) enqueueSegment(s *segment) { +func (e *endpoint) enqueueSegment(s *segment) bool { // Send packet to worker goroutine. - if e.segmentQueue.enqueue(s) { - e.newSegmentWaker.Assert() - } else { + if !e.segmentQueue.enqueue(s) { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.SegmentQueueDropped.Increment() - s.decRef() + return false } + return true } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. @@ -2319,8 +2424,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int { // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { - if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { - e.recentTS = tsVal + if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + e.setRecentTimestamp(tsVal) } } @@ -2330,7 +2435,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { if synOpts.TS { e.sendTSOk = true - e.recentTS = synOpts.TSVal + e.setRecentTimestamp(synOpts.TSVal) } } @@ -2419,7 +2524,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Endpoint TCP Option state. s.SendTSOk = e.sendTSOk - s.RecentTS = e.recentTS + s.RecentTS = e.recentTimestamp() s.TSOffset = e.tsOffset s.SACKPermitted = e.sackPermitted s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) @@ -2526,9 +2631,7 @@ func (e *endpoint) initGSO() { // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { - e.mu.Lock() - defer e.mu.Unlock() - return uint32(e.state) + return uint32(e.EndpointState()) } // Info returns a copy of the endpoint info. diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 4b8d867bc..4a46f0ec5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -16,6 +16,7 @@ package tcp import ( "fmt" + "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sync" @@ -48,7 +49,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() defer e.mu.Unlock() - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound: // TODO(b/138137272): this enumeration duplicates // EndpointState.connected. remove it. @@ -70,31 +71,30 @@ func (e *endpoint) beforeSave() { fallthrough case StateListen, StateConnecting: e.drainSegmentLocked() - if e.state != StateClose && e.state != StateError { + if e.EndpointState() != StateClose && e.EndpointState() != StateError { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } break } - fallthrough case StateError, StateClose: - for (e.state == StateError || e.state == StateClose) && e.workerRunning { + for e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() } if e.workerRunning { - panic("endpoint still has worker running in closed or error state") + panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID)) } default: - panic(fmt.Sprintf("endpoint in unknown state %v", e.state)) + panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) } if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { panic("endpoint still has waiters upon save") } - if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) { + if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) { panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") } } @@ -135,7 +135,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { // saveState is invoked by stateify. func (e *endpoint) saveState() EndpointState { - return e.state + return e.EndpointState() } // Endpoint loading must be done in the following ordering by their state, to @@ -151,7 +151,8 @@ var connectingLoading sync.WaitGroup func (e *endpoint) loadState(state EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. - if state.connected() { + // For restore purposes we treat TimeWait like a connected endpoint. + if state.connected() || state == StateTimeWait { connectedLoading.Add(1) } switch state { @@ -160,13 +161,14 @@ func (e *endpoint) loadState(state EndpointState) { case StateConnecting, StateSynSent, StateSynRecv: connectingLoading.Add(1) } - e.state = state + // Directly update the state here rather than using e.setEndpointState + // as the endpoint is still being loaded and the stack reference to increment + // metrics is not yet initialized. + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) } // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - // Freeze segment queue before registering to prevent any segments - // from being delivered while it is being restored. e.origEndpointState = e.state // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. @@ -180,7 +182,6 @@ func (e *endpoint) Resume(s *stack.Stack) { e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() state := e.origEndpointState - switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -276,7 +277,7 @@ func (e *endpoint) Resume(s *stack.Stack) { listenLoading.Wait() connectingLoading.Wait() bind() - e.state = StateClose + e.setEndpointState(StateClose) tcpip.AsyncLoading.Done() }() } @@ -288,6 +289,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } + } // saveLastError is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 9a8f64aa6..958c06fa7 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,6 +21,7 @@ package tcp import ( + "runtime" "strings" "time" @@ -104,6 +105,7 @@ type protocol struct { moderateReceiveBuffer bool tcpLingerTimeout time.Duration tcpTimeWaitTimeout time.Duration + dispatcher *dispatcher } // Number returns the tcp protocol number. @@ -134,6 +136,14 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { return h.SourcePort(), h.DestinationPort(), nil } +// QueuePacket queues packets targeted at an endpoint after hashing the packet +// to a specific processing queue. Each queue is serviced by its own processor +// goroutine which is responsible for dequeuing and doing full TCP dispatch of +// the packet. +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + p.dispatcher.queuePacket(r, ep, id, pkt) +} + // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. // @@ -330,5 +340,6 @@ func NewProtocol() stack.TransportProtocol { availableCongestionControl: []string{ccReno, ccCubic}, tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 05c8488f8..958f03ac1 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -169,19 +169,19 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // We just received a FIN, our next state depends on whether we sent a // FIN already or not. r.ep.mu.Lock() - switch r.ep.state { + switch r.ep.EndpointState() { case StateEstablished: - r.ep.state = StateCloseWait + r.ep.setEndpointState(StateCloseWait) case StateFinWait1: if s.flagIsSet(header.TCPFlagAck) { // FIN-ACK, transition to TIME-WAIT. - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } else { // Simultaneous close, expecting a final ACK. - r.ep.state = StateClosing + r.ep.setEndpointState(StateClosing) } case StateFinWait2: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } r.ep.mu.Unlock() @@ -205,16 +205,16 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // shutdown states. if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { r.ep.mu.Lock() - switch r.ep.state { + switch r.ep.EndpointState() { case StateFinWait1: - r.ep.state = StateFinWait2 + r.ep.setEndpointState(StateFinWait2) // Notify protocol goroutine that we have received an // ACK to our FIN so that it can start the FIN_WAIT2 // timer to abort connection if the other side does // not close within 2MSL. r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) case StateLastAck: r.ep.transitionToStateCloseLocked() } @@ -267,7 +267,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo switch state { case StateCloseWait, StateClosing, StateLastAck: if !s.sequenceNumber.LessThanEq(r.rcvNxt) { - s.decRef() // Just drop the segment as we have // already received a FIN and this // segment is after the sequence number @@ -284,7 +283,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { - s.decRef() return true, tcpip.ErrConnectionAborted } if state == StateFinWait1 { @@ -314,7 +312,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // the last actual data octet in a segment in // which it occurs. if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { - s.decRef() return true, tcpip.ErrConnectionAborted } } @@ -336,7 +333,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // r as they arrive. It is called by the protocol main loop. func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { r.ep.mu.RLock() - state := r.ep.state + state := r.ep.EndpointState() closed := r.ep.closed r.ep.mu.RUnlock() diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index fdff7ed81..b74b61e7d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -705,17 +705,15 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } seg.flags = header.TCPFlagAck | header.TCPFlagFin segEnd = seg.sequenceNumber.Add(1) - // Transition to FIN-WAIT1 state since we're initiating an active close. - s.ep.mu.Lock() - switch s.ep.state { + // Update the state to reflect that we have now + // queued a FIN. + switch s.ep.EndpointState() { case StateCloseWait: - // We've already received a FIN and are now sending our own. The - // sender is now awaiting a final ACK for this FIN. - s.ep.state = StateLastAck + s.ep.setEndpointState(StateLastAck) default: - s.ep.state = StateFinWait1 + s.ep.setEndpointState(StateFinWait1) } - s.ep.mu.Unlock() + } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 6edfa8dce..a9dfbe857 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -293,7 +293,6 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SeqNum(uint32(c.IRS+1)), checker.AckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - finHeaders := &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -459,6 +458,9 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence number + // of the FIN itself. checker.SeqNum(uint32(c.IRS)+2), checker.AckNum(0), checker.TCPFlags(header.TCPFlagRst), @@ -1500,6 +1502,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence + // number of the FIN itself. checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. @@ -5441,6 +5446,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { rawEP.SendPacketWithTS(b[start:start+mss], tsVal) packetsSent++ } + // Resume the worker so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() @@ -5508,7 +5514,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { stk := c.Stack() // Set lower limits for auto-tuning tests. This is required because the // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. + // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { @@ -5563,6 +5569,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalSent += mss packetsSent++ } + // Resume it so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 5d114d460..2f9821555 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -533,7 +533,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { // Sleep for a little over the linger timeout to reduce flakiness in // save/restore tests. - absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2)); ds.reset(); diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 6b99c021d..33a5ac66c 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -814,6 +814,20 @@ TEST_P(TcpSocketTest, FullBuffer) { t_ = -1; } +TEST_P(TcpSocketTest, PollAfterShutdown) { + ScopedThread client_thread([this]() { + EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceedsWithValue(0)); + struct pollfd poll_fd = {s_, POLLIN | POLLERR | POLLHUP, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); + }); + + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceedsWithValue(0)); + struct pollfd poll_fd = {t_, POLLIN | POLLERR | POLLHUP, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); +} + TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) { // Initialize address to the loopback one. sockaddr_storage addr = -- cgit v1.2.3 From 92a00ca91affab8564b8875387758914ddc9785b Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Wed, 18 Mar 2020 16:25:20 -0700 Subject: Store segment transmit count. This will aid in segment reordering detection. Updates #691 PiperOrigin-RevId: 301692638 --- pkg/tcpip/transport/tcp/segment.go | 6 +++--- pkg/tcpip/transport/tcp/snd.go | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) (limited to 'pkg/tcpip/transport/tcp/snd.go') diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 1c10da5ca..5d0bc4f72 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -56,9 +56,9 @@ type segment struct { options []byte `state:".([]byte)"` hasNewSACKInfo bool rcvdTime time.Time `state:".(unixTime)"` - // xmitTime is the last transmit time of this segment. A zero value - // indicates that the segment has yet to be transmitted. - xmitTime time.Time `state:".(unixTime)"` + // xmitTime is the last transmit time of this segment. + xmitTime time.Time `state:".(unixTime)"` + xmitCount uint32 } func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment { diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index b74b61e7d..657c3146e 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1229,7 +1229,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { // sendSegment sends the specified segment. func (s *sender) sendSegment(seg *segment) *tcpip.Error { - if !seg.xmitTime.IsZero() { + if seg.xmitCount > 0 { s.ep.stack.Stats().TCP.Retransmits.Increment() s.ep.stats.SendErrors.Retransmits.Increment() if s.sndCwnd < s.sndSsthresh { @@ -1237,6 +1237,7 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error { } } seg.xmitTime = time.Now() + seg.xmitCount++ return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) } -- cgit v1.2.3 From e9e399c25d4fcad2adfe92d73b192b9784774964 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 19 Mar 2020 07:18:47 -0700 Subject: Remove workMu from tcpip.Endpoint. workMu is removed and e.mu is now a mutex that supports TryLock. The packet processing path tries to lock the mutex and if its locked it will just queue the packet and move on. The endpoint.UnlockUser() will process any backlog of packets before unlocking the socket. This simplifies the locking inside tcp endpoints a lot. Further the endpoint.LockUser() implements spinning as long as the lock is not held by another syscall goroutine. This ensures low latency as not spinning leads to the task thread being put to sleep if the lock is held by the packet dispatch path. This is suboptimal as the lower layer rarely holds the lock for long so implementing spinning here helps. If the lock is held by another task goroutine then we just proceed to call LockUser() and the task could be put to sleep. The protocol goroutines themselves just call e.mu.Lock() and block if the lock is currently not available. Updates #231, #357 PiperOrigin-RevId: 301808349 --- pkg/sentry/kernel/epoll/epoll.go | 2 + pkg/sentry/socket/netstack/netstack.go | 24 +- pkg/tcpip/transport/tcp/accept.go | 90 +++--- pkg/tcpip/transport/tcp/connect.go | 78 ++--- pkg/tcpip/transport/tcp/dispatcher.go | 8 +- pkg/tcpip/transport/tcp/endpoint.go | 495 ++++++++++++++++-------------- pkg/tcpip/transport/tcp/endpoint_state.go | 5 +- pkg/tcpip/transport/tcp/protocol.go | 38 +-- pkg/tcpip/transport/tcp/rcv.go | 6 - pkg/tcpip/transport/tcp/segment_queue.go | 8 +- pkg/tcpip/transport/tcp/snd.go | 3 - pkg/tcpip/transport/tcp/tcp_test.go | 41 ++- 12 files changed, 424 insertions(+), 374 deletions(-) (limited to 'pkg/tcpip/transport/tcp/snd.go') diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 8bffb78fc..592650923 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -296,8 +296,10 @@ func (*readyCallback) Callback(w *waiter.Entry) { e.waitingList.Remove(entry) e.readyList.PushBack(entry) entry.curList = &e.readyList + e.listsMu.Unlock() e.Notify(waiter.EventIn) + return } e.listsMu.Unlock() diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 13a9a60b4..a2e1da02f 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -29,6 +29,7 @@ import ( "io" "math" "reflect" + "sync/atomic" "syscall" "time" @@ -264,6 +265,12 @@ type SocketOperations struct { skType linux.SockType protocol int + // readViewHasData is 1 iff readView has data to be read, 0 otherwise. + // Must be accessed using atomic operations. It must only be written + // with readMu held but can be read without holding readMu. The latter + // is required to avoid deadlocks in epoll Readiness checks. + readViewHasData uint32 + // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` // readView contains the remaining payload from the last packet. @@ -410,21 +417,24 @@ func (s *SocketOperations) isPacketBased() bool { // fetchReadView updates the readView field of the socket if it's currently // empty. It assumes that the socket is locked. +// +// Precondition: s.readMu must be held. func (s *SocketOperations) fetchReadView() *syserr.Error { if len(s.readView) > 0 { return nil } - s.readView = nil s.sender = tcpip.FullAddress{} v, cms, err := s.Endpoint.Read(&s.sender) if err != nil { + atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) } s.readView = v s.readCM = cms + atomic.StoreUint32(&s.readViewHasData, 1) return nil } @@ -623,11 +633,9 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Check our cached value iff the caller asked for readability and the // endpoint itself is currently not readable. if (mask & ^r & waiter.EventIn) != 0 { - s.readMu.Lock() - if len(s.readView) > 0 { + if atomic.LoadUint32(&s.readViewHasData) == 1 { r |= waiter.EventIn } - s.readMu.Unlock() } return r @@ -2334,6 +2342,10 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq } copied += n s.readView.TrimFront(n) + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + dst = dst.DropFirst(n) if e != nil { err = syserr.FromError(e) @@ -2456,6 +2468,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readView.TrimFront(int(n)) } + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + var flags int if msgLen > int(n) { flags |= linux.MSG_TRUNC diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 85049e54e..4d7602d54 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -221,7 +221,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu } // createConnectingEndpoint creates a new endpoint in a connecting state, with -// the connection parameters given by the arguments. +// the connection parameters given by the arguments. The endpoint is returned +// with n.mu held. func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto @@ -243,21 +244,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() - // Now inherit any socket options that should be inherited from the - // listening endpoint. - // In case of Forwarder listenEP will be nil and hence this check. - if l.listenEP != nil { - l.listenEP.propagateInheritableOptions(n) - } - - // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { - n.Close() - return nil, err - } - - n.isRegistered = true - // Create sender and receiver. // // The receiver at least temporarily has a zero receive window scale, @@ -269,11 +255,27 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() + // Lock the endpoint before registering to ensure that no out of + // band changes are possible due to incoming packets etc till + // the endpoint is done initializing. + n.mu.Lock() + + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { + n.mu.Unlock() + n.Close() + return nil, err + } + + n.isRegistered = true + return n, nil } // createEndpointAndPerformHandshake creates a new endpoint in connected state // and then performs the TCP 3-way handshake. +// +// The new endpoint is returned with e.mu held. func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber @@ -289,9 +291,25 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { l.listenEP.mu.Unlock() + // Ensure we release any registrations done by the newly + // created endpoint. + ep.mu.Unlock() + ep.Close() + + // Wake up any waiters. This is strictly not required normally + // as a socket that was never accepted can't really have any + // registered waiters except when stack.Wait() is called which + // waits for all registered endpoints to stop and expects an + // EventHUp. + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + l.listenEP.propagateInheritableOptionsLocked(ep) + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } @@ -299,6 +317,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { + ep.mu.Unlock() ep.Close() // Wake up any waiters. This is strictly not required normally // as a socket that was never accepted can't really have any @@ -312,9 +331,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head } return nil, err } - ep.mu.Lock() ep.isConnectNotified = true - ep.mu.Unlock() // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window @@ -366,12 +383,12 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } } -// propagateInheritableOptions propagates any options set on the listening +// propagateInheritableOptionsLocked propagates any options set on the listening // endpoint to the newly created endpoint. -func (e *endpoint) propagateInheritableOptions(n *endpoint) { - e.mu.Lock() +// +// Precondition: e.mu and n.mu must be held. +func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout - e.mu.Unlock() } // handleSynSegment is called in its own goroutine once the listening endpoint @@ -382,7 +399,11 @@ func (e *endpoint) propagateInheritableOptions(n *endpoint) { // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { defer decSynRcvdCount() - defer e.decSynRcvdCount() + defer func() { + e.mu.Lock() + e.decSynRcvdCount() + e.mu.Unlock() + }() defer s.decRef() n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}) @@ -399,29 +420,21 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } func (e *endpoint) incSynRcvdCount() bool { - e.mu.Lock() - if e.synRcvdCount >= cap(e.acceptedChan) { - e.mu.Unlock() + if e.synRcvdCount >= (cap(e.acceptedChan)) { return false } e.synRcvdCount++ - e.mu.Unlock() return true } func (e *endpoint) decSynRcvdCount() { - e.mu.Lock() e.synRcvdCount-- - e.mu.Unlock() } func (e *endpoint) acceptQueueIsFull() bool { - e.mu.Lock() if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c { - e.mu.Unlock() return true } - e.mu.Unlock() return false } @@ -559,6 +572,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + e.propagateInheritableOptionsLocked(n) + // clear the tsOffset for the newly created // endpoint as the Timestamp was already // randomly offset when the original SYN-ACK was @@ -593,14 +610,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() v6only := e.v6only - e.mu.Unlock() ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running // handleSynSegment() from attempting to queue new connections // to the endpoint. - e.mu.Lock() e.setEndpointState(StateClose) // close any endpoints in SYN-RCVD state. @@ -622,7 +637,10 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) for { - switch index, _ := s.Fetch(true); index { + e.mu.Unlock() + index, _ := s.Fetch(true) + e.mu.Lock() + switch index { case wakerForNotification: n := e.fetchNotifications() if n¬ifyClose != 0 { @@ -635,7 +653,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { s.decRef() } close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } case wakerForNewSegment: diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index be86af502..edb37a549 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -61,6 +61,9 @@ const ( ) // handshake holds the state used during a TCP 3-way handshake. +// +// NOTE: handshake.ep.mu is held during handshake processing. It is released if +// we are going to block and reacquired when we start processing an event. type handshake struct { ep *endpoint state handshakeState @@ -209,9 +212,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.deferAccept = deferAccept - h.ep.mu.Lock() h.ep.setEndpointState(StateSynRecv) - h.ep.mu.Unlock() } // checkAck checks if the ACK number, if present, of a segment received during @@ -241,9 +242,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK // was acceptable then signal the user "error: connection reset", drop // the segment, enter CLOSED state, delete TCB, and return." - h.ep.mu.Lock() h.ep.workerCleanup = true - h.ep.mu.Unlock() // Although the RFC above calls out ECONNRESET, Linux actually returns // ECONNREFUSED here so we do as well. return tcpip.ErrConnectionRefused @@ -281,9 +280,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { if s.flagIsSet(header.TCPFlagAck) { h.state = handshakeCompleted - h.ep.mu.Lock() h.ep.transitionToStateEstablishedLocked(h) - h.ep.mu.Unlock() h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil @@ -293,11 +290,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // but resend our own SYN and wait for it to be acknowledged in the // SYN-RCVD state. h.state = handshakeSynRcvd - h.ep.mu.Lock() ttl := h.ep.ttl amss := h.ep.amss h.ep.setEndpointState(StateSynRecv) - h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), TS: rcvSynOpts.TS, @@ -357,10 +352,6 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - h.ep.mu.RLock() - amss := h.ep.amss - h.ep.mu.RUnlock() - h.resetState() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -368,7 +359,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, - MSS: amss, + MSS: h.ep.amss, } h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil @@ -399,15 +390,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { } h.state = handshakeCompleted - h.ep.mu.Lock() h.ep.transitionToStateEstablishedLocked(h) + // If the segment has data then requeue it for the receiver // to process it again once main loop is started. if s.data.Size() > 0 { s.incRef() h.ep.enqueueSegment(s) } - h.ep.mu.Unlock() return nil } @@ -493,7 +483,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { } if n¬ifyDrain != 0 { close(h.ep.drainDone) + h.ep.mu.Unlock() <-h.ep.undrain + h.ep.mu.Lock() } } @@ -535,7 +527,6 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. - h.ep.mu.Lock() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ @@ -546,7 +537,6 @@ func (h *handshake) execute() *tcpip.Error { SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } - h.ep.mu.Unlock() // Execute is also called in a listen context so we want to make sure we // only send the TS/SACK option when we received the TS/SACK in the @@ -563,7 +553,11 @@ func (h *handshake) execute() *tcpip.Error { h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) for h.state != handshakeCompleted { - switch index, _ := s.Fetch(true); index { + h.ep.mu.Unlock() + index, _ := s.Fetch(true) + h.ep.mu.Lock() + switch index { + case wakerForResend: timeOut *= 2 if timeOut > MaxRTO { @@ -600,7 +594,9 @@ func (h *handshake) execute() *tcpip.Error { } } close(h.ep.drainDone) + h.ep.mu.Unlock() <-h.ep.undrain + h.ep.mu.Lock() } case wakerForNewSegment: @@ -1016,7 +1012,6 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // except SYN-SENT, all reset (RST) segments are // validated by checking their SEQ-fields." So // we only process it if it's acceptable. - e.mu.Lock() switch e.EndpointState() { // In case of a RST in CLOSE-WAIT linux moves // the socket to closed state with an error set @@ -1040,11 +1035,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { case StateCloseWait: e.transitionToStateCloseLocked() e.HardError = tcpip.ErrAborted - e.mu.Unlock() e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: - e.mu.Unlock() // RFC 793, page 37 states that "in all states // except SYN-SENT, all reset (RST) segments are // validated by checking their SEQ-fields." So @@ -1157,9 +1150,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. - e.mu.RLock() state := e.state - e.mu.RUnlock() if state == StateClose { // When we get into StateClose while processing from the queue, // return immediately and let the protocolMainloop handle it. @@ -1182,9 +1173,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { - e.mu.RLock() userTimeout := e.userTimeout - e.mu.RUnlock() e.keepalive.Lock() if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { @@ -1248,6 +1237,7 @@ func (e *endpoint) disableKeepaliveTimer() { // goroutine and is responsible for sending segments and handling received // segments. func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { + e.mu.Lock() var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1269,7 +1259,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } e.mu.Unlock() - e.workMu.Unlock() // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1280,16 +1269,13 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // completion. initialRcvWnd := e.initialReceiveWindow() h := newHandshake(e, seqnum.Size(initialRcvWnd)) - e.mu.Lock() h.ep.setEndpointState(StateSynSent) - e.mu.Unlock() if err := h.execute(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() - e.mu.Lock() e.setEndpointState(StateError) e.HardError = err @@ -1302,9 +1288,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.keepalive.timer.init(&e.keepalive.waker) defer e.keepalive.timer.cleanup() - e.mu.Lock() drained := e.drainDone != nil - e.mu.Unlock() if drained { close(e.drainDone) <-e.undrain @@ -1330,10 +1314,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // This means the socket is being closed due // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. - e.mu.Lock() e.transitionToStateCloseLocked() e.workerCleanup = true - e.mu.Unlock() return nil }, }, @@ -1388,7 +1370,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if n¬ifyClose != 0 && closeTimer == nil { - e.mu.Lock() if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. @@ -1397,7 +1378,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }) e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } - e.mu.Unlock() } if n¬ifyKeepaliveChanged != 0 { @@ -1417,7 +1397,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } } @@ -1460,7 +1442,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } e.rcvListMu.Unlock() - e.mu.Lock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } @@ -1468,7 +1449,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Main loop. Handle segments until both send and receive ends of the // connection have completed. cleanupOnError := func(err *tcpip.Error) { - e.mu.Lock() e.workerCleanup = true if err != nil { e.resetConnectionLocked(err) @@ -1480,16 +1460,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ loop: for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError { e.mu.Unlock() - e.workMu.Unlock() v, _ := s.Fetch(true) - e.workMu.Lock() + e.mu.Lock() - // We need to double check here because the notification maybe + // We need to double check here because the notification may be // stale by the time we got around to processing it. - // - // NOTE: since we now hold the workMu the processors cannot - // change the state of the endpoint so it's safe to proceed - // after this check. switch e.EndpointState() { case StateError: // If the endpoint has already transitioned to an ERROR @@ -1502,21 +1477,17 @@ loop: case StateTimeWait: fallthrough case StateClose: - e.mu.Lock() break loop default: if err := funcs[v].f(); err != nil { cleanupOnError(err) return nil } - e.mu.Lock() } } - state := e.EndpointState() - e.mu.Unlock() var reuseTW func() - if state == StateTimeWait { + if e.EndpointState() == StateTimeWait { // Disable close timer as we now entering real TIME_WAIT. if closeTimer != nil { closeTimer.Stop() @@ -1526,14 +1497,11 @@ loop: s.Done() // Wake up any waiters before we enter TIME_WAIT. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) - e.mu.Lock() e.workerCleanup = true - e.mu.Unlock() reuseTW = e.doTimeWait() } // Mark endpoint as closed. - e.mu.Lock() if e.EndpointState() != StateError { e.transitionToStateCloseLocked() } @@ -1649,9 +1617,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) { defer timeWaitTimer.Stop() for { - e.workMu.Unlock() + e.mu.Unlock() v, _ := s.Fetch(true) - e.workMu.Lock() + e.mu.Lock() switch v { case newSegment: extendTimeWait, reuseTW := e.handleTimeWaitSegments() @@ -1674,7 +1642,9 @@ func (e *endpoint) doTimeWait() (twReuse func()) { e.handleTimeWaitSegments() } close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() return nil } case timeWaitDone: diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index d792b07d6..90ac956a9 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -128,7 +128,7 @@ func (p *processor) handleSegments() { continue } - if !ep.workMu.TryLock() { + if !ep.mu.TryLock() { ep.newSegmentWaker.Assert() continue } @@ -138,12 +138,10 @@ func (p *processor) handleSegments() { if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { // Send any active resets if required. if err != nil { - ep.mu.Lock() ep.resetConnectionLocked(err) - ep.mu.Unlock() } ep.notifyProtocolGoroutine(notifyTickleWorker) - ep.workMu.Unlock() + ep.mu.Unlock() continue } @@ -151,7 +149,7 @@ func (p *processor) handleSegments() { p.epQ.enqueue(ep) } - ep.workMu.Unlock() + ep.mu.Unlock() } } } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 5187a5e25..eb8a9d73e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -18,6 +18,7 @@ import ( "encoding/binary" "fmt" "math" + "runtime" "strings" "sync/atomic" "time" @@ -33,7 +34,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tmutex" "gvisor.dev/gvisor/pkg/waiter" ) @@ -283,6 +283,37 @@ func (*EndpointInfo) IsEndpointInfo() {} // synchronized. The protocol implementation, however, runs in a single // goroutine. // +// Each endpoint has a few mutexes: +// +// e.mu -> Primary mutex for an endpoint must be held for all operations except +// in e.Readiness where acquiring it will result in a deadlock in epoll +// implementation. +// +// The following three mutexes can be acquired independent of e.mu but if +// acquired with e.mu then e.mu must be acquired first. +// +// e.rcvListMu -> Protects the rcvList and associated fields. +// e.sndBufMu -> Protects the sndQueue and associated fields. +// e.lastErrorMu -> Protects the lastError field. +// +// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different +// based on the context in which the lock is acquired. In the syscall context +// e.LockUser/e.UnlockUser should be used and when doing background processing +// e.mu.Lock/e.mu.Unlock should be used. The distinction is described below +// in brief. +// +// The reason for this locking behaviour is to avoid wakeups to handle packets. +// In cases where the endpoint is already locked the background processor can +// queue the packet up and go its merry way and the lock owner will eventually +// process the backlog when releasing the lock. Similarly when acquiring the +// lock from say a syscall goroutine we can implement a bit of spinning if we +// know that the lock is not held by another syscall goroutine. Background +// processors should never hold the lock for long and we can avoid an expensive +// sleep/wakeup by spinning for a shortwhile. +// +// For more details please see the detailed documentation on +// e.LockUser/e.UnlockUser methods. +// // +stateify savable type endpoint struct { EndpointInfo @@ -299,12 +330,6 @@ type endpoint struct { // Precondition: epQueue.mu must be held to read/write this field.. pendingProcessing bool `state:"nosave"` - // workMu is used to arbitrate which goroutine may perform protocol - // work. Only the main protocol goroutine is expected to call Lock() on - // it, but other goroutines (e.g., send) may call TryLock() to eagerly - // perform work without having to wait for the main one to wake up. - workMu tmutex.Mutex `state:"nosave"` - // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` @@ -330,15 +355,11 @@ type endpoint struct { rcvBufSize int rcvBufUsed int rcvAutoParams rcvBufAutoTuneParams - // zeroWindow indicates that the window was closed due to receive buffer - // space being filled up. This is set by the worker goroutine before - // moving a segment to the rcvList. This setting is cleared by the - // endpoint when a Read() call reads enough data for the new window to - // be non-zero. - zeroWindow bool - // The following fields are protected by the mutex. - mu sync.RWMutex `state:"nosave"` + // mu protects all endpoint fields unless documented otherwise. mu must + // be acquired before interacting with the endpoint fields. + mu sync.Mutex `state:"nosave"` + ownedByUser uint32 // state must be read/set using the EndpointState()/setEndpointState() methods. state EndpointState `state:".(EndpointState)"` @@ -583,14 +604,93 @@ func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { return maxMSS } +// LockUser tries to lock e.mu and if it fails it will check if the lock is held +// by another syscall goroutine. If yes, then it will goto sleep waiting for the +// lock to be released, if not then it will spin till it acquires the lock or +// another syscall goroutine acquires it in which case it will goto sleep as +// described above. +// +// The assumption behind spinning here being that background packet processing +// should not be holding the lock for long and spinning reduces latency as we +// avoid an expensive sleep/wakeup of of the syscall goroutine). +func (e *endpoint) LockUser() { + for { + // Try first if the sock is locked then check if it's owned + // by another user goroutine if not then we spin, otherwise + // we just goto sleep on the Lock() and wait. + if !e.mu.TryLock() { + // If socket is owned by the user then just goto sleep + // as the lock could be held for a reasonably long time. + if atomic.LoadUint32(&e.ownedByUser) == 1 { + e.mu.Lock() + atomic.StoreUint32(&e.ownedByUser, 1) + return + } + // Spin but yield the processor since the lower half + // should yield the lock soon. + runtime.Gosched() + continue + } + atomic.StoreUint32(&e.ownedByUser, 1) + return + } +} + +// UnlockUser will check if there are any segments already queued for processing +// and process any such segments before unlocking e.mu. This is required because +// we when packets arrive and endpoint lock is already held then such packets +// are queued up to be processed. If the lock is held by the endpoint goroutine +// then it will process these packets but if the lock is instead held by the +// syscall goroutine then we can have the syscall goroutine process the backlog +// before unlocking. +// +// This avoids an unnecessary wakeup of the endpoint protocol goroutine for the +// endpoint. It's also required eventually when we get rid of the endpoint +// protocol goroutine altogether. +// +// Precondition: e.LockUser() must have been called before calling e.UnlockUser() +func (e *endpoint) UnlockUser() { + // Lock segment queue before checking so that we avoid a race where + // segments can be queued between the time we check if queue is empty + // and actually unlock the endpoint mutex. + for { + e.segmentQueue.mu.Lock() + if e.segmentQueue.emptyLocked() { + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + e.segmentQueue.mu.Unlock() + return + } + e.segmentQueue.mu.Unlock() + + switch e.EndpointState() { + case StateEstablished: + if err := e.handleSegments(true /* fastPath */); err != nil { + e.notifyProtocolGoroutine(notifyTickleWorker) + } + default: + // Since we are waking the endpoint goroutine here just unlock + // and let it process the queued segments. + e.newSegmentWaker.Assert() + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + return + } + } +} + // StopWork halts packet processing. Only to be used in tests. func (e *endpoint) StopWork() { - e.workMu.Lock() + e.mu.Lock() } // ResumeWork resumes packet processing. Only to be used in tests. func (e *endpoint) ResumeWork() { - e.workMu.Unlock() + e.mu.Unlock() } // setEndpointState updates the state of the endpoint to state atomically. This @@ -709,8 +809,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() - e.workMu.Lock() e.tsOffset = timeStampOffset() return e @@ -721,9 +819,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result := waiter.EventMask(0) - e.mu.RLock() - defer e.mu.RUnlock() - switch e.EndpointState() { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. @@ -823,20 +918,22 @@ func (e *endpoint) Abort() { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { - e.mu.Lock() - closed := e.closed - e.closed = true - e.mu.Unlock() - if closed { + e.LockUser() + defer e.UnlockUser() + if e.closed { return } // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. - e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) - - e.mu.Lock() + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + e.closeNoShutdownLocked() +} +// closeNoShutdown closes the endpoint without doing a full shutdown. This is +// used when a connection needs to be aborted with a RST and we want to skip +// a full 4 way TCP shutdown. +func (e *endpoint) closeNoShutdownLocked() { // For listening sockets, we always release ports inline so that they // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail @@ -853,6 +950,8 @@ func (e *endpoint) Close() { e.boundPortFlags = ports.Flags{} } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. switch e.EndpointState() { @@ -873,8 +972,6 @@ func (e *endpoint) Close() { // goroutine terminates. e.notifyProtocolGoroutine(notifyClose) } - - e.mu.Unlock() } // closePendingAcceptableConnections closes all connections that have completed @@ -909,7 +1006,6 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { // after Close() is called and the worker goroutine (if any) is done with its // work. func (e *endpoint) cleanupLocked() { - // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { @@ -954,18 +1050,18 @@ func (e *endpoint) initialReceiveWindow() int { // ModerateRecvBuf adjusts the receive buffer and the advertised window // based on the number of bytes copied to user space. func (e *endpoint) ModerateRecvBuf(copied int) { - e.mu.RLock() + e.LockUser() + defer e.UnlockUser() + e.rcvListMu.Lock() if e.rcvAutoParams.disabled { e.rcvListMu.Unlock() - e.mu.RUnlock() return } now := time.Now() if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { e.rcvAutoParams.copied += copied e.rcvListMu.Unlock() - e.mu.RUnlock() return } prevRTTCopied := e.rcvAutoParams.copied + copied @@ -1021,7 +1117,6 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvAutoParams.measureTime = now e.rcvAutoParams.copied = 0 e.rcvListMu.Unlock() - e.mu.RUnlock() } // IPTables implements tcpip.Endpoint.IPTables. @@ -1031,7 +1126,7 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - e.mu.RLock() + e.LockUser() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the @@ -1041,7 +1136,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError - e.mu.RUnlock() + e.UnlockUser() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } @@ -1051,7 +1146,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, v, err := e.readLocked() e.rcvListMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() if err == tcpip.ErrClosedForReceive { e.stats.ReadErrors.ReadClosed.Increment() @@ -1124,13 +1219,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. - e.mu.RLock() + e.LockUser() e.sndBufMu.Lock() avail, err := e.isEndpointWritableLocked() if err != nil { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } @@ -1142,113 +1237,68 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // are copying data in. if !opts.Atomic { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() } // Fetch data. v, perr := p.Payload(avail) if perr != nil || len(v) == 0 { - if opts.Atomic { // See above. + // Note that perr may be nil if len(v) == 0. + if opts.Atomic { e.sndBufMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() } - // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - if opts.Atomic { + queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { // Add data to the send queue. s := newSegmentFromView(&e.route, e.ID, v) e.sndBufUsed += len(v) e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() - } - if e.workMu.TryLock() { - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() - - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } - - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() - - } // Do the work inline. e.handleWrite() - e.workMu.Unlock() - } else { - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() + e.UnlockUser() + return int64(len(v)), nil, nil + } - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } + if opts.Atomic { + // Locks released in queueAndSend() + return queueAndSend() + } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + e.LockUser() + e.sndBufMu.Lock() + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.UnlockUser() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } - } - // Let the protocol goroutine do the work. - e.sndWaker.Assert() + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] } - return int64(len(v)), nil, nil + // Locks released in queueAndSend() + return queueAndSend() } // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. @@ -1339,6 +1389,9 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // SetSockOptBool sets a socket option. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + e.LockUser() + defer e.UnlockUser() + switch opt { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. @@ -1346,9 +1399,6 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - e.mu.Lock() - defer e.mu.Unlock() - // We only allow this to be set when we're in the initial state. if e.EndpointState() != StateInitial { return tcpip.ErrInvalidEndpointState @@ -1379,7 +1429,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { mask := uint32(notifyReceiveWindowChanged) - e.mu.RLock() + e.LockUser() e.rcvListMu.Lock() // Make sure the receive buffer size allows us to send a @@ -1409,8 +1459,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } + e.rcvListMu.Unlock() - e.mu.RUnlock() + e.UnlockUser() e.notifyProtocolGoroutine(mask) return nil @@ -1466,15 +1517,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil case tcpip.ReuseAddressOption: - e.mu.Lock() + e.LockUser() e.reuseAddr = v != 0 - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.ReusePortOption: - e.mu.Lock() + e.LockUser() e.reusePort = v != 0 - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.BindToDeviceOption: @@ -1482,9 +1533,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { if id != 0 && !e.stack.HasNIC(id) { return tcpip.ErrUnknownDevice } - e.mu.Lock() + e.LockUser() e.bindToDevice = id - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.QuickAckOption: @@ -1500,16 +1551,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { return tcpip.ErrInvalidOptionValue } - e.mu.Lock() + e.LockUser() e.userMSS = uint16(userMSS) - e.mu.Unlock() + e.UnlockUser() e.notifyProtocolGoroutine(notifyMSSChanged) return nil case tcpip.TTLOption: - e.mu.Lock() + e.LockUser() e.ttl = uint8(v) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.KeepaliveEnabledOption: @@ -1541,15 +1592,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil case tcpip.TCPUserTimeoutOption: - e.mu.Lock() + e.LockUser() e.userTimeout = time.Duration(v) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.BroadcastOption: - e.mu.Lock() + e.LockUser() e.broadcast = v != 0 - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.CongestionControlOption: @@ -1563,22 +1614,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { availCC := strings.Split(string(avail), " ") for _, cc := range availCC { if v == tcpip.CongestionControlOption(cc) { - // Acquire the work mutex as we may need to - // reinitialize the congestion control state. - e.mu.Lock() + e.LockUser() state := e.EndpointState() e.cc = v - e.mu.Unlock() switch state { case StateEstablished: - e.workMu.Lock() - e.mu.Lock() if e.EndpointState() == state { e.snd.cc = e.snd.initCongestionControl(e.cc) } - e.mu.Unlock() - e.workMu.Unlock() } + e.UnlockUser() return nil } } @@ -1588,23 +1633,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrNoSuchFile case tcpip.IPv4TOSOption: - e.mu.Lock() + e.LockUser() // TODO(gvisor.dev/issue/995): ECN is not currently supported, // ignore the bits for now. e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.IPv6TrafficClassOption: - e.mu.Lock() + e.LockUser() // TODO(gvisor.dev/issue/995): ECN is not currently supported, // ignore the bits for now. e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.TCPLingerTimeoutOption: - e.mu.Lock() + e.LockUser() if v < 0 { // Same as effectively disabling TCPLinger timeout. v = 0 @@ -1622,16 +1667,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { v = stkTCPLingerTimeout } e.tcpLingerTimeout = time.Duration(v) - e.mu.Unlock() + e.UnlockUser() return nil case tcpip.TCPDeferAcceptOption: - e.mu.Lock() + e.LockUser() if time.Duration(v) > MaxRTO { v = tcpip.TCPDeferAcceptOption(MaxRTO) } e.deferAccept = time.Duration(v) - e.mu.Unlock() + e.UnlockUser() return nil default: @@ -1641,8 +1686,8 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // readyReceiveSize returns the number of bytes ready to be received. func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // The endpoint cannot be in listen state. if e.EndpointState() == StateListen { @@ -1664,9 +1709,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return false, tcpip.ErrUnknownProtocolOption } - e.mu.Lock() + e.LockUser() v := e.v6only - e.mu.Unlock() + e.UnlockUser() return v, nil } @@ -1730,9 +1775,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.ReuseAddressOption: - e.mu.RLock() + e.LockUser() v := e.reuseAddr - e.mu.RUnlock() + e.UnlockUser() *o = 0 if v { @@ -1741,9 +1786,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.ReusePortOption: - e.mu.RLock() + e.LockUser() v := e.reusePort - e.mu.RUnlock() + e.UnlockUser() *o = 0 if v { @@ -1752,9 +1797,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.BindToDeviceOption: - e.mu.RLock() + e.LockUser() *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.mu.RUnlock() + e.UnlockUser() return nil case *tcpip.QuickAckOption: @@ -1765,16 +1810,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.TTLOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TTLOption(e.ttl) - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} - e.mu.RLock() + e.LockUser() snd := e.snd - e.mu.RUnlock() + e.UnlockUser() if snd != nil { snd.rtt.Lock() o.RTT = snd.rtt.srtt @@ -1813,9 +1858,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.TCPUserTimeoutOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TCPUserTimeoutOption(e.userTimeout) - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.OutOfBandInlineOption: @@ -1824,9 +1869,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.BroadcastOption: - e.mu.Lock() + e.LockUser() v := e.broadcast - e.mu.Unlock() + e.UnlockUser() *o = 0 if v { @@ -1835,33 +1880,33 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil case *tcpip.CongestionControlOption: - e.mu.Lock() + e.LockUser() *o = e.cc - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.IPv4TOSOption: - e.mu.RLock() + e.LockUser() *o = tcpip.IPv4TOSOption(e.sendTOS) - e.mu.RUnlock() + e.UnlockUser() return nil case *tcpip.IPv6TrafficClassOption: - e.mu.RLock() + e.LockUser() *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.mu.RUnlock() + e.UnlockUser() return nil case *tcpip.TCPLingerTimeoutOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) - e.mu.Unlock() + e.UnlockUser() return nil case *tcpip.TCPDeferAcceptOption: - e.mu.Lock() + e.LockUser() *o = tcpip.TCPDeferAcceptOption(e.deferAccept) - e.mu.Unlock() + e.UnlockUser() return nil default: @@ -1901,8 +1946,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // yet accepted by the app, they are restored without running the main goroutine // here. func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() connectingAddr := addr.Addr @@ -2071,9 +2116,13 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // Shutdown closes the read and/or write end of the endpoint connection to its // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - e.mu.Lock() + e.LockUser() + defer e.UnlockUser() + return e.shutdownLocked(flags) +} + +func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { e.shutdownFlags |= flags - finQueued := false switch { case e.EndpointState().connected(): // Close for read. @@ -2087,24 +2136,9 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { - e.mu.Unlock() - // Try to send an active reset immediately if the - // work mutex is available. - if e.workMu.TryLock() { - e.mu.Lock() - // We need to double check here to make - // sure worker has not transitioned the - // endpoint out of a connected state - // before trying to send a reset. - if e.EndpointState().connected() { - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.notifyProtocolGoroutine(notifyTickleWorker) - } - e.mu.Unlock() - e.workMu.Unlock() - } else { - e.notifyProtocolGoroutine(notifyReset) - } + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + // Wake up worker to terminate loop. + e.notifyProtocolGoroutine(notifyTickleWorker) return nil } } @@ -2116,42 +2150,32 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // Already closed. e.sndBufMu.Unlock() if e.EndpointState() == StateTimeWait { - e.mu.Unlock() return tcpip.ErrNotConnected } - break + return nil } // Queue fin segment. s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() + e.handleClose() } + return nil case e.EndpointState() == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } + return nil + default: - e.mu.Unlock() return tcpip.ErrNotConnected } - e.mu.Unlock() - if finQueued { - if e.workMu.TryLock() { - e.handleClose() - e.workMu.Unlock() - } else { - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() - } - } - return nil } // Listen puts the endpoint in "listen" mode, which allows it to accept @@ -2166,8 +2190,8 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error { } func (e *endpoint) listen(backlog int) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() // Allow the backlog to be adjusted if the endpoint is not shutting down. // When the endpoint shuts down, it sets workerCleanup to true, and from @@ -2229,7 +2253,6 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. func (e *endpoint) startAcceptedLoop() { - e.mu.Lock() e.workerRunning = true e.mu.Unlock() wakerInitDone := make(chan struct{}) @@ -2240,8 +2263,8 @@ func (e *endpoint) startAcceptedLoop() { // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() // Endpoint must be in listen state before it can accept connections. if e.EndpointState() != StateListen { @@ -2260,8 +2283,8 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { // Bind binds the endpoint to a specific local port and optionally address. func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { - e.mu.Lock() - defer e.mu.Unlock() + e.LockUser() + defer e.UnlockUser() return e.bindLocked(addr) } @@ -2339,8 +2362,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // GetLocalAddress returns the address to which the endpoint is bound. func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() return tcpip.FullAddress{ Addr: e.ID.LocalAddress, @@ -2351,8 +2374,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { // GetRemoteAddress returns the address to which the endpoint is connected. func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() + e.LockUser() + defer e.UnlockUser() if !e.EndpointState().connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected @@ -2419,7 +2442,6 @@ func (e *endpoint) updateSndBufferUsage(v int) { // to be read, or when the connection is closed for receiving (in which case // s will be nil). func (e *endpoint) readyToRead(s *segment) { - e.mu.RLock() e.rcvListMu.Lock() if s != nil { s.incRef() @@ -2434,7 +2456,6 @@ func (e *endpoint) readyToRead(s *segment) { e.rcvClosed = true } e.rcvListMu.Unlock() - e.mu.RUnlock() e.waiterQueue.Notify(waiter.EventIn) } @@ -2578,9 +2599,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.SegTime = time.Now() // Copy EndpointID. - e.mu.Lock() s.ID = stack.TCPEndpointID(e.ID) - e.mu.Unlock() // Copy endpoint rcv state. e.rcvListMu.Lock() @@ -2710,10 +2729,10 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { - e.mu.RLock() + e.LockUser() // Make a copy of the endpoint info. ret := e.EndpointInfo - e.mu.RUnlock() + e.UnlockUser() return &ret } @@ -2728,9 +2747,9 @@ func (e *endpoint) Wait() { e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) defer e.waiterQueue.EventUnregister(&waitEntry) for { - e.mu.Lock() + e.LockUser() running := e.workerRunning - e.mu.Unlock() + e.UnlockUser() if !running { break } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 4a46f0ec5..9175de441 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -162,8 +162,8 @@ func (e *endpoint) loadState(state EndpointState) { connectingLoading.Add(1) } // Directly update the state here rather than using e.setEndpointState - // as the endpoint is still being loaded and the stack reference to increment - // metrics is not yet initialized. + // as the endpoint is still being loaded and the stack reference is not + // yet initialized. atomic.StoreUint32((*uint32)(&e.state), uint32(state)) } @@ -180,7 +180,6 @@ func (e *endpoint) afterLoad() { func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() state := e.origEndpointState switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 73098d904..b0f918bb4 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -95,7 +95,7 @@ const ( ) type protocol struct { - mu sync.Mutex + mu sync.RWMutex sackEnabled bool delayEnabled bool sendBufferSize SendBufferSizeOption @@ -273,57 +273,57 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { case *SACKEnabled: - p.mu.Lock() + p.mu.RLock() *v = SACKEnabled(p.sackEnabled) - p.mu.Unlock() + p.mu.RUnlock() return nil case *DelayEnabled: - p.mu.Lock() + p.mu.RLock() *v = DelayEnabled(p.delayEnabled) - p.mu.Unlock() + p.mu.RUnlock() return nil case *SendBufferSizeOption: - p.mu.Lock() + p.mu.RLock() *v = p.sendBufferSize - p.mu.Unlock() + p.mu.RUnlock() return nil case *ReceiveBufferSizeOption: - p.mu.Lock() + p.mu.RLock() *v = p.recvBufferSize - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.CongestionControlOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.CongestionControlOption(p.congestionControl) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.AvailableCongestionControlOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.ModerateReceiveBufferOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.TCPLingerTimeoutOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) - p.mu.Unlock() + p.mu.RUnlock() return nil case *tcpip.TCPTimeWaitTimeoutOption: - p.mu.Lock() + p.mu.RLock() *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) - p.mu.Unlock() + p.mu.RUnlock() return nil default: diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index d80aff1b6..caf8977b3 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -168,7 +168,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // We just received a FIN, our next state depends on whether we sent a // FIN already or not. - r.ep.mu.Lock() switch r.ep.EndpointState() { case StateEstablished: r.ep.setEndpointState(StateCloseWait) @@ -183,7 +182,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateFinWait2: r.ep.setEndpointState(StateTimeWait) } - r.ep.mu.Unlock() // Flush out any pending segments, except the very first one if // it happens to be the one we're handling now because the @@ -208,7 +206,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { - r.ep.mu.Lock() switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -222,7 +219,6 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateLastAck: r.ep.transitionToStateCloseLocked() } - r.ep.mu.Unlock() } return true @@ -336,10 +332,8 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // handleRcvdSegment handles TCP segments directed at the connection managed by // r as they arrive. It is called by the protocol main loop. func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { - r.ep.mu.RLock() state := r.ep.EndpointState() closed := r.ep.closed - r.ep.mu.RUnlock() if state != StateEstablished { drop, err := r.handleRcvdSegmentClosing(s, state, closed) diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index bd20a7ee9..48a257137 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -28,10 +28,16 @@ type segmentQueue struct { used int } +// emptyLocked determines if the queue is empty. +// Preconditions: q.mu must be held. +func (q *segmentQueue) emptyLocked() bool { + return q.used == 0 +} + // empty determines if the queue is empty. func (q *segmentQueue) empty() bool { q.mu.Lock() - r := q.used == 0 + r := q.emptyLocked() q.mu.Unlock() return r diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 657c3146e..17fed4ec5 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -455,9 +455,7 @@ func (s *sender) retransmitTimerExpired() bool { // Give up if we've waited more than a minute since the last resend or // if a user time out is set and we have exceeded the user specified // timeout since the first retransmission. - s.ep.mu.RLock() uto := s.ep.userTimeout - s.ep.mu.RUnlock() if s.firstRetransmittedSegXmitTime.IsZero() { // We store the original xmitTime of the segment that we are @@ -713,7 +711,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se default: s.ep.setEndpointState(StateFinWait1) } - } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 5b2b16afa..39d36d2ba 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2236,9 +2236,17 @@ func TestSegmentMerging(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - // Prevent the endpoint from processing packets. - test.stop(c.EP) + // Send 10 1 byte segments to fill up InitialWindow but don't + // ACK. That should prevent anymore packets from going out. + for i := 0; i < 10; i++ { + view := buffer.NewViewFromBytes([]byte{0}) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %v", i+1, err) + } + } + // Now send the segments that should get merged as the congestion + // window is full and we won't be able to send any more packets. var allData []byte for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) @@ -2248,8 +2256,29 @@ func TestSegmentMerging(t *testing.T) { } } - // Let the endpoint process the segments that we just sent. - test.resume(c.EP) + // Check that we get 10 packets of 1 byte each. + for i := 0; i < 10; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(header.TCPMinimumSize+1), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+uint32(i)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. + RcvWnd: 30000, + }) // Check that data is received. b := c.GetPacket() @@ -2257,7 +2286,7 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(len(allData)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+11), checker.AckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), @@ -2273,7 +2302,7 @@ func TestSegmentMerging(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))), + AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), RcvWnd: 30000, }) }) -- cgit v1.2.3 From 49aef9cee70d111f6c3e1a6b04430bbe414a6c1e Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Fri, 20 Mar 2020 15:24:00 -0700 Subject: Remove unused variable `sndNxtList`. PiperOrigin-RevId: 302110328 --- pkg/tcpip/transport/tcp/connect.go | 1 - pkg/tcpip/transport/tcp/snd.go | 5 ----- 2 files changed, 6 deletions(-) (limited to 'pkg/tcpip/transport/tcp/snd.go') diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index edb37a549..53193afc6 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -878,7 +878,6 @@ func (e *endpoint) handleWrite() *tcpip.Error { first := e.sndQueue.Front() if first != nil { e.snd.writeList.PushBackList(&e.sndQueue) - e.snd.sndNxtList.UpdateForward(e.sndBufInQueue) e.sndBufInQueue = 0 } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 17fed4ec5..6b7bac37d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -126,10 +126,6 @@ type sender struct { // sndNxt is the sequence number of the next segment to be sent. sndNxt seqnum.Value - // sndNxtList is the sequence number of the next segment to be added to - // the send list. - sndNxtList seqnum.Value - // rttMeasureSeqNum is the sequence number being used for the latest RTT // measurement. rttMeasureSeqNum seqnum.Value @@ -229,7 +225,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint sndWnd: sndWnd, sndUna: iss + 1, sndNxt: iss + 1, - sndNxtList: iss + 1, rto: 1 * time.Second, rttMeasureSeqNum: iss + 1, lastSendTime: time.Now(), -- cgit v1.2.3 From 28212b3f179dc23bb966f72b11f635017cdf8664 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 14 Apr 2020 19:32:32 -0700 Subject: Reduce flakiness in tcp_test. Tests now use a MinRTO of 3s instead of default 200ms. This reduced flakiness in a lot of the congestion control/recovery tests which were flaky due to retransmit timer firing too early in case the test executors were overloaded. This change also bumps some of the timeouts in tests which were too sensitive to timer variations and reduces the number of slow start iterations which can make the tests run for too long and also trigger retansmit timeouts etc if the executor is overloaded. PiperOrigin-RevId: 306562645 --- pkg/tcpip/checker/checker.go | 19 +++++ pkg/tcpip/link/channel/channel.go | 34 ++------- pkg/tcpip/tcpip.go | 4 ++ pkg/tcpip/transport/tcp/BUILD | 5 +- pkg/tcpip/transport/tcp/protocol.go | 17 +++++ pkg/tcpip/transport/tcp/snd.go | 15 +++- pkg/tcpip/transport/tcp/tcp_noracedetector_test.go | 83 ++++++++++++++-------- pkg/tcpip/transport/tcp/tcp_sack_test.go | 2 +- pkg/tcpip/transport/tcp/tcp_test.go | 50 ++++++------- pkg/tcpip/transport/tcp/testing/context/context.go | 17 ++++- 10 files changed, 159 insertions(+), 87 deletions(-) (limited to 'pkg/tcpip/transport/tcp/snd.go') diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 307f1b666..c1745ba6a 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -107,6 +107,8 @@ func DstAddr(addr tcpip.Address) NetworkChecker { // TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). func TTL(ttl uint8) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + var v uint8 switch ip := h[0].(type) { case header.IPv4: @@ -310,6 +312,8 @@ func SrcPort(port uint16) TransportChecker { // DstPort creates a checker that checks the destination port. func DstPort(port uint16) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + if p := h.DestinationPort(); p != port { t.Errorf("Bad destination port, got %v, want %v", p, port) } @@ -336,6 +340,7 @@ func SeqNum(seq uint32) TransportChecker { func AckNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -350,6 +355,8 @@ func AckNum(seq uint32) TransportChecker { // Window creates a checker that checks the tcp window. func Window(window uint16) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -381,6 +388,8 @@ func TCPFlags(flags uint8) TransportChecker { // given mask, match the supplied flags. func TCPFlagsMatch(flags, mask uint8) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -398,6 +407,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker { // If wndscale is negative, the window scale option must not be present. func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -494,6 +505,8 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { // skipped. func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return @@ -612,6 +625,8 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { // Payload creates a checker that checks the payload. func Payload(want []byte) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + if got := h.Payload(); !reflect.DeepEqual(got, want) { t.Errorf("Wrong payload, got %v, want %v", got, want) } @@ -644,6 +659,7 @@ func ICMPv4(checkers ...TransportChecker) NetworkChecker { func ICMPv4Type(want header.ICMPv4Type) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv4, ok := h.(header.ICMPv4) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) @@ -658,6 +674,7 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker { func ICMPv4Code(want byte) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv4, ok := h.(header.ICMPv4) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) @@ -700,6 +717,7 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker { func ICMPv6Type(want header.ICMPv6Type) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv6, ok := h.(header.ICMPv6) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) @@ -714,6 +732,7 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker { func ICMPv6Code(want byte) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() + icmpv6, ok := h.(header.ICMPv6) if !ok { t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index b4a0ae53d..9bf67686d 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -50,13 +50,11 @@ type NotificationHandle struct { } type queue struct { + // c is the outbound packet channel. + c chan PacketInfo // mu protects fields below. - mu sync.RWMutex - // c is the outbound packet channel. Sending to c should hold mu. - c chan PacketInfo - numWrite int - numRead int - notify []*NotificationHandle + mu sync.RWMutex + notify []*NotificationHandle } func (q *queue) Close() { @@ -64,11 +62,8 @@ func (q *queue) Close() { } func (q *queue) Read() (PacketInfo, bool) { - q.mu.Lock() - defer q.mu.Unlock() select { case p := <-q.c: - q.numRead++ return p, true default: return PacketInfo{}, false @@ -76,15 +71,8 @@ func (q *queue) Read() (PacketInfo, bool) { } func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) { - // We have to receive from channel without holding the lock, since it can - // block indefinitely. This will cause a window that numWrite - numRead - // produces a larger number, but won't go to negative. numWrite >= numRead - // still holds. select { case pkt := <-q.c: - q.mu.Lock() - defer q.mu.Unlock() - q.numRead++ return pkt, true case <-ctx.Done(): return PacketInfo{}, false @@ -93,16 +81,12 @@ func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) { func (q *queue) Write(p PacketInfo) bool { wrote := false - - // It's important to make sure nobody can see numWrite until we increment it, - // so numWrite >= numRead holds. - q.mu.Lock() select { case q.c <- p: wrote = true - q.numWrite++ default: } + q.mu.Lock() notify := q.notify q.mu.Unlock() @@ -116,13 +100,7 @@ func (q *queue) Write(p PacketInfo) bool { } func (q *queue) Num() int { - q.mu.RLock() - defer q.mu.RUnlock() - n := q.numWrite - q.numRead - if n < 0 { - panic("numWrite < numRead") - } - return n + return len(q.c) } func (q *queue) AddNotify(notify Notification) *NotificationHandle { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index aec7126ff..109121dbc 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -681,6 +681,10 @@ type TCPTimeWaitTimeoutOption time.Duration // for a handshake till the specified timeout until a segment with data arrives. type TCPDeferAcceptOption time.Duration +// TCPMinRTOOption is use by SetSockOpt/GetSockOpt to allow overriding +// default MinRTO used by the Stack. +type TCPMinRTOOption time.Duration + // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 7f94f9646..edb7718a6 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -87,7 +87,9 @@ go_test( "tcp_timestamp_test.go", ], # FIXME(b/68809571) - tags = ["flaky"], + tags = [ + "flaky", + ], deps = [ ":tcp", "//pkg/sync", @@ -104,5 +106,6 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp/testing/context", "//pkg/waiter", + "//runsc/testutil", ], ) diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index dce9a1652..91f25c132 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -105,6 +105,7 @@ type protocol struct { moderateReceiveBuffer bool tcpLingerTimeout time.Duration tcpTimeWaitTimeout time.Duration + minRTO time.Duration dispatcher *dispatcher } @@ -272,6 +273,15 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPMinRTOOption: + if v < 0 { + v = tcpip.TCPMinRTOOption(MinRTO) + } + p.mu.Lock() + p.minRTO = time.Duration(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -334,6 +344,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil + case *tcpip.TCPMinRTOOption: + p.mu.RLock() + *v = tcpip.TCPMinRTOOption(p.minRTO) + p.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -359,5 +375,6 @@ func NewProtocol() stack.TransportProtocol { tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), + minRTO: MinRTO, } } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 6b7bac37d..d8cfe3115 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -15,6 +15,7 @@ package tcp import ( + "fmt" "math" "sync/atomic" "time" @@ -149,6 +150,9 @@ type sender struct { rtt rtt rto time.Duration + // minRTO is the minimum permitted value for sender.rto. + minRTO time.Duration + // maxPayloadSize is the maximum size of the payload of a given segment. // It is initialized on demand. maxPayloadSize int @@ -260,6 +264,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // etc. s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) + // Get Stack wide minRTO. + var v tcpip.TCPMinRTOOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { + panic(fmt.Sprintf("unable to get minRTO from stack: %s", err)) + } + s.minRTO = time.Duration(v) + return s } @@ -394,8 +405,8 @@ func (s *sender) updateRTO(rtt time.Duration) { s.rto = s.rtt.srtt + 4*s.rtt.rttvar s.rtt.Unlock() - if s.rto < MinRTO { - s.rto = MinRTO + if s.rto < s.minRTO { + s.rto = s.minRTO } } diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 782d7b42c..359a75e73 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/runsc/testutil" ) func TestFastRecovery(t *testing.T) { @@ -40,7 +41,7 @@ func TestFastRecovery(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -86,16 +87,23 @@ func TestFastRecovery(t *testing.T) { // Receive the retransmitted packet. c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) - } + // Wait before checking metrics. + metricPollFn := func() error { + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want) + } + return nil } - if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want) + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } // Now send 7 mode duplicate acks. Each of these should cause a window @@ -117,12 +125,18 @@ func TestFastRecovery(t *testing.T) { // Receive the retransmit due to partial ack. c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + // Wait before checking metrics. + metricPollFn = func() error { + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { + return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + } + return nil } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } // Receive the 10 extra packets that should have been released due to @@ -192,7 +206,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -234,7 +248,7 @@ func TestCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -338,7 +352,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { @@ -447,7 +461,7 @@ func TestRetransmit(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const iterations = 7 + const iterations = 3 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) @@ -492,24 +506,33 @@ func TestRetransmit(t *testing.T) { rtxOffset := bytesRead - maxPayload*expected c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want) - } + metricPollFn := func() error { + if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want) + } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) - } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) + } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { - t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) - } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { + return fmt.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) + } + + if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) + } - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { - t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) + return nil } - if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) + // Poll when checking metrics. + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) } // Acknowledge half of the pending data. diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index afea124ec..c439d5281 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -387,7 +387,7 @@ func TestSACKRecovery(t *testing.T) { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - const iterations = 7 + const iterations = 3 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) for i := range data { data[i] = byte(i) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 29301a45c..41caa9ed4 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -590,6 +590,10 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { ), ) + // Give the stack a few ms to transition the endpoint out of ESTABLISHED + // state. + time.Sleep(10 * time.Millisecond) + if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } @@ -4472,8 +4476,8 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - const keepAliveInterval = 10 * time.Millisecond - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + const keepAliveInterval = 3 * time.Second + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) @@ -4567,7 +4571,7 @@ func TestKeepalive(t *testing.T) { // Sleep for a litte over the KeepAlive interval to make sure // the timer has time to fire after the last ACK and close the // close the socket. - time.Sleep(keepAliveInterval + 5*time.Millisecond) + time.Sleep(keepAliveInterval + keepAliveInterval/2) // The connection should be terminated after 5 unacked keepalives. // Send an ACK to trigger a RST from the stack as the endpoint should @@ -6615,14 +6619,17 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - const keepAliveInterval = 10 * time.Millisecond - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + const keepAliveInterval = 3 * time.Second + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10) c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) - // Set userTimeout to be the duration for 3 keepalive probes. - userTimeout := 30 * time.Millisecond + // Set userTimeout to be the duration to be 1 keepalive + // probes. Which means that after the first probe is sent + // the second one should cause the connection to be + // closed due to userTimeout being hit. + userTimeout := 1 * keepAliveInterval c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) // Check that the connection is still alive. @@ -6630,28 +6637,23 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) } - // Now receive 2 keepalives, but don't ACK them. The connection should - // be reset when the 3rd one should be sent due to userTimeout being - // 30ms and each keepalive probe should be sent 10ms apart as set above after - // the connection has been idle for 10ms. - for i := 0; i < 2; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } + // Now receive 1 keepalives, but don't ACK it. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) // Sleep for a litte over the KeepAlive interval to make sure // the timer has time to fire after the last ACK and close the // close the socket. - time.Sleep(keepAliveInterval + 5*time.Millisecond) + time.Sleep(keepAliveInterval + keepAliveInterval/2) - // The connection should be terminated after 30ms. + // The connection should be closed with a timeout. // Send an ACK to trigger a RST from the stack as the endpoint should // be dead. c.SendPacket(nil, &context.Headers{ diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 431ab4e6b..7b1d72cf4 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -152,6 +152,13 @@ func New(t *testing.T, mtu uint32) *Context { t.Fatalf("SetTransportProtocolOption failed: %v", err) } + // Increase minimum RTO in tests to avoid test flakes due to early + // retransmit in case the test executors are overloaded and cause timers + // to fire earlier than expected. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil { + t.Fatalf("failed to set stack-wide minRTO: %s", err) + } + // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. ep := channel.New(1000, mtu, "") @@ -236,7 +243,7 @@ func (c *Context) CheckNoPacket(errMsg string) { func (c *Context) GetPacket() []byte { c.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { @@ -417,6 +424,8 @@ func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlock // verifies that the packet packet payload of packet matches the slice // of data indicated by offset & size. func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { + c.t.Helper() + c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0) } @@ -425,6 +434,8 @@ func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { // data indicated by offset & size and skips optlen bytes in addition to the IP // TCP headers when comparing the data. func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) { + c.t.Helper() + b := c.GetPacket() checker.IPv4(c.t, b, checker.PayloadLen(size+header.TCPMinimumSize+optlen), @@ -447,6 +458,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op // data indicated by offset & size. It returns true if a packet was received and // processed. func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { + c.t.Helper() + b := c.GetPacketNonBlocking() if b == nil { return false @@ -570,6 +583,8 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf // // PreCondition: c.EP must already be created. func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { + c.t.Helper() + // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventOut) -- cgit v1.2.3