From fff04769518b279a364c928307a71055eaa6166d Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Mon, 13 Jan 2020 13:08:36 -0800 Subject: benchmarks/tcp: set a number of channels to GOMAXPROCS Updates #231 PiperOrigin-RevId: 289502669 --- benchmarks/tcp/tcp_benchmark.sh | 21 ++++++++++++++- benchmarks/tcp/tcp_proxy.go | 58 ++++++++++++++++++++++------------------- 2 files changed, 51 insertions(+), 28 deletions(-) (limited to 'benchmarks') diff --git a/benchmarks/tcp/tcp_benchmark.sh b/benchmarks/tcp/tcp_benchmark.sh index 69344c9c3..e65801a7b 100755 --- a/benchmarks/tcp/tcp_benchmark.sh +++ b/benchmarks/tcp/tcp_benchmark.sh @@ -41,6 +41,8 @@ duplicate=0.1 # 0.1% means duplicates are 1/10x as frequent as losses. duration=30 # 30s is enough time to consistent results (experimentally). helper_dir=$(dirname $0) netstack_opts= +disable_linux_gso= +num_client_threads=1 # Check for netem support. lsmod_output=$(lsmod | grep sch_netem) @@ -125,6 +127,13 @@ while [ $# -gt 0 ]; do shift netstack_opts="${netstack_opts} -memprofile=$1" ;; + --disable-linux-gso) + disable_linux_gso=1 + ;; + --num-client-threads) + shift + num_client_threads=$1 + ;; --helpers) shift [ "$#" -le 0 ] && echo "no helper dir provided" && exit 1 @@ -147,6 +156,8 @@ while [ $# -gt 0 ]; do echo " --loss set the loss probability (%)" echo " --duplicate set the duplicate probability (%)" echo " --helpers set the helper directory" + echo " --num-client-threads number of parallel client threads to run" + echo " --disable-linux-gso disable segmentation offload in the Linux network stack" echo "" echo "The output will of the script will be:" echo " " @@ -301,6 +312,14 @@ fi # Add client and server addresses, and bring everything up. ${nsjoin_binary} /tmp/client.netns ip addr add ${client_addr}/${mask} dev client.0 ${nsjoin_binary} /tmp/server.netns ip addr add ${server_addr}/${mask} dev server.0 +if [ "${disable_linux_gso}" == "1" ]; then + ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 tso off + ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gro off + ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gso off + ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 tso off + ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gso off + ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gro off +fi ${nsjoin_binary} /tmp/client.netns ip link set client.0 up ${nsjoin_binary} /tmp/client.netns ip link set lo up ${nsjoin_binary} /tmp/server.netns ip link set server.0 up @@ -338,7 +357,7 @@ trap cleanup EXIT # Run the benchmark, recording the results file. while ${nsjoin_binary} /tmp/client.netns iperf \\ - -p ${proxy_port} -c ${client_addr} -t ${duration} -f m 2>&1 \\ + -p ${proxy_port} -c ${client_addr} -t ${duration} -f m -P ${num_client_threads} 2>&1 \\ | tee \$results_file \\ | grep "connect failed" >/dev/null; do sleep 0.1 # Wait for all services. diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index 361a56755..be0d7bdd6 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -94,11 +94,11 @@ type netstackImpl struct { mode string } -func setupNetwork(ifaceName string) (fd int, err error) { +func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) { // Get all interfaces in the namespace. ifaces, err := net.Interfaces() if err != nil { - return -1, fmt.Errorf("querying interfaces: %v", err) + return nil, fmt.Errorf("querying interfaces: %v", err) } for _, iface := range ifaces { @@ -107,39 +107,43 @@ func setupNetwork(ifaceName string) (fd int, err error) { } // Create the socket. const protocol = 0x0300 // htons(ETH_P_ALL) - fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol) - if err != nil { - return -1, fmt.Errorf("unable to create raw socket: %v", err) - } + fds := make([]int, numChannels) + for i := range fds { + fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol) + if err != nil { + return nil, fmt.Errorf("unable to create raw socket: %v", err) + } - // Bind to the appropriate device. - ll := syscall.SockaddrLinklayer{ - Protocol: protocol, - Ifindex: iface.Index, - Pkttype: syscall.PACKET_HOST, - } - if err := syscall.Bind(fd, &ll); err != nil { - return -1, fmt.Errorf("unable to bind to %q: %v", iface.Name, err) - } + // Bind to the appropriate device. + ll := syscall.SockaddrLinklayer{ + Protocol: protocol, + Ifindex: iface.Index, + Pkttype: syscall.PACKET_HOST, + } + if err := syscall.Bind(fd, &ll); err != nil { + return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err) + } - // RAW Sockets by default have a very small SO_RCVBUF of 256KB, - // up it to at least 1MB to reduce packet drops. - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, rcvBufSize); err != nil { - return -1, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) - } + // RAW Sockets by default have a very small SO_RCVBUF of 256KB, + // up it to at least 1MB to reduce packet drops. + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, rcvBufSize); err != nil { + return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) + } - if !*swgso && *gso != 0 { - if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { - return -1, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) + if !*swgso && *gso != 0 { + if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { + return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) + } } + fds[i] = fd } - return fd, nil + return fds, nil } - return -1, fmt.Errorf("failed to find interface: %v", ifaceName) + return nil, fmt.Errorf("failed to find interface: %v", ifaceName) } func newNetstackImpl(mode string) (impl, error) { - fd, err := setupNetwork(*iface) + fds, err := setupNetwork(*iface, runtime.GOMAXPROCS(-1)) if err != nil { return nil, err } @@ -177,7 +181,7 @@ func newNetstackImpl(mode string) (impl, error) { mac[0] &^= 0x1 // Clear multicast bit. mac[0] |= 0x2 // Set local assignment bit (IEEE802). ep, err := fdbased.New(&fdbased.Options{ - FDs: []int{fd}, + FDs: fds, MTU: uint32(*mtu), EthernetHeader: true, Address: tcpip.LinkAddress(mac), -- cgit v1.2.3 From a611fdaee3c14abe2222140ae0a8a742ebfd31ab Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 14 Jan 2020 14:14:17 -0800 Subject: Changes TCP packet dispatch to use a pool of goroutines. All inbound segments for connections in ESTABLISHED state are delivered to the endpoint's queue but for every segment delivered we also queue the endpoint for processing to a selected processor. This ensures that when there are a large number of connections in ESTABLISHED state the inbound packets are all handled by a small number of goroutines and significantly reduces the amount of work the goscheduler has to perform. We let connections in other states follow the current path where the endpoint's goroutine directly handles the segments. Updates #231 PiperOrigin-RevId: 289728325 --- benchmarks/tcp/tcp_proxy.go | 6 +- pkg/sleep/sleep_test.go | 31 +++ pkg/tcpip/stack/transport_demuxer.go | 54 ++++- pkg/tcpip/transport/tcp/BUILD | 15 +- pkg/tcpip/transport/tcp/accept.go | 9 +- pkg/tcpip/transport/tcp/connect.go | 310 ++++++++++++++++------------ pkg/tcpip/transport/tcp/dispatcher.go | 218 +++++++++++++++++++ pkg/tcpip/transport/tcp/endpoint.go | 303 ++++++++++++++++++--------- pkg/tcpip/transport/tcp/endpoint_state.go | 30 +-- pkg/tcpip/transport/tcp/protocol.go | 11 + pkg/tcpip/transport/tcp/rcv.go | 21 +- pkg/tcpip/transport/tcp/snd.go | 14 +- pkg/tcpip/transport/tcp/tcp_test.go | 11 +- test/syscalls/linux/socket_inet_loopback.cc | 2 +- test/syscalls/linux/tcp_socket.cc | 14 ++ 15 files changed, 769 insertions(+), 280 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/dispatcher.go (limited to 'benchmarks') diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index be0d7bdd6..dc96add66 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -85,7 +85,7 @@ func (netImpl) printStats() { const ( nicID = 1 // Fixed. - rcvBufSize = 1 << 20 // 1MB. + rcvBufSize = 4 << 20 // 1MB. ) type netstackImpl struct { @@ -130,6 +130,10 @@ func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) { return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) } + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, rcvBufSize); err != nil { + return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) + } + if !*swgso && *gso != 0 { if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go index 130806c86..af47e2ba1 100644 --- a/pkg/sleep/sleep_test.go +++ b/pkg/sleep/sleep_test.go @@ -376,6 +376,37 @@ func TestRace(t *testing.T) { } } +// TestRaceInOrder tests that multiple wakers can continuously send wake requests to +// the sleeper and that the wakers are retrieved in the order asserted. +func TestRaceInOrder(t *testing.T) { + const wakers = 100 + const wakeRequests = 10000 + + w := make([]Waker, wakers) + s := Sleeper{} + + // Associate each waker and start goroutines that will assert them. + for i := range w { + s.AddWaker(&w[i], i) + } + go func() { + n := 0 + for n < wakeRequests { + wk := w[n%len(w)] + wk.Assert() + n++ + } + }() + + // Wait for all wake up notifications from all wakers. + for i := 0; i < wakeRequests; i++ { + v, _ := s.Fetch(true) + if got, want := v, i%wakers; got != want { + t.Fatalf("got %d want %d", got, want) + } + } +} + // BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up // from 4 wakers when at least one is already asserted. func BenchmarkSleeperMultiSelect(b *testing.B) { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index f384a91de..d686e6eb8 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -104,7 +104,14 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p return } // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) + transEP := selectEndpoint(id, mpep, epsByNic.seed) + if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { + queuedProtocol.QueuePacket(r, transEP, id, pkt) + epsByNic.mu.RUnlock() + return + } + + transEP.HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. } @@ -130,7 +137,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { epsByNic.mu.Lock() defer epsByNic.mu.Unlock() @@ -140,7 +147,7 @@ func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort } // This is a new binding. - multiPortEp := &multiPortEndpoint{} + multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto} multiPortEp.endpointsMap = make(map[TransportEndpoint]int) multiPortEp.reuse = reusePort epsByNic.endpoints[bindToDevice] = multiPortEp @@ -168,18 +175,34 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T // newTransportDemuxer. type transportDemuxer struct { // protocol is immutable. - protocol map[protocolIDs]*transportEndpoints + protocol map[protocolIDs]*transportEndpoints + queuedProtocols map[protocolIDs]queuedTransportProtocol +} + +// queuedTransportProtocol if supported by a protocol implementation will cause +// the dispatcher to delivery packets to the QueuePacket method instead of +// calling HandlePacket directly on the endpoint. +type queuedTransportProtocol interface { + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { - d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} + d := &transportDemuxer{ + protocol: make(map[protocolIDs]*transportEndpoints), + queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), + } // Add each network and transport pair to the demuxer. for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { - d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ + protoIDs := protocolIDs{netProto, proto} + d.protocol[protoIDs] = &transportEndpoints{ endpoints: make(map[TransportEndpointID]*endpointsByNic), } + qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) + if isQueued { + d.queuedProtocols[protoIDs] = qTransProto + } } } @@ -209,7 +232,11 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // // +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + demux *transportDemuxer + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int // reuse indicates if more than one endpoint is allowed. @@ -258,13 +285,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() + queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] for i, endpoint := range ep.endpointsArr { // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. if i == len(ep.endpointsArr)-1 { + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt) + break + } endpoint.HandlePacket(r, id, pkt) break } + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + continue + } endpoint.HandlePacket(r, id, pkt.Clone()) } ep.mu.RUnlock() // Don't use defer for performance reasons. @@ -357,7 +393,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if epsByNic, ok := eps.endpoints[id]; ok { // There was already a binding. - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // This is a new binding. @@ -367,7 +403,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol } eps.endpoints[id] = epsByNic - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 353bd06f4..0e3ab05ad 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -16,6 +16,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tcp_endpoint_list", + out = "tcp_endpoint_list.go", + package = "tcp", + prefix = "endpoint", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*endpoint", + "Linker": "*endpoint", + }, +) + go_library( name = "tcp", srcs = [ @@ -23,6 +35,7 @@ go_library( "connect.go", "cubic.go", "cubic_state.go", + "dispatcher.go", "endpoint.go", "endpoint_state.go", "forwarder.go", @@ -38,6 +51,7 @@ go_library( "segment_state.go", "snd.go", "snd_state.go", + "tcp_endpoint_list.go", "tcp_segment_list.go", "timer.go", ], @@ -45,7 +59,6 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ - "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 1ea996936..1a2e3efa9 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -285,7 +285,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // listenEP is nil when listenContext is used by tcp.Forwarder. if l.listenEP != nil { l.listenEP.mu.Lock() - if l.listenEP.state != StateListen { + if l.listenEP.EndpointState() != StateListen { l.listenEP.mu.Unlock() return nil, tcpip.ErrConnectionAborted } @@ -344,11 +344,12 @@ func (l *listenContext) closeAllPendingEndpoints() { // instead. func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.Lock() - state := e.state + state := e.EndpointState() e.pendingAccepted.Add(1) defer e.pendingAccepted.Done() acceptedChan := e.acceptedChan e.mu.Unlock() + if state == StateListen { acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) @@ -562,8 +563,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // We do not use transitionToStateEstablishedLocked here as there is // no handshake state available when doing a SYN cookie based accept. n.stack.Stats().TCP.CurrentEstablished.Increment() - n.state = StateEstablished n.isConnectNotified = true + n.setEndpointState(StateEstablished) // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -596,7 +597,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // handleSynSegment() from attempting to queue new connections // to the endpoint. e.mu.Lock() - e.state = StateClose + e.setEndpointState(StateClose) // close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 613ec1775..f3896715b 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -190,7 +190,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.ep.mu.Lock() - h.ep.state = StateSynRecv + h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() } @@ -274,14 +274,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // SYN-RCVD state. h.state = handshakeSynRcvd h.ep.mu.Lock() - h.ep.state = StateSynRecv ttl := h.ep.ttl + h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), TS: rcvSynOpts.TS, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), // We only send SACKPermitted if the other side indicated it // permits SACK. This is not explicitly defined in the RFC but @@ -341,7 +341,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { WS: h.rcvWndScale, TS: h.ep.sendTSOk, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } @@ -501,7 +501,7 @@ func (h *handshake) execute() *tcpip.Error { WS: h.rcvWndScale, TS: true, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } @@ -792,7 +792,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // Ref: https://tools.ietf.org/html/rfc7323#section-5.4. offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:]) + offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) } if e.sackPermitted && len(sackBlocks) > 0 { offset += header.EncodeNOP(options[offset:]) @@ -811,7 +811,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -848,6 +848,9 @@ func (e *endpoint) handleWrite() *tcpip.Error { } func (e *endpoint) handleClose() *tcpip.Error { + if !e.EndpointState().connected() { + return nil + } // Drain the send queue. e.handleWrite() @@ -864,11 +867,7 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. - if e.state == StateEstablished || e.state == StateCloseWait { - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() - } - e.state = StateError + e.setEndpointState(StateError) e.HardError = err if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the @@ -888,9 +887,12 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { } // completeWorkerLocked is called by the worker goroutine when it's about to -// exit. It marks the worker as completed and performs cleanup work if requested -// by Close(). +// exit. func (e *endpoint) completeWorkerLocked() { + // Worker is terminating(either due to moving to + // CLOSED or ERROR state, ensure we release all + // registrations port reservations even if the socket + // itself is not yet closed by the application. e.workerRunning = false if e.workerCleanup { e.cleanupLocked() @@ -917,8 +919,7 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { e.rcvAutoParams.prevCopied = int(h.rcvWnd) e.rcvListMu.Unlock() } - h.ep.stack.Stats().TCP.CurrentEstablished.Increment() - e.state = StateEstablished + e.setEndpointState(StateEstablished) } // transitionToStateCloseLocked ensures that the endpoint is @@ -927,11 +928,12 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // delivered to this endpoint from the demuxer when the endpoint // is transitioned to StateClose. func (e *endpoint) transitionToStateCloseLocked() { - if e.state == StateClose { + if e.EndpointState() == StateClose { return } + // Mark the endpoint as fully closed for reads/writes. e.cleanupLocked() - e.state = StateClose + e.setEndpointState(StateClose) e.stack.Stats().TCP.EstablishedClosed.Increment() } @@ -946,7 +948,9 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { s.decRef() return } - ep.(*endpoint).enqueueSegment(s) + if ep.(*endpoint).enqueueSegment(s) { + ep.(*endpoint).newSegmentWaker.Assert() + } } func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { @@ -955,9 +959,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // except SYN-SENT, all reset (RST) segments are // validated by checking their SEQ-fields." So // we only process it if it's acceptable. - s.decRef() e.mu.Lock() - switch e.state { + switch e.EndpointState() { // In case of a RST in CLOSE-WAIT linux moves // the socket to closed state with an error set // to indicate EPIPE. @@ -981,103 +984,57 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { e.transitionToStateCloseLocked() e.HardError = tcpip.ErrAborted e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: e.mu.Unlock() + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + + // Notify protocol goroutine. This is required when + // handleSegment is invoked from the processor goroutine + // rather than the worker goroutine. + e.notifyProtocolGoroutine(notifyResetByPeer) return false, tcpip.ErrConnectionReset } } return true, nil } -// handleSegments pulls segments from the queue and processes them. It returns -// no error if the protocol loop should continue, an error otherwise. -func (e *endpoint) handleSegments() *tcpip.Error { +// handleSegments processes all inbound segments. +func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { + if e.EndpointState() == StateClose || e.EndpointState() == StateError { + return nil + } s := e.segmentQueue.dequeue() if s == nil { checkRequeue = false break } - // Invoke the tcp probe if installed. - if e.probe != nil { - e.probe(e.completeState()) + cont, err := e.handleSegment(s) + if err != nil { + s.decRef() + e.mu.Lock() + e.setEndpointState(StateError) + e.HardError = err + e.mu.Unlock() + return err } - - if s.flagIsSet(header.TCPFlagRst) { - if ok, err := e.handleReset(s); !ok { - return err - } - } else if s.flagIsSet(header.TCPFlagSyn) { - // See: https://tools.ietf.org/html/rfc5961#section-4.1 - // 1) If the SYN bit is set, irrespective of the sequence number, TCP - // MUST send an ACK (also referred to as challenge ACK) to the remote - // peer: - // - // - // - // After sending the acknowledgment, TCP MUST drop the unacceptable - // segment and stop processing further. - // - // By sending an ACK, the remote peer is challenged to confirm the loss - // of the previous connection and the request to start a new connection. - // A legitimate peer, after restart, would not have a TCB in the - // synchronized state. Thus, when the ACK arrives, the peer should send - // a RST segment back with the sequence number derived from the ACK - // field that caused the RST. - - // This RST will confirm that the remote peer has indeed closed the - // previous connection. Upon receipt of a valid RST, the local TCP - // endpoint MUST terminate its connection. The local TCP endpoint - // should then rely on SYN retransmission from the remote end to - // re-establish the connection. - - e.snd.sendAck() - } else if s.flagIsSet(header.TCPFlagAck) { - // Patch the window size in the segment according to the - // send window scale. - s.window <<= e.snd.sndWndScale - - // RFC 793, page 41 states that "once in the ESTABLISHED - // state all segments must carry current acknowledgment - // information." - drop, err := e.rcv.handleRcvdSegment(s) - if err != nil { - s.decRef() - return err - } - if drop { - s.decRef() - continue - } - - // Now check if the received segment has caused us to transition - // to a CLOSED state, if yes then terminate processing and do - // not invoke the sender. - e.mu.RLock() - state := e.state - e.mu.RUnlock() - if state == StateClose { - // When we get into StateClose while processing from the queue, - // return immediately and let the protocolMainloop handle it. - // - // We can reach StateClose only while processing a previous segment - // or a notification from the protocolMainLoop (caller goroutine). - // This means that with this return, the segment dequeue below can - // never occur on a closed endpoint. - s.decRef() - return nil - } - e.snd.handleRcvdSegment(s) + if !cont { + s.decRef() + return nil } - s.decRef() } - // If the queue is not empty, make sure we'll wake up in the next - // iteration. - if checkRequeue && !e.segmentQueue.empty() { + // When fastPath is true we don't want to wake up the worker + // goroutine. If the endpoint has more segments to process the + // dispatcher will call handleSegments again anyway. + if !fastPath && checkRequeue && !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } @@ -1086,11 +1043,88 @@ func (e *endpoint) handleSegments() *tcpip.Error { e.snd.sendAck() } - e.resetKeepaliveTimer(true) + e.resetKeepaliveTimer(true /* receivedData */) return nil } +// handleSegment handles a given segment and notifies the worker goroutine if +// if the connection should be terminated. +func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { + // Invoke the tcp probe if installed. + if e.probe != nil { + e.probe(e.completeState()) + } + + if s.flagIsSet(header.TCPFlagRst) { + if ok, err := e.handleReset(s); !ok { + return false, err + } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() + } else if s.flagIsSet(header.TCPFlagAck) { + // Patch the window size in the segment according to the + // send window scale. + s.window <<= e.snd.sndWndScale + + // RFC 793, page 41 states that "once in the ESTABLISHED + // state all segments must carry current acknowledgment + // information." + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + return false, err + } + if drop { + return true, nil + } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return false, nil + } + + e.snd.handleRcvdSegment(s) + } + + return true, nil +} + // keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. @@ -1160,7 +1194,7 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. -func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1182,6 +1216,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() + e.workMu.Unlock() // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1193,7 +1228,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { initialRcvWnd := e.initialReceiveWindow() h := newHandshake(e, seqnum.Size(initialRcvWnd)) e.mu.Lock() - h.ep.state = StateSynSent + h.ep.setEndpointState(StateSynSent) e.mu.Unlock() if err := h.execute(); err != nil { @@ -1202,12 +1237,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() - e.state = StateError + e.setEndpointState(StateError) e.HardError = err // Lock released below. epilogue() - return err } } @@ -1215,7 +1249,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.keepalive.timer.init(&e.keepalive.waker) defer e.keepalive.timer.cleanup() - // Tell waiters that the endpoint is connected and writable. e.mu.Lock() drained := e.drainDone != nil e.mu.Unlock() @@ -1224,8 +1257,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { <-e.undrain } - e.waiterQueue.Notify(waiter.EventOut) - // Set up the functions that will be called when the main protocol loop // wakes up. funcs := []struct { @@ -1240,18 +1271,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { w: &e.sndCloseWaker, f: e.handleClose, }, - { - w: &e.newSegmentWaker, - f: e.handleSegments, - }, { w: &closeWaker, f: func() *tcpip.Error { // This means the socket is being closed due - // to the TCP_FIN_WAIT2 timeout was hit. Just + // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. e.mu.Lock() e.transitionToStateCloseLocked() + e.workerCleanup = true e.mu.Unlock() return nil }, @@ -1266,6 +1294,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil }, }, + { + w: &e.newSegmentWaker, + f: func() *tcpip.Error { + return e.handleSegments(false /* fastPath */) + }, + }, { w: &e.keepalive.waker, f: e.keepaliveTimerExpired, @@ -1293,14 +1327,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } if n¬ifyReset != 0 { - e.mu.Lock() - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.mu.Unlock() + return tcpip.ErrConnectionAborted + } + + if n¬ifyResetByPeer != 0 { + return tcpip.ErrConnectionReset } if n¬ifyClose != 0 && closeTimer == nil { e.mu.Lock() - if e.state == StateFinWait2 && e.closed { + if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { @@ -1320,11 +1356,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { - if err := e.handleSegments(); err != nil { + if err := e.handleSegments(false /* fastPath */); err != nil { return err } } - if e.state != StateClose && e.state != StateError { + if e.EndpointState() != StateClose && e.EndpointState() != StateError { // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) @@ -1349,14 +1385,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { s.AddWaker(funcs[i].w, i) } + // Notify the caller that the waker initialization is complete and the + // endpoint is ready. + if wakerInitDone != nil { + close(wakerInitDone) + } + + // Tell waiters that the endpoint is connected and writable. + e.waiterQueue.Notify(waiter.EventOut) + // The following assertions and notifications are needed for restored // endpoints. Fresh newly created endpoints have empty states and should // not invoke any. - e.segmentQueue.mu.Lock() - if !e.segmentQueue.list.Empty() { + if !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } - e.segmentQueue.mu.Unlock() e.rcvListMu.Lock() if !e.rcvList.Empty() { @@ -1372,27 +1415,32 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError { e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() + // We need to double check here because the notification + // maybe stale by the time we got around to processing it. + // NOTE: since we now hold the workMu the processors cannot + // change the state of the endpoint so it' safe to proceed + // after this check. + if e.EndpointState() == StateTimeWait || e.EndpointState() == StateClose || e.EndpointState() == StateError { + e.mu.Lock() + break + } if err := funcs[v].f(); err != nil { e.mu.Lock() - // Ensure we release all endpoint registration and route - // references as the connection is now in an error - // state. e.workerCleanup = true e.resetConnectionLocked(err) // Lock released below. epilogue() - return nil } e.mu.Lock() } - state := e.state + state := e.EndpointState() e.mu.Unlock() var reuseTW func() if state == StateTimeWait { @@ -1405,13 +1453,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { s.Done() // Wake up any waiters before we enter TIME_WAIT. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.mu.Lock() + e.workerCleanup = true + e.mu.Unlock() reuseTW = e.doTimeWait() } // Mark endpoint as closed. e.mu.Lock() - if e.state != StateError { - e.stack.Stats().TCP.CurrentEstablished.Decrement() + if e.EndpointState() != StateError { e.transitionToStateCloseLocked() } @@ -1468,7 +1518,11 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() tcpEP := listenEP.(*endpoint) if EndpointState(tcpEP.State()) == StateListen { reuseTW = func() { - tcpEP.enqueueSegment(s) + if !tcpEP.enqueueSegment(s) { + s.decRef() + return + } + tcpEP.newSegmentWaker.Assert() } // We explicitly do not decRef // the segment as it's still diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go new file mode 100644 index 000000000..a72f0c379 --- /dev/null +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -0,0 +1,218 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// epQueue is a queue of endpoints. +type epQueue struct { + mu sync.Mutex + list endpointList +} + +// enqueue adds e to the queue if the endpoint is not already on the queue. +func (q *epQueue) enqueue(e *endpoint) { + q.mu.Lock() + if e.pendingProcessing { + q.mu.Unlock() + return + } + q.list.PushBack(e) + e.pendingProcessing = true + q.mu.Unlock() +} + +// dequeue removes and returns the first element from the queue if available, +// returns nil otherwise. +func (q *epQueue) dequeue() *endpoint { + q.mu.Lock() + if e := q.list.Front(); e != nil { + q.list.Remove(e) + e.pendingProcessing = false + q.mu.Unlock() + return e + } + q.mu.Unlock() + return nil +} + +// empty returns true if the queue is empty, false otherwise. +func (q *epQueue) empty() bool { + q.mu.Lock() + v := q.list.Empty() + q.mu.Unlock() + return v +} + +// processor is responsible for processing packets queued to a tcp endpoint. +type processor struct { + epQ epQueue + newEndpointWaker sleep.Waker + id int +} + +func newProcessor(id int) *processor { + p := &processor{ + id: id, + } + go p.handleSegments() + return p +} + +func (p *processor) queueEndpoint(ep *endpoint) { + // Queue an endpoint for processing by the processor goroutine. + p.epQ.enqueue(ep) + p.newEndpointWaker.Assert() +} + +func (p *processor) handleSegments() { + const newEndpointWaker = 1 + s := sleep.Sleeper{} + s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + defer s.Done() + for { + s.Fetch(true) + for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { + if ep.segmentQueue.empty() { + continue + } + + // If socket has transitioned out of connected state + // then just let the worker handle the packet. + // + // NOTE: We read this outside of e.mu lock which means + // that by the time we get to handleSegments the + // endpoint may not be in ESTABLISHED. But this should + // be fine as all normal shutdown states are handled by + // handleSegments and if the endpoint moves to a + // CLOSED/ERROR state then handleSegments is a noop. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + continue + } + + if !ep.workMu.TryLock() { + ep.newSegmentWaker.Assert() + continue + } + // If the endpoint is in a connected state then we do + // direct delivery to ensure low latency and avoid + // scheduler interactions. + if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { + ep.notifyProtocolGoroutine(notifyTickleWorker) + ep.workMu.Unlock() + continue + } + + if !ep.segmentQueue.empty() { + p.epQ.enqueue(ep) + } + + ep.workMu.Unlock() + } + } +} + +// dispatcher manages a pool of TCP endpoint processors which are responsible +// for the processing of inbound segments. This fixed pool of processor +// goroutines do full tcp processing. The processor is selected based on the +// hash of the endpoint id to ensure that delivery for the same endpoint happens +// in-order. +type dispatcher struct { + processors []*processor + seed uint32 +} + +func newDispatcher(nProcessors int) *dispatcher { + processors := []*processor{} + for i := 0; i < nProcessors; i++ { + processors = append(processors, newProcessor(i)) + } + return &dispatcher{ + processors: processors, + seed: generateRandUint32(), + } +} + +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + ep := stackEP.(*endpoint) + s := newSegment(r, id, pkt) + if !s.parse() { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + s.decRef() + return + } + + if !s.csumValid { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.ChecksumErrors.Increment() + ep.stats.ReceiveErrors.ChecksumErrors.Increment() + s.decRef() + return + } + + ep.stack.Stats().TCP.ValidSegmentsReceived.Increment() + ep.stats.SegmentsReceived.Increment() + if (s.flags & header.TCPFlagRst) != 0 { + ep.stack.Stats().TCP.ResetsReceived.Increment() + } + + if !ep.enqueueSegment(s) { + s.decRef() + return + } + + // For sockets not in established state let the worker goroutine + // handle the packets. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + return + } + + d.selectProcessor(id).queueEndpoint(ep) +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { + payload := []byte{ + byte(id.LocalPort), + byte(id.LocalPort >> 8), + byte(id.RemotePort), + byte(id.RemotePort >> 8)} + + h := jenkins.Sum32(d.seed) + h.Write(payload) + h.Write([]byte(id.LocalAddress)) + h.Write([]byte(id.RemoteAddress)) + + return d.processors[h.Sum32()%uint32(len(d.processors))] +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cc8b533c8..1799c6e10 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -120,6 +120,7 @@ const ( notifyMTUChanged notifyDrain notifyReset + notifyResetByPeer notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -127,6 +128,7 @@ const ( // ensures the loop terminates if the final state of the endpoint is // say TIME_WAIT. notifyTickleWorker + notifyError ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -283,6 +285,18 @@ func (*EndpointInfo) IsEndpointInfo() {} type endpoint struct { EndpointInfo + // endpointEntry is used to queue endpoints for processing to the + // a given tcp processor goroutine. + // + // Precondition: epQueue.mu must be held to read/write this field.. + endpointEntry `state:"nosave"` + + // pendingProcessing is true if this endpoint is queued for processing + // to a TCP processor. + // + // Precondition: epQueue.mu must be held to read/write this field.. + pendingProcessing bool `state:"nosave"` + // workMu is used to arbitrate which goroutine may perform protocol // work. Only the main protocol goroutine is expected to call Lock() on // it, but other goroutines (e.g., send) may call TryLock() to eagerly @@ -324,6 +338,7 @@ type endpoint struct { // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` + // state must be read/set using the EndpointState()/setEndpointState() methods. state EndpointState `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the @@ -359,7 +374,7 @@ type endpoint struct { workerRunning bool // workerCleanup specifies if the worker goroutine must perform cleanup - // before exitting. This can only be set to true when workerRunning is + // before exiting. This can only be set to true when workerRunning is // also true, and they're both protected by the mutex. workerCleanup bool @@ -371,6 +386,8 @@ type endpoint struct { // recentTS is the timestamp that should be sent in the TSEcr field of // the timestamp for future segments sent by the endpoint. This field is // updated if required when a new segment is received by this endpoint. + // + // recentTS must be read/written atomically. recentTS uint32 // tsOffset is a randomized offset added to the value of the @@ -567,6 +584,47 @@ func (e *endpoint) ResumeWork() { e.workMu.Unlock() } +// setEndpointState updates the state of the endpoint to state atomically. This +// method is unexported as the only place we should update the state is in this +// package but we allow the state to be read freely without holding e.mu. +// +// Precondition: e.mu must be held to call this method. +func (e *endpoint) setEndpointState(state EndpointState) { + oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + switch state { + case StateEstablished: + e.stack.Stats().TCP.CurrentEstablished.Increment() + case StateError: + fallthrough + case StateClose: + if oldstate == StateCloseWait || oldstate == StateEstablished { + e.stack.Stats().TCP.EstablishedResets.Increment() + } + fallthrough + default: + if oldstate == StateEstablished { + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } + } + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) +} + +// EndpointState returns the current state of the endpoint. +func (e *endpoint) EndpointState() EndpointState { + return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) +} + +// setRecentTimestamp atomically sets the recentTS field to the +// provided value. +func (e *endpoint) setRecentTimestamp(recentTS uint32) { + atomic.StoreUint32(&e.recentTS, recentTS) +} + +// recentTimestamp atomically reads and returns the value of the recentTS field. +func (e *endpoint) recentTimestamp() uint32 { + return atomic.LoadUint32(&e.recentTS) +} + // keepalive is a synchronization wrapper used to appease stateify. See the // comment in endpoint, where it is used. // @@ -656,7 +714,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { e.mu.RLock() defer e.mu.RUnlock() - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. @@ -672,7 +730,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } } } - if e.state.connected() { + if e.EndpointState().connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -733,14 +791,20 @@ func (e *endpoint) Close() { // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) + e.closeNoShutdown() +} +// closeNoShutdown closes the endpoint without doing a full shutdown. This is +// used when a connection needs to be aborted with a RST and we want to skip +// a full 4 way TCP shutdown. +func (e *endpoint) closeNoShutdown() { e.mu.Lock() // For listening sockets, we always release ports inline so that they // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == StateListen && e.isPortReserved { + if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false @@ -780,6 +844,8 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { defer close(done) for n := range e.acceptedChan { n.notifyProtocolGoroutine(notifyReset) + // close all connections that have completed but + // not accepted by the application. n.Close() } }() @@ -797,11 +863,13 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { // after Close() is called and the worker goroutine (if any) is done with its // work. func (e *endpoint) cleanupLocked() { + // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { e.closePendingAcceptableConnectionsLocked() } + e.workerCleanup = false if e.isRegistered { @@ -920,7 +988,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { + if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError e.mu.RUnlock() @@ -944,7 +1012,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -980,8 +1048,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. - if !e.state.connected() { - switch e.state { + if !e.EndpointState().connected() { + switch e.EndpointState() { case StateError: return 0, e.HardError default: @@ -1039,42 +1107,86 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, perr } - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() + if opts.Atomic { + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() + } - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { + if e.workMu.TryLock() { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] } - } - - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() - - if e.workMu.TryLock() { // Do the work inline. e.handleWrite() e.workMu.Unlock() } else { + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() + + } // Let the protocol goroutine do the work. e.sndWaker.Assert() } @@ -1091,7 +1203,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; !s.connected() && s != StateClose { + if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { return 0, tcpip.ControlMessages{}, e.HardError } @@ -1103,7 +1215,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } @@ -1187,7 +1299,7 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -1402,14 +1514,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // Acquire the work mutex as we may need to // reinitialize the congestion control state. e.mu.Lock() - state := e.state + state := e.EndpointState() e.cc = v e.mu.Unlock() switch state { case StateEstablished: e.workMu.Lock() e.mu.Lock() - if e.state == state { + if e.EndpointState() == state { e.snd.cc = e.snd.initCongestionControl(e.cc) } e.mu.Unlock() @@ -1472,7 +1584,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { defer e.mu.RUnlock() // The endpoint cannot be in listen state. - if e.state == StateListen { + if e.EndpointState() == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1731,7 +1843,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return err } - if e.state.connected() { + if e.EndpointState().connected() { // The endpoint is already connected. If caller hasn't been // notified yet, return success. if !e.isConnectNotified { @@ -1743,7 +1855,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } nicID := addr.NIC - switch e.state { + switch e.EndpointState() { case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. @@ -1850,7 +1962,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.isRegistered = true - e.state = StateConnecting + e.setEndpointState(StateConnecting) e.route = r.Clone() e.boundNICID = nicID e.effectiveNetProtos = netProtos @@ -1871,14 +1983,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = StateEstablished - e.stack.Stats().TCP.CurrentEstablished.Increment() + e.setEndpointState(StateEstablished) } if run { e.workerRunning = true e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save. + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } return tcpip.ErrConnectStarted @@ -1896,7 +2007,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.shutdownFlags |= flags finQueued := false switch { - case e.state.connected(): + case e.EndpointState().connected(): // Close for read. if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { // Mark read side as closed. @@ -1908,8 +2019,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { - e.notifyProtocolGoroutine(notifyReset) + // Move the socket to error state immediately. + // This is done redundantly because in case of + // save/restore on a Shutdown/Close() the socket + // state needs to indicate the error otherwise + // save file will show the socket in established + // state even though snd/rcv are closed. e.mu.Unlock() + // Try to send an active reset immediately if the + // work mutex is available. + if e.workMu.TryLock() { + e.mu.Lock() + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.mu.Unlock() + e.workMu.Unlock() + } else { + e.notifyProtocolGoroutine(notifyReset) + } return nil } } @@ -1931,11 +2057,10 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { finQueued = true // Mark endpoint as closed. e.sndClosed = true - e.sndBufMu.Unlock() } - case e.state == StateListen: + case e.EndpointState() == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) @@ -1976,7 +2101,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // When the endpoint shuts down, it sets workerCleanup to true, and from // that point onward, acceptedChan is the responsibility of the cleanup() // method (and should not be touched anywhere else, including here). - if e.state == StateListen && !e.workerCleanup { + if e.EndpointState() == StateListen && !e.workerCleanup { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. if len(e.acceptedChan) > backlog { @@ -1994,7 +2119,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { return nil } - if e.state == StateInitial { + if e.EndpointState() == StateInitial { // The listen is called on an unbound socket, the socket is // automatically bound to a random free port with the local // address set to INADDR_ANY. @@ -2004,7 +2129,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Endpoint must be bound before it can transition to listen mode. - if e.state != StateBound { + if e.EndpointState() != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } @@ -2015,24 +2140,27 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } e.isRegistered = true - e.state = StateListen + e.setEndpointState(StateListen) + if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } e.workerRunning = true - go e.protocolListenLoop( // S/R-SAFE: drained on save. seqnum.Size(e.receiveBufferAvailable())) - return nil } // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { + e.mu.Lock() e.waiterQueue = waiterQueue e.workerRunning = true - go e.protocolMainLoop(false) // S/R-SAFE: drained on save. + e.mu.Unlock() + wakerInitDone := make(chan struct{}) + go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save. + <-wakerInitDone } // Accept returns a new endpoint if a peer has established a connection @@ -2042,7 +2170,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { defer e.mu.RUnlock() // Endpoint must be in listen state before it can accept connections. - if e.state != StateListen { + if e.EndpointState() != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } @@ -2069,7 +2197,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrAlreadyBound } @@ -2131,7 +2259,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } // Mark endpoint as bound. - e.state = StateBound + e.setEndpointState(StateBound) return nil } @@ -2153,7 +2281,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if !e.state.connected() { + if !e.EndpointState().connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -2164,45 +2292,22 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { - s := newSegment(r, id, pkt) - if !s.parse() { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() - e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() - s.decRef() - return - } - - if !s.csumValid { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.ChecksumErrors.Increment() - e.stats.ReceiveErrors.ChecksumErrors.Increment() - s.decRef() - return - } - - e.stack.Stats().TCP.ValidSegmentsReceived.Increment() - e.stats.SegmentsReceived.Increment() - if (s.flags & header.TCPFlagRst) != 0 { - e.stack.Stats().TCP.ResetsReceived.Increment() - } - - e.enqueueSegment(s) + // TCP HandlePacket is not required anymore as inbound packets first + // land at the Dispatcher which then can either delivery using the + // worker go routine or directly do the invoke the tcp processing inline + // based on the state of the endpoint. } -func (e *endpoint) enqueueSegment(s *segment) { +func (e *endpoint) enqueueSegment(s *segment) bool { // Send packet to worker goroutine. - if e.segmentQueue.enqueue(s) { - e.newSegmentWaker.Assert() - } else { + if !e.segmentQueue.enqueue(s) { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.SegmentQueueDropped.Increment() - s.decRef() + return false } + return true } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. @@ -2319,8 +2424,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int { // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { - if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { - e.recentTS = tsVal + if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + e.setRecentTimestamp(tsVal) } } @@ -2330,7 +2435,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { if synOpts.TS { e.sendTSOk = true - e.recentTS = synOpts.TSVal + e.setRecentTimestamp(synOpts.TSVal) } } @@ -2419,7 +2524,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Endpoint TCP Option state. s.SendTSOk = e.sendTSOk - s.RecentTS = e.recentTS + s.RecentTS = e.recentTimestamp() s.TSOffset = e.tsOffset s.SACKPermitted = e.sackPermitted s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) @@ -2526,9 +2631,7 @@ func (e *endpoint) initGSO() { // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { - e.mu.Lock() - defer e.mu.Unlock() - return uint32(e.state) + return uint32(e.EndpointState()) } // Info returns a copy of the endpoint info. diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 4b8d867bc..4a46f0ec5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -16,6 +16,7 @@ package tcp import ( "fmt" + "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sync" @@ -48,7 +49,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() defer e.mu.Unlock() - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound: // TODO(b/138137272): this enumeration duplicates // EndpointState.connected. remove it. @@ -70,31 +71,30 @@ func (e *endpoint) beforeSave() { fallthrough case StateListen, StateConnecting: e.drainSegmentLocked() - if e.state != StateClose && e.state != StateError { + if e.EndpointState() != StateClose && e.EndpointState() != StateError { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } break } - fallthrough case StateError, StateClose: - for (e.state == StateError || e.state == StateClose) && e.workerRunning { + for e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() } if e.workerRunning { - panic("endpoint still has worker running in closed or error state") + panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID)) } default: - panic(fmt.Sprintf("endpoint in unknown state %v", e.state)) + panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) } if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { panic("endpoint still has waiters upon save") } - if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) { + if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) { panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") } } @@ -135,7 +135,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { // saveState is invoked by stateify. func (e *endpoint) saveState() EndpointState { - return e.state + return e.EndpointState() } // Endpoint loading must be done in the following ordering by their state, to @@ -151,7 +151,8 @@ var connectingLoading sync.WaitGroup func (e *endpoint) loadState(state EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. - if state.connected() { + // For restore purposes we treat TimeWait like a connected endpoint. + if state.connected() || state == StateTimeWait { connectedLoading.Add(1) } switch state { @@ -160,13 +161,14 @@ func (e *endpoint) loadState(state EndpointState) { case StateConnecting, StateSynSent, StateSynRecv: connectingLoading.Add(1) } - e.state = state + // Directly update the state here rather than using e.setEndpointState + // as the endpoint is still being loaded and the stack reference to increment + // metrics is not yet initialized. + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) } // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - // Freeze segment queue before registering to prevent any segments - // from being delivered while it is being restored. e.origEndpointState = e.state // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. @@ -180,7 +182,6 @@ func (e *endpoint) Resume(s *stack.Stack) { e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() state := e.origEndpointState - switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -276,7 +277,7 @@ func (e *endpoint) Resume(s *stack.Stack) { listenLoading.Wait() connectingLoading.Wait() bind() - e.state = StateClose + e.setEndpointState(StateClose) tcpip.AsyncLoading.Done() }() } @@ -288,6 +289,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } + } // saveLastError is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 9a8f64aa6..958c06fa7 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,6 +21,7 @@ package tcp import ( + "runtime" "strings" "time" @@ -104,6 +105,7 @@ type protocol struct { moderateReceiveBuffer bool tcpLingerTimeout time.Duration tcpTimeWaitTimeout time.Duration + dispatcher *dispatcher } // Number returns the tcp protocol number. @@ -134,6 +136,14 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { return h.SourcePort(), h.DestinationPort(), nil } +// QueuePacket queues packets targeted at an endpoint after hashing the packet +// to a specific processing queue. Each queue is serviced by its own processor +// goroutine which is responsible for dequeuing and doing full TCP dispatch of +// the packet. +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + p.dispatcher.queuePacket(r, ep, id, pkt) +} + // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. // @@ -330,5 +340,6 @@ func NewProtocol() stack.TransportProtocol { availableCongestionControl: []string{ccReno, ccCubic}, tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 05c8488f8..958f03ac1 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -169,19 +169,19 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // We just received a FIN, our next state depends on whether we sent a // FIN already or not. r.ep.mu.Lock() - switch r.ep.state { + switch r.ep.EndpointState() { case StateEstablished: - r.ep.state = StateCloseWait + r.ep.setEndpointState(StateCloseWait) case StateFinWait1: if s.flagIsSet(header.TCPFlagAck) { // FIN-ACK, transition to TIME-WAIT. - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } else { // Simultaneous close, expecting a final ACK. - r.ep.state = StateClosing + r.ep.setEndpointState(StateClosing) } case StateFinWait2: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } r.ep.mu.Unlock() @@ -205,16 +205,16 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // shutdown states. if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { r.ep.mu.Lock() - switch r.ep.state { + switch r.ep.EndpointState() { case StateFinWait1: - r.ep.state = StateFinWait2 + r.ep.setEndpointState(StateFinWait2) // Notify protocol goroutine that we have received an // ACK to our FIN so that it can start the FIN_WAIT2 // timer to abort connection if the other side does // not close within 2MSL. r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) case StateLastAck: r.ep.transitionToStateCloseLocked() } @@ -267,7 +267,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo switch state { case StateCloseWait, StateClosing, StateLastAck: if !s.sequenceNumber.LessThanEq(r.rcvNxt) { - s.decRef() // Just drop the segment as we have // already received a FIN and this // segment is after the sequence number @@ -284,7 +283,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { - s.decRef() return true, tcpip.ErrConnectionAborted } if state == StateFinWait1 { @@ -314,7 +312,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // the last actual data octet in a segment in // which it occurs. if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { - s.decRef() return true, tcpip.ErrConnectionAborted } } @@ -336,7 +333,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // r as they arrive. It is called by the protocol main loop. func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { r.ep.mu.RLock() - state := r.ep.state + state := r.ep.EndpointState() closed := r.ep.closed r.ep.mu.RUnlock() diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index fdff7ed81..b74b61e7d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -705,17 +705,15 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } seg.flags = header.TCPFlagAck | header.TCPFlagFin segEnd = seg.sequenceNumber.Add(1) - // Transition to FIN-WAIT1 state since we're initiating an active close. - s.ep.mu.Lock() - switch s.ep.state { + // Update the state to reflect that we have now + // queued a FIN. + switch s.ep.EndpointState() { case StateCloseWait: - // We've already received a FIN and are now sending our own. The - // sender is now awaiting a final ACK for this FIN. - s.ep.state = StateLastAck + s.ep.setEndpointState(StateLastAck) default: - s.ep.state = StateFinWait1 + s.ep.setEndpointState(StateFinWait1) } - s.ep.mu.Unlock() + } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 6edfa8dce..a9dfbe857 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -293,7 +293,6 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SeqNum(uint32(c.IRS+1)), checker.AckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - finHeaders := &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -459,6 +458,9 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence number + // of the FIN itself. checker.SeqNum(uint32(c.IRS)+2), checker.AckNum(0), checker.TCPFlags(header.TCPFlagRst), @@ -1500,6 +1502,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence + // number of the FIN itself. checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. @@ -5441,6 +5446,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { rawEP.SendPacketWithTS(b[start:start+mss], tsVal) packetsSent++ } + // Resume the worker so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() @@ -5508,7 +5514,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { stk := c.Stack() // Set lower limits for auto-tuning tests. This is required because the // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. + // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { @@ -5563,6 +5569,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalSent += mss packetsSent++ } + // Resume it so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 5d114d460..2f9821555 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -533,7 +533,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { // Sleep for a little over the linger timeout to reduce flakiness in // save/restore tests. - absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2)); ds.reset(); diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 6b99c021d..33a5ac66c 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -814,6 +814,20 @@ TEST_P(TcpSocketTest, FullBuffer) { t_ = -1; } +TEST_P(TcpSocketTest, PollAfterShutdown) { + ScopedThread client_thread([this]() { + EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceedsWithValue(0)); + struct pollfd poll_fd = {s_, POLLIN | POLLERR | POLLHUP, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); + }); + + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceedsWithValue(0)); + struct pollfd poll_fd = {t_, POLLIN | POLLERR | POLLHUP, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); +} + TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) { // Initialize address to the loopback one. sockaddr_storage addr = -- cgit v1.2.3 From f874723e64bd8a2e747bb336e0b6b8f0da1f044a Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 15 Jan 2020 11:17:25 -0800 Subject: Bump SO_SNDBUF for fdbased endpoint used by runsc. Updates #231 PiperOrigin-RevId: 289897881 --- benchmarks/tcp/tcp_proxy.go | 14 +++++++------- runsc/sandbox/network.go | 23 ++++++++++++++--------- 2 files changed, 21 insertions(+), 16 deletions(-) (limited to 'benchmarks') diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index dc96add66..72ada5700 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -84,8 +84,8 @@ func (netImpl) printStats() { } const ( - nicID = 1 // Fixed. - rcvBufSize = 4 << 20 // 1MB. + nicID = 1 // Fixed. + bufSize = 4 << 20 // 4MB. ) type netstackImpl struct { @@ -125,13 +125,13 @@ func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) { } // RAW Sockets by default have a very small SO_RCVBUF of 256KB, - // up it to at least 1MB to reduce packet drops. - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, rcvBufSize); err != nil { - return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) + // up it to at least 4MB to reduce packet drops. + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize); err != nil { + return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", bufSize, err) } - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, rcvBufSize); err != nil { - return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize); err != nil { + return nil, fmt.Errorf("setsockopt(..., SO_SNDBUF, %v,..) = %v", bufSize, err) } if !*swgso && *gso != 0 { diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index be8b72b3e..ff48f5646 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -321,16 +321,21 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) ( } } - // Use SO_RCVBUFFORCE because on linux the receive buffer for an - // AF_PACKET socket is capped by "net.core.rmem_max". rmem_max - // defaults to a unusually low value of 208KB. This is too low - // for gVisor to be able to receive packets at high throughputs - // without incurring packet drops. - const rcvBufSize = 4 << 20 // 4MB. - - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, rcvBufSize); err != nil { - return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", rcvBufSize, err) + // Use SO_RCVBUFFORCE/SO_SNDBUFFORCE because on linux the receive/send buffer + // for an AF_PACKET socket is capped by "net.core.rmem_max/wmem_max". + // wmem_max/rmem_max default to a unusually low value of 208KB. This is too low + // for gVisor to be able to receive packets at high throughputs without + // incurring packet drops. + const bufSize = 4 << 20 // 4MB. + + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, bufSize); err != nil { + return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", bufSize, err) } + + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, bufSize); err != nil { + return nil, fmt.Errorf("failed to increase socket snd buffer to %d: %v", bufSize, err) + } + return &socketEntry{deviceFile, gsoMaxSize}, nil } -- cgit v1.2.3 From 7b7ce29af326ccd247ee5225e9b5b55a9d0330ce Mon Sep 17 00:00:00 2001 From: Zach Koopmans Date: Wed, 15 Jan 2020 14:24:55 -0800 Subject: Update commandline and get local runs working. PiperOrigin-RevId: 289937063 --- benchmarks/README.md | 126 ++++++++++----------- benchmarks/harness/__init__.py | 2 +- benchmarks/harness/machine.py | 8 +- .../harness/machine_producers/machine_producer.py | 21 ++++ benchmarks/runner/BUILD | 10 ++ benchmarks/runner/__init__.py | 59 +++------- benchmarks/runner/commands.py | 84 ++++++++++++++ benchmarks/runner/runner_test.py | 2 +- benchmarks/suites/http.py | 2 +- benchmarks/workloads/BUILD | 40 +++---- benchmarks/workloads/ab/BUILD | 5 +- benchmarks/workloads/absl/BUILD | 5 +- benchmarks/workloads/curl/BUILD | 6 +- benchmarks/workloads/ffmpeg/BUILD | 6 +- benchmarks/workloads/fio/BUILD | 5 +- benchmarks/workloads/httpd/BUILD | 6 +- benchmarks/workloads/iperf/BUILD | 5 +- benchmarks/workloads/netcat/BUILD | 6 +- benchmarks/workloads/nginx/BUILD | 6 +- benchmarks/workloads/node/BUILD | 6 +- benchmarks/workloads/node_template/BUILD | 6 +- benchmarks/workloads/redis/BUILD | 6 +- benchmarks/workloads/redisbenchmark/BUILD | 5 +- benchmarks/workloads/ruby/BUILD | 13 +++ benchmarks/workloads/ruby_template/BUILD | 7 +- benchmarks/workloads/sleep/BUILD | 6 +- benchmarks/workloads/sysbench/BUILD | 5 +- benchmarks/workloads/syscall/BUILD | 5 +- benchmarks/workloads/tensorflow/BUILD | 6 +- benchmarks/workloads/true/BUILD | 7 +- 30 files changed, 303 insertions(+), 173 deletions(-) create mode 100644 benchmarks/runner/commands.py (limited to 'benchmarks') diff --git a/benchmarks/README.md b/benchmarks/README.md index ad44cd6ac..ff21614c5 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -6,66 +6,55 @@ These scripts are tools for collecting performance data for Docker-based tests. The scripts assume the following: -* You have a local machine with bazel installed. -* You have some machine(s) with docker installed. These machines will be - refered to as the "Environment". -* Environment machines have the runtime(s) under test installed, such that you - can run docker with a command like: `docker run --runtime=$RUNTIME - your/image`. -* You are able to login to machines in the environment with the local machine - via ssh and the user for ssh can run docker commands without using `sudo`. +* There are two sets of machines: one where the scripts will be run + (controller) and one or more machines on which docker containers will be run + (environment). +* The controller machine must have bazel installed along with this source + code. You should be able to run a command like `bazel run :benchmarks -- + --list` +* Environment machines must have docker and the required runtimes installed. + More specifically, you should be able to run a command like: `docker run + --runtime=$RUNTIME your/image`. +* The controller has ssh private key which can be used to login to environment + machines and run docker commands without using `sudo`. This is not required + if running locally via the `run-local` command. * The docker daemon on each of your environment machines is listening on `unix:///var/run/docker.sock` (docker's default). For configuring the environment manually, consult the [dockerd documentation][dockerd]. -## Environment - -All benchmarks require a user defined yaml file describe the environment. These -files are of the form: - -```yaml -machine1: local -machine2: - hostname: 100.100.100.100 - username: username - key_path: ~/private_keyfile - key_password: passphrase -machine3: - hostname: 100.100.100.101 - username: username - key_path: ~/private_keyfile - key_password: passphrase -``` +## Running benchmarks -The yaml file defines an environment with three machines named `machine1`, -`machine2` and `machine3`. `machine1` is the local machine, `machine2` and -`machine3` are remote machines. Both `machine2` and `machine3` should be -reachable by `ssh`. For example, the command `ssh -i ~/private_keyfile -username@100.100.100.100` (using the passphrase `passphrase`) should connect to -`machine2`. +Run the following from the benchmarks directory: -The above is an example only. Machines should be uniform, since they are treated -as such by the tests. Machines must also be accessible to each other via their -default routes. Furthermore, some benchmarks will meaningless if running on the -local machine, such as density. +```bash +bazel run :benchmarks -- run-local startup -For remote machines, `hostname`, `key_path`, and `username` are required and -others are optional. In addition key files must be generated -[using the instrcutions below](#generating-ssh-keys). +... +method,metric,result +startup.empty,startup_time_ms,652.5772 +startup.node,startup_time_ms,1654.4042000000002 +startup.ruby,startup_time_ms,1429.835 +``` -The above yaml file can be checked for correctness with the `validate` command -in the top level perf.py script: +The above command ran the startup benchmark locally, which consists of three +benchmarks (empty, node, and ruby). Benchmark tools ran it on the default +runtime, runc. Running on another installed runtime, like say runsc, is as +simple as: -`bazel run :benchmarks -- validate $PWD/examples/localhost.yaml` +```bash +bazel run :benchmakrs -- run-local startup --runtime=runsc +``` -## Running benchmarks +There is help: ``bash bash bazel run :benchmarks -- --help bazel +run :benchmarks -- run-local --help` `` To list available benchmarks, use the `list` commmand: ```bash bazel run :benchmarks -- list +ls ... Benchmark: sysbench.cpu @@ -75,24 +64,44 @@ Metrics: events_per_second :param max_prime: The maximum prime number to search. ``` -To run benchmarks, use the `run` command. For example, to run the sysbench -benchmark above: +You can choose benchmarks by name or regex like: ```bash -bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml sysbench.cpu +bazel run :benchmarks -- run-local startup.node +... +metric,result +startup_time_ms,1671.7178000000001 + +``` + +or + +```bash +bazel run :benchmarks -- run-local s +... +method,metric,result +startup.empty,startup_time_ms,1792.8292 +startup.node,startup_time_ms,3113.5274 +startup.ruby,startup_time_ms,3025.2424 +sysbench.cpu,cpu_events_per_second,12661.47 +sysbench.memory,memory_ops_per_second,7228268.44 +sysbench.mutex,mutex_time,17.4835 +sysbench.mutex,mutex_latency,3496.7 +sysbench.mutex,mutex_deviation,0.04 +syscall.syscall,syscall_time_ns,2065.0 ``` You can run parameterized benchmarks, for example to run with different runtimes: ```bash -bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml --runtime=runc --runtime=runsc sysbench.cpu +bazel run :benchmarks -- run-local --runtime=runc --runtime=runsc sysbench.cpu ``` Or with different parameters: ```bash -bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml --max_prime=10 --max_prime=100 sysbench.cpu +bazel run :benchmarks -- run-local --max_prime=10 --max_prime=100 sysbench.cpu ``` ## Writing benchmarks @@ -121,7 +130,7 @@ The harness requires workloads to run. These are all available in the In general, a workload consists of a Dockerfile to build it (while these are not hermetic, in general they should be as fixed and isolated as possible), some -parses for output if required, parser tests and sample data. Provided the test +parsers for output if required, parser tests and sample data. Provided the test is named after the workload package and contains a function named `sample`, this variable will be used to automatically mock workload output when the `--mock` flag is provided to the main tool. @@ -149,24 +158,5 @@ To write a new benchmark, open a module in the `suites` directory and use the above signature. You should add a descriptive doc string to describe what your benchmark is and any test centric arguments. -## Generating SSH Keys - -The scripts only support RSA Keys, and ssh library used in paramiko. Paramiko -only supports RSA keys that look like the following (PEM format): - -```bash -$ cat /path/to/ssh/key - ------BEGIN RSA PRIVATE KEY----- -...private key text... ------END RSA PRIVATE KEY----- - -``` - -To generate ssh keys in PEM format, use the [`-t rsa -m PEM -b 4096`][RSA-keys]. -option. - [dockerd]: https://docs.docker.com/engine/reference/commandline/dockerd/ [docker-py]: https://docker-py.readthedocs.io/en/stable/ -[paramiko]: http://docs.paramiko.org/en/2.4/api/client.html -[RSA-keys]: https://serverfault.com/questions/939909/ssh-keygen-does-not-create-rsa-private-key diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py index a7f34da9e..7b96d1666 100644 --- a/benchmarks/harness/__init__.py +++ b/benchmarks/harness/__init__.py @@ -18,7 +18,7 @@ import os # LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a # format string that accepts a single string parameter. LOCAL_WORKLOADS_PATH = os.path.join( - os.path.dirname(__file__), "../workloads/{}") + os.path.dirname(__file__), "../workloads/{}/tar.tar") # REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the # remote host. This is a format string that accepts a single string parameter. diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py index 66b719b63..af037dbcc 100644 --- a/benchmarks/harness/machine.py +++ b/benchmarks/harness/machine.py @@ -160,15 +160,17 @@ class LocalMachine(Machine): stdout, stderr = process.communicate() return stdout.decode("utf-8"), stderr.decode("utf-8") - def read(self, path: str) -> str: + def read(self, path: str) -> bytes: # Read the exact path locally. return open(path, "r").read() def pull(self, workload: str) -> str: # Run the docker build command locally. logging.info("Building %s@%s locally...", workload, self._name) - self.run("docker build --tag={} {}".format( - workload, harness.LOCAL_WORKLOADS_PATH.format(workload))) + with open(harness.LOCAL_WORKLOADS_PATH.format(workload), + "rb") as dockerfile: + self._docker_client.images.build( + fileobj=dockerfile, tag=workload, custom_context=True) return workload # Workload is the tag. def container(self, image: str, **kwargs) -> container.Container: diff --git a/benchmarks/harness/machine_producers/machine_producer.py b/benchmarks/harness/machine_producers/machine_producer.py index 124ee14cc..f5591c026 100644 --- a/benchmarks/harness/machine_producers/machine_producer.py +++ b/benchmarks/harness/machine_producers/machine_producer.py @@ -13,6 +13,7 @@ # limitations under the License. """Abstract types.""" +import threading from typing import List from benchmarks.harness import machine @@ -28,3 +29,23 @@ class MachineProducer: def release_machines(self, machine_list: List[machine.Machine]): """Releases the given set of machines.""" raise NotImplementedError + + +class LocalMachineProducer(MachineProducer): + """Produces Local Machines.""" + + def __init__(self, limit: int): + self.limit_sem = threading.Semaphore(value=limit) + + def get_machines(self, num_machines: int) -> List[machine.Machine]: + """Returns the request number of MockMachines.""" + + self.limit_sem.acquire() + return [machine.LocalMachine("local") for _ in range(num_machines)] + + def release_machines(self, machine_list: List[machine.MockMachine]): + """No-op.""" + if not machine_list: + raise ValueError("Cannot release an empty list!") + self.limit_sem.release() + machine_list.clear() diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD index de24824cc..e1b2ea550 100644 --- a/benchmarks/runner/BUILD +++ b/benchmarks/runner/BUILD @@ -10,7 +10,9 @@ py_library( ], visibility = ["//benchmarks:__pkg__"], deps = [ + ":commands", "//benchmarks/harness:benchmark_driver", + "//benchmarks/harness/machine_producers:machine_producer", "//benchmarks/harness/machine_producers:mock_producer", "//benchmarks/harness/machine_producers:yaml_producer", "//benchmarks/suites", @@ -30,6 +32,14 @@ py_library( ], ) +py_library( + name = "commands", + srcs = ["commands.py"], + deps = [ + requirement("click", True), + ], +) + py_test( name = "runner_test", srcs = ["runner_test.py"], diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py index 9bf9cfd65..6f56704d8 100644 --- a/benchmarks/runner/__init__.py +++ b/benchmarks/runner/__init__.py @@ -28,8 +28,10 @@ import click from benchmarks import suites from benchmarks.harness import benchmark_driver +from benchmarks.harness.machine_producers import machine_producer from benchmarks.harness.machine_producers import mock_producer from benchmarks.harness.machine_producers import yaml_producer +from benchmarks.runner import commands @click.group() @@ -100,30 +102,22 @@ def list_all(method): print("\n") -# pylint: disable=too-many-arguments -# pylint: disable=too-many-branches -# pylint: disable=too-many-locals -@runner.command( - context_settings=dict(ignore_unknown_options=True, allow_extra_args=True)) +@runner.command("run-local", commands.LocalCommand) @click.pass_context -@click.argument("method") -@click.option("--mock/--no-mock", default=False, help="Mock the machines.") -@click.option("--env", default=None, help="Specify a yaml file with machines.") -@click.option( - "--runtime", default=["runc"], help="The runtime to use.", multiple=True) -@click.option("--metric", help="The metric to extract.", multiple=True) -@click.option( - "--runs", default=1, help="The number of times to run each benchmark.") -@click.option( - "--stat", - default="median", - help="How to aggregate the data from all runs." - "\nmedian - returns the median of all runs (default)" - "\nall - returns all results comma separated" - "\nmeanstd - returns result as mean,std") -# pylint: disable=too-many-statements -def run(ctx, method: str, runs: int, env: str, mock: bool, runtime: List[str], - metric: List[str], stat: str, **kwargs): +def run_local(ctx, limit: float, **kwargs): + """Runs benchmarks locally.""" + run(ctx, machine_producer.LocalMachineProducer(limit=limit), **kwargs) + + +@runner.command("run-mock", commands.RunCommand) +@click.pass_context +def run_mock(ctx, **kwargs): + """Runs benchmarks on Mock machines. Used for testing.""" + run(ctx, mock_producer.MockMachineProducer(), **kwargs) + + +def run(ctx, producer: machine_producer.MachineProducer, method: str, runs: int, + runtime: List[str], metric: List[str], stat: str, **kwargs): """Runs arbitrary benchmarks. All unknown command line flags are passed through to the underlying benchmark @@ -139,16 +133,13 @@ def run(ctx, method: str, runs: int, env: str, mock: bool, runtime: List[str], All benchmarks are run in parallel where possible, but have exclusive ownership over the individual machines. - Exactly one of the --mock and --env flag must be specified. - Every benchmark method will be run the times indicated by --runs. Args: ctx: Click context. + producer: A Machine Producer from which to get Machines. method: A regular expression for methods to be run. runs: Number of runs. - env: Environment to use. - mock: If true, use mocked environment (supercedes env). runtime: A list of runtimes to test. metric: A list of metrics to extract. stat: The class of statistics to extract. @@ -218,20 +209,6 @@ def run(ctx, method: str, runs: int, env: str, mock: bool, runtime: List[str], sys.exit(1) fold("method", list(methods.keys()), allow_flatten=True) - # Construct the environment. - if mock and env: - # You can't provide both. - logging.error("both --mock and --env are set: which one is it?") - sys.exit(1) - elif mock: - producer = mock_producer.MockMachineProducer() - elif env: - producer = yaml_producer.YamlMachineProducer(env) - else: - # You must provide one of mock or env. - logging.error("no enviroment provided: use --mock or --env.") - sys.exit(1) - # Spin up the drivers. # # We ensure that metric is the last entry, because we have special behavior. diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py new file mode 100644 index 000000000..4973843b9 --- /dev/null +++ b/benchmarks/runner/commands.py @@ -0,0 +1,84 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module with the guts of `click` commands. + +Overrides of the click.core.Command. This is done so flags are inherited between +similar commands (the run command). The classes below are meant to be used in +click templates like so. + +@runner.command("run-mock", RunCommand) +def run_mock(**kwargs): + # mock implementation + +""" +import click + + +class RunCommand(click.core.Command): + """Base Run Command with flags. + + Attributes: + method: regex of which suite to choose (e.g. sysbench would run + sysbench.cpu, sysbench.memory, and sysbench.mutex) See list command for + details. + metric: metric(s) to extract. See list command for details. + runtime: the runtime(s) on which to run. + runs: the number of runs to do of each method. + stat: how to compile results in the case of multiple run (e.g. median). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + method = click.core.Argument(("method",)) + + metric = click.core.Option(("--metric",), + help="The metric to extract.", + multiple=True) + + runtime = click.core.Option(("--runtime",), + default=["runc"], + help="The runtime to use.", + multiple=True) + runs = click.core.Option(("--runs",), + default=1, + help="The number of times to run each benchmark.") + stat = click.core.Option( + ("--stat",), + default="median", + help="How to aggregate the data from all runs." + "\nmedian - returns the median of all runs (default)" + "\nall - returns all results comma separated" + "\nmeanstd - returns result as mean,std") + self.params.extend([method, runtime, runs, stat, metric]) + self.ignore_unknown_options = True + self.allow_extra_args = True + + +class LocalCommand(RunCommand): + """LocalCommand inherits all flags from RunCommand. + + Attributes: + limit: limits the number of machines on which to run benchmarks. This limits + for local how many benchmarks may run at a time. e.g. "startup" requires + one machine -- passing two machines would limit two startup jobs at a + time. Default is infinity. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.params.append( + click.core.Option( + ("--limit",), + default=1, + help="Limit of number of benchmarks that can run at a given time.")) diff --git a/benchmarks/runner/runner_test.py b/benchmarks/runner/runner_test.py index 5719c2838..7818d631a 100644 --- a/benchmarks/runner/runner_test.py +++ b/benchmarks/runner/runner_test.py @@ -49,7 +49,7 @@ def test_list(): def test_run(): cli_runner = testing.CliRunner() - result = cli_runner.invoke(runner.runner, ["run", "--mock", "."]) + result = cli_runner.invoke(runner.runner, ["run-mock", "."]) print(result.output) assert result.exit_code == 0 diff --git a/benchmarks/suites/http.py b/benchmarks/suites/http.py index ea9024e43..6efea938c 100644 --- a/benchmarks/suites/http.py +++ b/benchmarks/suites/http.py @@ -92,7 +92,7 @@ def http_app(server: machine.Machine, redis = server.pull("redis") image = server.pull(workload) redis_port = 6379 - redis_name = "redis_server" + redis_name = "{workload}_redis_server".format(workload=workload) with server.container(redis, name=redis_name).detach(): server.container(server_netcat, links={redis_name: redis_name})\ diff --git a/benchmarks/workloads/BUILD b/benchmarks/workloads/BUILD index 643806105..ccb86af5b 100644 --- a/benchmarks/workloads/BUILD +++ b/benchmarks/workloads/BUILD @@ -11,25 +11,25 @@ py_library( filegroup( name = "files", srcs = [ - "//benchmarks/workloads/ab:files", - "//benchmarks/workloads/absl:files", - "//benchmarks/workloads/curl:files", - "//benchmarks/workloads/ffmpeg:files", - "//benchmarks/workloads/fio:files", - "//benchmarks/workloads/httpd:files", - "//benchmarks/workloads/iperf:files", - "//benchmarks/workloads/netcat:files", - "//benchmarks/workloads/nginx:files", - "//benchmarks/workloads/node:files", - "//benchmarks/workloads/node_template:files", - "//benchmarks/workloads/redis:files", - "//benchmarks/workloads/redisbenchmark:files", - "//benchmarks/workloads/ruby:files", - "//benchmarks/workloads/ruby_template:files", - "//benchmarks/workloads/sleep:files", - "//benchmarks/workloads/sysbench:files", - "//benchmarks/workloads/syscall:files", - "//benchmarks/workloads/tensorflow:files", - "//benchmarks/workloads/true:files", + "//benchmarks/workloads/ab:tar", + "//benchmarks/workloads/absl:tar", + "//benchmarks/workloads/curl:tar", + "//benchmarks/workloads/ffmpeg:tar", + "//benchmarks/workloads/fio:tar", + "//benchmarks/workloads/httpd:tar", + "//benchmarks/workloads/iperf:tar", + "//benchmarks/workloads/netcat:tar", + "//benchmarks/workloads/nginx:tar", + "//benchmarks/workloads/node:tar", + "//benchmarks/workloads/node_template:tar", + "//benchmarks/workloads/redis:tar", + "//benchmarks/workloads/redisbenchmark:tar", + "//benchmarks/workloads/ruby:tar", + "//benchmarks/workloads/ruby_template:tar", + "//benchmarks/workloads/sleep:tar", + "//benchmarks/workloads/sysbench:tar", + "//benchmarks/workloads/syscall:tar", + "//benchmarks/workloads/tensorflow:tar", + "//benchmarks/workloads/true:tar", ], ) diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD index e99a8d674..4fc0ab735 100644 --- a/benchmarks/workloads/ab/BUILD +++ b/benchmarks/workloads/ab/BUILD @@ -1,4 +1,5 @@ load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") +load("@rules_pkg//:pkg.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -27,8 +28,8 @@ py_test( ], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD index bb499620e..61e010096 100644 --- a/benchmarks/workloads/absl/BUILD +++ b/benchmarks/workloads/absl/BUILD @@ -1,4 +1,5 @@ load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") +load("@rules_pkg//:pkg.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -27,8 +28,8 @@ py_test( ], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/curl/BUILD b/benchmarks/workloads/curl/BUILD index 83f3c71a0..eb0fb6165 100644 --- a/benchmarks/workloads/curl/BUILD +++ b/benchmarks/workloads/curl/BUILD @@ -1,10 +1,12 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/ffmpeg/BUILD b/benchmarks/workloads/ffmpeg/BUILD index c1f2afc40..be472dfb2 100644 --- a/benchmarks/workloads/ffmpeg/BUILD +++ b/benchmarks/workloads/ffmpeg/BUILD @@ -1,3 +1,5 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], @@ -8,8 +10,8 @@ py_library( srcs = ["__init__.py"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD index 7fc96cfa5..de257adad 100644 --- a/benchmarks/workloads/fio/BUILD +++ b/benchmarks/workloads/fio/BUILD @@ -1,4 +1,5 @@ load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") +load("@rules_pkg//:pkg.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -27,8 +28,8 @@ py_test( ], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD index 83f3c71a0..eb0fb6165 100644 --- a/benchmarks/workloads/httpd/BUILD +++ b/benchmarks/workloads/httpd/BUILD @@ -1,10 +1,12 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD index fe0acbfce..8832a996c 100644 --- a/benchmarks/workloads/iperf/BUILD +++ b/benchmarks/workloads/iperf/BUILD @@ -1,4 +1,5 @@ load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") +load("@rules_pkg//:pkg.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -27,8 +28,8 @@ py_test( ], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/netcat/BUILD b/benchmarks/workloads/netcat/BUILD index 83f3c71a0..eb0fb6165 100644 --- a/benchmarks/workloads/netcat/BUILD +++ b/benchmarks/workloads/netcat/BUILD @@ -1,10 +1,12 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/nginx/BUILD b/benchmarks/workloads/nginx/BUILD index 83f3c71a0..eb0fb6165 100644 --- a/benchmarks/workloads/nginx/BUILD +++ b/benchmarks/workloads/nginx/BUILD @@ -1,10 +1,12 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/node/BUILD b/benchmarks/workloads/node/BUILD index 59460d02f..71cd9f519 100644 --- a/benchmarks/workloads/node/BUILD +++ b/benchmarks/workloads/node/BUILD @@ -1,10 +1,12 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", "index.js", diff --git a/benchmarks/workloads/node_template/BUILD b/benchmarks/workloads/node_template/BUILD index ae7f121d3..ca996f068 100644 --- a/benchmarks/workloads/node_template/BUILD +++ b/benchmarks/workloads/node_template/BUILD @@ -1,10 +1,12 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", "index.hbs", diff --git a/benchmarks/workloads/redis/BUILD b/benchmarks/workloads/redis/BUILD index 83f3c71a0..eb0fb6165 100644 --- a/benchmarks/workloads/redis/BUILD +++ b/benchmarks/workloads/redis/BUILD @@ -1,10 +1,12 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD index d40e75a3a..f5994a815 100644 --- a/benchmarks/workloads/redisbenchmark/BUILD +++ b/benchmarks/workloads/redisbenchmark/BUILD @@ -1,4 +1,5 @@ load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") +load("@rules_pkg//:pkg.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -27,8 +28,8 @@ py_test( ], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/ruby/BUILD b/benchmarks/workloads/ruby/BUILD index 9846c7e70..e37d77804 100644 --- a/benchmarks/workloads/ruby/BUILD +++ b/benchmarks/workloads/ruby/BUILD @@ -1,3 +1,5 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], @@ -13,3 +15,14 @@ filegroup( "index.rb", ], ) + +pkg_tar( + name = "tar", + srcs = [ + "Dockerfile", + "Gemfile", + "Gemfile.lock", + "config.ru", + "index.rb", + ], +) diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD index 2b99892af..27f7c0c46 100644 --- a/benchmarks/workloads/ruby_template/BUILD +++ b/benchmarks/workloads/ruby_template/BUILD @@ -1,10 +1,12 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", "Gemfile", @@ -13,4 +15,5 @@ filegroup( "index.erb", "main.rb", ], + strip_prefix = "third_party/gvisor/benchmarks/workloads/ruby_template", ) diff --git a/benchmarks/workloads/sleep/BUILD b/benchmarks/workloads/sleep/BUILD index 83f3c71a0..eb0fb6165 100644 --- a/benchmarks/workloads/sleep/BUILD +++ b/benchmarks/workloads/sleep/BUILD @@ -1,10 +1,12 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD index 35f4d460b..fd2f8f03d 100644 --- a/benchmarks/workloads/sysbench/BUILD +++ b/benchmarks/workloads/sysbench/BUILD @@ -1,4 +1,5 @@ load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") +load("@rules_pkg//:pkg.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -27,8 +28,8 @@ py_test( ], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD index e1ff3059b..5100cbb21 100644 --- a/benchmarks/workloads/syscall/BUILD +++ b/benchmarks/workloads/syscall/BUILD @@ -1,4 +1,5 @@ load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") +load("@rules_pkg//:pkg.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -27,8 +28,8 @@ py_test( ], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", "syscall.c", diff --git a/benchmarks/workloads/tensorflow/BUILD b/benchmarks/workloads/tensorflow/BUILD index 17f1f8ebb..026c3b316 100644 --- a/benchmarks/workloads/tensorflow/BUILD +++ b/benchmarks/workloads/tensorflow/BUILD @@ -1,3 +1,5 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], @@ -8,8 +10,8 @@ py_library( srcs = ["__init__.py"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], diff --git a/benchmarks/workloads/true/BUILD b/benchmarks/workloads/true/BUILD index 83f3c71a0..221c4b9a7 100644 --- a/benchmarks/workloads/true/BUILD +++ b/benchmarks/workloads/true/BUILD @@ -1,11 +1,14 @@ +load("@rules_pkg//:pkg.bzl", "pkg_tar") + package( default_visibility = ["//benchmarks:__subpackages__"], licenses = ["notice"], ) -filegroup( - name = "files", +pkg_tar( + name = "tar", srcs = [ "Dockerfile", ], + extension = "tar", ) -- cgit v1.2.3 From 94be30a18dc7c75dc70716ce1ede74a7fb1352fb Mon Sep 17 00:00:00 2001 From: Zach Koopmans Date: Thu, 16 Jan 2020 13:00:58 -0800 Subject: Add run-gcp command. Add command to run benchmarks on GCP backed machines using the gcloud producer. Run with: `bazel run :benchmarks -- run-gcp [BENCHMARK_NAME]` Tested with the startup benchmark. PiperOrigin-RevId: 290126444 --- benchmarks/harness/__init__.py | 7 +++ benchmarks/harness/machine.py | 3 + .../harness/machine_producers/gcloud_producer.py | 70 ++++++++++++++-------- benchmarks/harness/ssh_connection.py | 9 +-- benchmarks/runner/__init__.py | 60 +++++++++++++++++++ benchmarks/runner/commands.py | 51 ++++++++++++++++ 6 files changed, 168 insertions(+), 32 deletions(-) (limited to 'benchmarks') diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py index 7b96d1666..61fd25f73 100644 --- a/benchmarks/harness/__init__.py +++ b/benchmarks/harness/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. """Core benchmark utilities.""" +import getpass import os # LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a @@ -23,3 +24,9 @@ LOCAL_WORKLOADS_PATH = os.path.join( # REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the # remote host. This is a format string that accepts a single string parameter. REMOTE_WORKLOADS_PATH = "workloads/{}" + +# DEFAULT_USER is the default user running this script. +DEFAULT_USER = getpass.getuser() + +# DEFAULT_USER_HOME is the home directory of the user running the script. +DEFAULT_USER_HOME = os.environ["HOME"] if "HOME" in os.environ else "" diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py index af037dbcc..2df4c9e31 100644 --- a/benchmarks/harness/machine.py +++ b/benchmarks/harness/machine.py @@ -214,6 +214,9 @@ class RemoteMachine(Machine): # Push to the remote machine and build. logging.info("Building %s@%s remotely...", workload, self._name) remote_path = self._ssh_connection.send_workload(workload) + # Workloads are all tarballs. + self.run("tar -xvf {remote_path}/tar.tar -C {remote_path}".format( + remote_path=remote_path)) self.run("docker build --tag={} {}".format(workload, remote_path)) return workload # Workload is the tag. diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py index 4693dd8a2..e0b77d52b 100644 --- a/benchmarks/harness/machine_producers/gcloud_producer.py +++ b/benchmarks/harness/machine_producers/gcloud_producer.py @@ -29,7 +29,6 @@ collisions with user instances shouldn't happen. producer.release_machines(NUM_MACHINES) """ import datetime -import getpass import json import subprocess import threading @@ -40,8 +39,6 @@ from benchmarks.harness import machine from benchmarks.harness.machine_producers import gcloud_mock_recorder from benchmarks.harness.machine_producers import machine_producer -DEFAULT_USER = getpass.getuser() - class GCloudProducer(machine_producer.MachineProducer): """Implementation of MachineProducer backed by GCP. @@ -50,9 +47,10 @@ class GCloudProducer(machine_producer.MachineProducer): Attributes: project: The GCP project name under which to create the machines. - ssh_key_path: path to a valid ssh key. See README on vaild ssh keys. + ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys. image: image name as a string. image_project: image project as a string. + machine_type: type of GCP to create. e.g. n1-standard-4 zone: string to a valid GCP zone. ssh_user: string of user name for ssh_key ssh_password: string of password for ssh key @@ -63,18 +61,22 @@ class GCloudProducer(machine_producer.MachineProducer): def __init__(self, project: str, - ssh_key_path: str, + ssh_key_file: str, image: str, image_project: str, + machine_type: str, zone: str, ssh_user: str, + ssh_password: str, mock: gcloud_mock_recorder.MockPrinter = None): self.project = project - self.ssh_key_path = ssh_key_path + self.ssh_key_file = ssh_key_file self.image = image self.image_project = image_project + self.machine_type = machine_type self.zone = zone - self.ssh_user = ssh_user if ssh_user else DEFAULT_USER + self.ssh_user = ssh_user + self.ssh_password = ssh_password self.mock = mock self.condition = threading.Condition() @@ -86,20 +88,19 @@ class GCloudProducer(machine_producer.MachineProducer): with self.condition: names = self._get_unique_names(num_machines) self._build_instances(names) - instances = self._start_command(names) - self._add_ssh_key_to_instances(names) - return self._machines_from_instances(instances) + instances = self._start_command(names) + self._add_ssh_key_to_instances(names) + return self._machines_from_instances(instances) def release_machines(self, machine_list: List[machine.Machine]): """Releases the requested number of machines, deleting the instances.""" if not machine_list: return - with self.condition: - cmd = "gcloud compute instances delete --quiet".split(" ") - names = [str(m) for m in machine_list] - cmd.extend(names) - cmd.append("--zone={zone}".format(zone=self.zone)) - self._run_command(cmd) + cmd = "gcloud compute instances delete --quiet".split(" ") + names = [str(m) for m in machine_list] + cmd.extend(names) + cmd.append("--zone={zone}".format(zone=self.zone)) + self._run_command(cmd, detach=True) def _machines_from_instances( self, instances: List[Dict[str, Any]]) -> List[machine.Machine]: @@ -111,9 +112,11 @@ class GCloudProducer(machine_producer.MachineProducer): "hostname": instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"], "key_path": - self.ssh_key_path, + self.ssh_key_file, "username": - self.ssh_user + self.ssh_user, + "key_password": + self.ssh_password } machines.append(machine.RemoteMachine(name=name, **kwargs)) return machines @@ -148,12 +151,15 @@ class GCloudProducer(machine_producer.MachineProducer): "_build_instances cannot create instances without names.") cmd = "gcloud compute instances create".split(" ") cmd.extend(names) - cmd.extend("--preemptible --image={image} --zone={zone}".format( - image=self.image, zone=self.zone).split(" ")) + cmd.extend( + "--preemptible --image={image} --zone={zone} --machine-type={machine_type}" + .format( + image=self.image, zone=self.zone, + machine_type=self.machine_type).split(" ")) if self.image_project: cmd.append("--image-project={project}".format(project=self.image_project)) - res = self._run_command(cmd) - return json.loads(res.stdout) + res = self._run_command(cmd) + return json.loads(res.stdout) def _start_command(self, names): """Starts instances using gcloud command. @@ -184,7 +190,7 @@ class GCloudProducer(machine_producer.MachineProducer): Args: names: list of machine names to which to add the ssh-key - self.ssh_key_path. + self.ssh_key_file. Raises: subprocess.CalledProcessError: when underlying subprocess call returns an @@ -193,7 +199,7 @@ class GCloudProducer(machine_producer.MachineProducer): """ for name in names: cmd = "gcloud compute ssh {name}".format(name=name).split(" ") - cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_path)) + cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_file)) cmd.append("--zone={zone}".format(zone=self.zone)) cmd.append("--command=uname") timeout = datetime.timedelta(seconds=5 * 60) @@ -221,7 +227,9 @@ class GCloudProducer(machine_producer.MachineProducer): res = self._run_command(cmd) return json.loads(res.stdout) - def _run_command(self, cmd: List[str]) -> subprocess.CompletedProcess: + def _run_command(self, + cmd: List[str], + detach: bool = False) -> [None, subprocess.CompletedProcess]: """Runs command as a subprocess. Runs command as subprocess and returns the result. @@ -230,14 +238,24 @@ class GCloudProducer(machine_producer.MachineProducer): Args: cmd: command to be run as a list of strings. + detach: if True, run the child process and don't wait for it to return. Returns: - Completed process object to be parsed by caller. + Completed process object to be parsed by caller or None if detach=True. Raises: CalledProcessError: if subprocess.run returns an error. """ cmd = cmd + ["--format=json"] + if detach: + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if self.mock: + out, _ = p.communicate() + self.mock.record( + subprocess.CompletedProcess( + returncode=p.returncode, stdout=out, args=p.args)) + return + res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if self.mock: self.mock.record(res) diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py index fcbfbcdb2..e0bf258f1 100644 --- a/benchmarks/harness/ssh_connection.py +++ b/benchmarks/harness/ssh_connection.py @@ -94,7 +94,7 @@ class SSHConnection: return stdout, stderr def send_workload(self, name: str) -> str: - """Sends a workload to the remote machine. + """Sends a workload tarball to the remote machine. Args: name: The workload name. @@ -103,9 +103,6 @@ class SSHConnection: The remote path. """ with self._client() as client: - for dirpath, _, filenames in os.walk( - harness.LOCAL_WORKLOADS_PATH.format(name)): - for filename in filenames: - send_one_file(client, os.path.join(dirpath, filename), - harness.REMOTE_WORKLOADS_PATH.format(name)) + send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name), + harness.REMOTE_WORKLOADS_PATH.format(name)) return harness.REMOTE_WORKLOADS_PATH.format(name) diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py index 6f56704d8..ba80d83d7 100644 --- a/benchmarks/runner/__init__.py +++ b/benchmarks/runner/__init__.py @@ -15,10 +15,13 @@ import copy import csv +import json import logging +import os import pkgutil import pydoc import re +import subprocess import sys import types from typing import List @@ -26,8 +29,10 @@ from typing import Tuple import click +from benchmarks import harness from benchmarks import suites from benchmarks.harness import benchmark_driver +from benchmarks.harness.machine_producers import gcloud_producer from benchmarks.harness.machine_producers import machine_producer from benchmarks.harness.machine_producers import mock_producer from benchmarks.harness.machine_producers import yaml_producer @@ -116,6 +121,61 @@ def run_mock(ctx, **kwargs): run(ctx, mock_producer.MockMachineProducer(), **kwargs) +@runner.command("run-gcp", commands.GCPCommand) +@click.pass_context +def run_gcp(ctx, project: str, ssh_key_file: str, image: str, + image_project: str, machine_type: str, zone: str, ssh_user: str, + ssh_password: str, **kwargs): + """Runs all benchmarks on GCP instances.""" + + if not ssh_user: + ssh_user = harness.DEFAULT_USER + + # Get the default project if one was not provided. + if not project: + sub = subprocess.run( + "gcloud config get-value project".split(" "), stdout=subprocess.PIPE) + if sub.returncode: + raise ValueError( + "Cannot get default project from gcloud. Is it configured>") + project = sub.stdout.decode("utf-8").strip("\n") + + if not image_project: + image_project = project + + # Check that the ssh-key exists and is readable. + if not os.access(ssh_key_file, os.R_OK): + raise ValueError( + "ssh key given `{ssh_key}` is does not exist or is not readable." + .format(ssh_key=ssh_key_file)) + + # Check that the image exists. + sub = subprocess.run( + "gcloud compute images describe {image} --project {image_project} --format=json" + .format(image=image, image_project=image_project).split(" "), + stdout=subprocess.PIPE) + if sub.returncode or "READY" not in json.loads(sub.stdout)["status"]: + raise ValueError( + "given image was not found or is not ready: {image} {image_project}." + .format(image=image, image_project=image_project)) + + # Check and set zone to default. + if not zone: + sub = subprocess.run( + "gcloud config get-value compute/zone".split(" "), + stdout=subprocess.PIPE) + if sub.returncode: + raise ValueError( + "Default zone is not set in gcloud. Set one or pass a zone with the --zone flag." + ) + zone = sub.stdout.decode("utf-8").strip("\n") + + producer = gcloud_producer.GCloudProducer(project, ssh_key_file, image, + image_project, machine_type, zone, + ssh_user, ssh_password) + run(ctx, producer, **kwargs) + + def run(ctx, producer: machine_producer.MachineProducer, method: str, runs: int, runtime: List[str], metric: List[str], stat: str, **kwargs): """Runs arbitrary benchmarks. diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py index 4973843b9..7ab12fac6 100644 --- a/benchmarks/runner/commands.py +++ b/benchmarks/runner/commands.py @@ -24,6 +24,8 @@ def run_mock(**kwargs): """ import click +from benchmarks import harness + class RunCommand(click.core.Command): """Base Run Command with flags. @@ -82,3 +84,52 @@ class LocalCommand(RunCommand): ("--limit",), default=1, help="Limit of number of benchmarks that can run at a given time.")) + + +class GCPCommand(RunCommand): + """GCPCommand inherits all flags from RunCommand and adds flags for run_gcp method. + + Attributes: + project: GCP project + ssh_key_path: path to the ssh-key to use for the run + image: name of the image to build machines from + image_project: GCP project under which to find image + zone: a GCP zone (e.g. us-west1-b) + ssh_user: username to use for the ssh-key + ssh_password: password to use for the ssh-key + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + project = click.core.Option( + ("--project",), + help="Project to run on if not default value given by 'gcloud config get-value project'." + ) + ssh_key_path = click.core.Option( + ("--ssh-key-file",), + help="Path to a valid ssh private key to use. See README on generating a valid ssh key. Set to ~/.ssh/benchmark-tools by default.", + default=harness.DEFAULT_USER_HOME + "/.ssh/benchmark-tools") + image = click.core.Option(("--image",), + help="The image on which to build VMs.", + default="bm-tools-testing") + image_project = click.core.Option( + ("--image_project",), + help="The project under which the image to be used is listed.", + default="") + machine_type = click.core.Option(("--machine_type",), + help="Type to make all machines.", + default="n1-standard-4") + zone = click.core.Option(("--zone",), + help="The GCP zone to run on.", + default="") + ssh_user = click.core.Option(("--ssh-user",), + help="User for the ssh key.", + default=harness.DEFAULT_USER) + ssh_password = click.core.Option(("--ssh-password",), + help="Password for the ssh key.", + default="") + self.params.extend([ + project, ssh_key_path, image, image_project, machine_type, zone, + ssh_user, ssh_password + ]) -- cgit v1.2.3