From d80af5f8b58d2bfe23d57e133a8d35eaed59fa13 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Thu, 4 Nov 2021 18:50:20 -0700 Subject: Remove id from sleep.Sleeper API. In a subsequent change, the Sleeper API will be plumbed through and used for arbitrary task wakeups. This requires a non-static association of Wakers and Sleepers, which means that a fixed ID no longer works. This is a relatively simple change that removes the ID from the Waker association, and simply uses the Waker pointer itself. That change also makes minor improvements to the tests to ensure that the benchmarks are more representative by removing goroutine start from the hot path (and uses Wakers for required synchronization), adds assertion checks to AddWaker, and clears relevant fields during Done (to allow assertions to pass). PiperOrigin-RevId: 407719630 --- pkg/sleep/sleep_test.go | 182 ++++++++++++---------- pkg/sleep/sleep_unsafe.go | 108 ++++++------- pkg/syncevent/waiter_test.go | 22 +-- pkg/tcpip/link/qdisc/fifo/endpoint.go | 11 +- pkg/tcpip/transport/tcp/accept.go | 13 +- pkg/tcpip/transport/tcp/connect.go | 286 +++++++++++++++------------------- pkg/tcpip/transport/tcp/dispatcher.go | 7 +- 7 files changed, 297 insertions(+), 332 deletions(-) (limited to 'pkg') diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go index 1dd11707d..b27feb99b 100644 --- a/pkg/sleep/sleep_test.go +++ b/pkg/sleep/sleep_test.go @@ -67,7 +67,7 @@ func AssertedWakerAfterTwoAsserts(t *testing.T) { func NotAssertedWakerWithSleeper(t *testing.T) { var w Waker var s Sleeper - s.AddWaker(&w, 0) + s.AddWaker(&w) if w.IsAsserted() { t.Fatalf("Non-asserted waker is reported as asserted") } @@ -83,7 +83,7 @@ func NotAssertedWakerWithSleeper(t *testing.T) { func NotAssertedWakerAfterWake(t *testing.T) { var w Waker var s Sleeper - s.AddWaker(&w, 0) + s.AddWaker(&w) w.Assert() s.Fetch(true) if w.IsAsserted() { @@ -101,10 +101,10 @@ func AssertedWakerBeforeAdd(t *testing.T) { var w Waker var s Sleeper w.Assert() - s.AddWaker(&w, 0) + s.AddWaker(&w) - if _, ok := s.Fetch(false); !ok { - t.Fatalf("Fetch failed even though asserted waker was added") + if s.Fetch(false) != &w { + t.Fatalf("Fetch did not match waker") } } @@ -128,7 +128,7 @@ func ClearedWaker(t *testing.T) { func ClearedWakerWithSleeper(t *testing.T) { var w Waker var s Sleeper - s.AddWaker(&w, 0) + s.AddWaker(&w) w.Clear() if w.IsAsserted() { t.Fatalf("Cleared waker is reported as asserted") @@ -145,7 +145,7 @@ func ClearedWakerWithSleeper(t *testing.T) { func ClearedWakerAssertedWithSleeper(t *testing.T) { var w Waker var s Sleeper - s.AddWaker(&w, 0) + s.AddWaker(&w) w.Assert() w.Clear() if w.IsAsserted() { @@ -163,18 +163,15 @@ func TestBlock(t *testing.T) { var w Waker var s Sleeper - s.AddWaker(&w, 0) + s.AddWaker(&w) // Assert waker after one second. before := time.Now() - go func() { - time.Sleep(1 * time.Second) - w.Assert() - }() + time.AfterFunc(time.Second, w.Assert) // Fetch the result and make sure it took at least 500ms. - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") + if s.Fetch(true) != &w { + t.Fatalf("Fetch did not match waker") } if d := time.Now().Sub(before); d < 500*time.Millisecond { t.Fatalf("Duration was too short: %v", d) @@ -182,8 +179,8 @@ func TestBlock(t *testing.T) { // Check that already-asserted waker completes inline. w.Assert() - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") + if s.Fetch(true) != &w { + t.Fatalf("Fetch did not match waker") } // Check that fetch sleeps if waker had been asserted but was reset @@ -191,12 +188,10 @@ func TestBlock(t *testing.T) { w.Assert() w.Clear() before = time.Now() - go func() { - time.Sleep(1 * time.Second) - w.Assert() - }() - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") + time.AfterFunc(time.Second, w.Assert) + + if s.Fetch(true) != &w { + t.Fatalf("Fetch did not match waker") } if d := time.Now().Sub(before); d < 500*time.Millisecond { t.Fatalf("Duration was too short: %v", d) @@ -209,30 +204,30 @@ func TestNonBlock(t *testing.T) { var s Sleeper // Don't block when there's no waker. - if _, ok := s.Fetch(false); ok { + if s.Fetch(false) != nil { t.Fatalf("Fetch succeeded when there is no waker") } // Don't block when waker isn't asserted. - s.AddWaker(&w, 0) - if _, ok := s.Fetch(false); ok { + s.AddWaker(&w) + if s.Fetch(false) != nil { t.Fatalf("Fetch succeeded when waker was not asserted") } // Don't block when waker was asserted, but isn't anymore. w.Assert() w.Clear() - if _, ok := s.Fetch(false); ok { + if s.Fetch(false) != nil { t.Fatalf("Fetch succeeded when waker was not asserted anymore") } // Don't block when waker was consumed by previous Fetch(). w.Assert() - if _, ok := s.Fetch(false); !ok { + if s.Fetch(false) != &w { t.Fatalf("Fetch failed even though waker was asserted") } - if _, ok := s.Fetch(false); ok { + if s.Fetch(false) != nil { t.Fatalf("Fetch succeeded when waker had been consumed") } } @@ -244,29 +239,30 @@ func TestMultiple(t *testing.T) { w1 := Waker{} w2 := Waker{} - s.AddWaker(&w1, 0) - s.AddWaker(&w2, 1) + s.AddWaker(&w1) + s.AddWaker(&w2) w1.Assert() w2.Assert() - v, ok := s.Fetch(false) - if !ok { + v := s.Fetch(false) + if v == nil { t.Fatalf("Fetch failed when there are asserted wakers") } - - if v != 0 && v != 1 { - t.Fatalf("Unexpected waker id: %v", v) + if v != &w1 && v != &w2 { + t.Fatalf("Unexpected waker: %v", v) } - want := 1 - v - v, ok = s.Fetch(false) - if !ok { + want := &w1 + if v == want { + want = &w2 // Other waiter. + } + v = s.Fetch(false) + if v == nil { t.Fatalf("Fetch failed when there is an asserted waker") } - if v != want { - t.Fatalf("Unexpected waker id, got %v, want %v", v, want) + t.Fatalf("Unexpected waker, got %v, want %v", v, want) } } @@ -281,7 +277,7 @@ func TestDoneFunction(t *testing.T) { s := Sleeper{} w := make([]Waker, n) for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) + s.AddWaker(&w[j]) } s.Done() } @@ -293,7 +289,7 @@ func TestDoneFunction(t *testing.T) { s := Sleeper{} w := make([]Waker, n) for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) + s.AddWaker(&w[j]) } w[i].Assert() s.Done() @@ -307,7 +303,7 @@ func TestDoneFunction(t *testing.T) { s := Sleeper{} w := make([]Waker, n) for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) + s.AddWaker(&w[j]) } w[i].Assert() w[i].Clear() @@ -322,7 +318,7 @@ func TestDoneFunction(t *testing.T) { s := Sleeper{} w := make([]Waker, n) for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) + s.AddWaker(&w[j]) } // Pick the number of asserted elements, then assert @@ -342,14 +338,14 @@ func TestRace(t *testing.T) { const wakers = 100 const wakeRequests = 10000 - counts := make([]int, wakers) - w := make([]Waker, wakers) + counts := make(map[*Waker]int, wakers) s := Sleeper{} // Associate each waker and start goroutines that will assert them. - for i := range w { - s.AddWaker(&w[i], i) - go func(w *Waker) { + for i := 0; i < wakers; i++ { + var w Waker + s.AddWaker(&w) + go func() { n := 0 for n < wakeRequests { if !w.IsAsserted() { @@ -359,19 +355,22 @@ func TestRace(t *testing.T) { runtime.Gosched() } } - }(&w[i]) + }() } // Wait for all wake up notifications from all wakers. for i := 0; i < wakers*wakeRequests; i++ { - v, _ := s.Fetch(true) + v := s.Fetch(true) counts[v]++ } // Check that we got the right number for each. - for i, v := range counts { - if v != wakeRequests { - t.Errorf("Waker %v only got %v wakes", i, v) + if got := len(counts); got != wakers { + t.Errorf("Got %d wakers, wanted %d", got, wakers) + } + for _, count := range counts { + if count != wakeRequests { + t.Errorf("Waker only got %d wakes, wanted %d", count, wakeRequests) } } } @@ -384,7 +383,7 @@ func TestRaceInOrder(t *testing.T) { // Associate each waker and start goroutines that will assert them. for i := range w { - s.AddWaker(&w[i], i) + s.AddWaker(&w[i]) } go func() { for i := range w { @@ -393,10 +392,10 @@ func TestRaceInOrder(t *testing.T) { }() // Wait for all wake up notifications from all wakers. - for want := range w { - got, _ := s.Fetch(true) - if got != want { - t.Fatalf("got %d want %d", got, want) + for i := range w { + got := s.Fetch(true) + if want := &w[i]; got != want { + t.Fatalf("got %v want %v", got, want) } } } @@ -408,7 +407,7 @@ func BenchmarkSleeperMultiSelect(b *testing.B) { s := Sleeper{} w := make([]Waker, count) for i := range w { - s.AddWaker(&w[i], i) + s.AddWaker(&w[i]) } b.ResetTimer() @@ -444,7 +443,7 @@ func BenchmarkGoMultiSelect(b *testing.B) { func BenchmarkSleeperSingleSelect(b *testing.B) { s := Sleeper{} w := Waker{} - s.AddWaker(&w, 0) + s.AddWaker(&w) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -494,16 +493,24 @@ func BenchmarkGoAssertNonWaiting(b *testing.B) { // a new goroutine doesn't run immediately (i.e., the creator of a new goroutine // is allowed to go to sleep before the new goroutine has a chance to run). func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) { - s := Sleeper{} - w := Waker{} - s.AddWaker(&w, 0) - for i := 0; i < b.N; i++ { - go func() { + var ( + s Sleeper + w Waker + ns Sleeper + nw Waker + ) + ns.AddWaker(&nw) + s.AddWaker(&w) + go func() { + for i := 0; i < b.N; i++ { + ns.Fetch(true) w.Assert() - }() + } + }() + for i := 0; i < b.N; i++ { + nw.Assert() s.Fetch(true) } - } // BenchmarkGoWaitOnSingleSelect measures how long it takes to wait on one @@ -511,11 +518,13 @@ func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) { // goroutine doesn't run immediately (i.e., the creator of a new goroutine is // allowed to go to sleep before the new goroutine has a chance to run). func BenchmarkGoWaitOnSingleSelect(b *testing.B) { - ch := make(chan struct{}, 1) - for i := 0; i < b.N; i++ { - go func() { + ch := make(chan struct{}) + go func() { + for i := 0; i < b.N; i++ { ch <- struct{}{} - }() + } + }() + for i := 0; i < b.N; i++ { <-ch } } @@ -526,17 +535,26 @@ func BenchmarkGoWaitOnSingleSelect(b *testing.B) { // allowed to go to sleep before the new goroutine has a chance to run). func BenchmarkSleeperWaitOnMultiSelect(b *testing.B) { const count = 4 - s := Sleeper{} + var ( + s Sleeper + ns Sleeper + nw Waker + ) + ns.AddWaker(&nw) w := make([]Waker, count) for i := range w { - s.AddWaker(&w[i], i) + s.AddWaker(&w[i]) } b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { + go func() { + for i := 0; i < b.N; i++ { + ns.Fetch(true) w[count-1].Assert() - }() + } + }() + for i := 0; i < b.N; i++ { + nw.Assert() s.Fetch(true) } } @@ -549,14 +567,16 @@ func BenchmarkGoWaitOnMultiSelect(b *testing.B) { const count = 4 ch := make([]chan struct{}, count) for i := range ch { - ch[i] = make(chan struct{}, 1) + ch[i] = make(chan struct{}) } b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { + go func() { + for i := 0; i < b.N; i++ { ch[count-1] <- struct{}{} - }() + } + }() + for i := 0; i < b.N; i++ { select { case <-ch[0]: case <-ch[1]: diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index c44206b1e..86c7cc983 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -37,15 +37,15 @@ // // // One time set-up. // s := sleep.Sleeper{} -// s.AddWaker(&w1, constant1) -// s.AddWaker(&w2, constant2) +// s.AddWaker(&w1) +// s.AddWaker(&w2) // // // Called repeatedly. // for { -// switch id, _ := s.Fetch(true); id { -// case constant1: +// switch s.Fetch(true) { +// case &w1: // // Do work triggered by w1 being asserted. -// case constant2: +// case &w2: // // Do work triggered by w2 being asserted. // } // } @@ -119,13 +119,18 @@ type Sleeper struct { waitingG uintptr } -// AddWaker associates the given waker to the sleeper. id is the value to be -// returned when the sleeper is woken by the given waker. -func (s *Sleeper) AddWaker(w *Waker, id int) { +// AddWaker associates the given waker to the sleeper. +func (s *Sleeper) AddWaker(w *Waker) { + if w.allWakersNext != nil { + panic("waker has non-nil allWakersNext; owned by another sleeper?") + } + if w.next != nil { + panic("waker has non-nil next; queued in another sleeper?") + } + // Add the waker to the list of all wakers. w.allWakersNext = s.allWakers s.allWakers = w - w.id = id // Try to associate the waker with the sleeper. If it's already // asserted, we simply enqueue it in the "ready" list. @@ -213,28 +218,26 @@ func commitSleep(g uintptr, waitingG unsafe.Pointer) bool { return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(waitingG), preparingG, g) } -// Fetch fetches the next wake-up notification. If a notification is immediately -// available, it is returned right away. Otherwise, the behavior depends on the -// value of 'block': if true, the current goroutine blocks until a notification -// arrives, then returns it; if false, returns 'ok' as false. -// -// When 'ok' is true, the value of 'id' corresponds to the id associated with -// the waker; when 'ok' is false, 'id' is undefined. +// Fetch fetches the next wake-up notification. If a notification is +// immediately available, the asserted waker is returned immediately. +// Otherwise, the behavior depends on the value of 'block': if true, the +// current goroutine blocks until a notification arrives and returns the +// asserted waker; if false, nil will be returned. // // N.B. This method is *not* thread-safe. Only one goroutine at a time is // allowed to call this method. -func (s *Sleeper) Fetch(block bool) (id int, ok bool) { +func (s *Sleeper) Fetch(block bool) *Waker { for { w := s.nextWaker(block) if w == nil { - return -1, false + return nil } // Reassociate the waker with the sleeper. If the waker was // still asserted we can return it, otherwise try the next one. old := (*Sleeper)(atomic.SwapPointer(&w.s, usleeper(s))) if old == &assertedSleeper { - return w.id, true + return w } } } @@ -243,51 +246,34 @@ func (s *Sleeper) Fetch(block bool) (id int, ok bool) { // removes the association with all wakers so that they can be safely reused // by another sleeper after Done() returns. func (s *Sleeper) Done() { - // Remove all associations that we can, and build a list of the ones - // we could not. An association can be removed right away from waker w - // if w.s has a pointer to the sleeper, that is, the waker is not - // asserted yet. By atomically switching w.s to nil, we guarantee that - // subsequent calls to Assert() on the waker will not result in it being - // queued to this sleeper. - var pending *Waker - w := s.allWakers - for w != nil { - next := w.allWakersNext - for { - t := atomic.LoadPointer(&w.s) - if t != usleeper(s) { - w.allWakersNext = pending - pending = w - break - } - - if atomic.CompareAndSwapPointer(&w.s, t, nil) { - break - } + // Remove all associations that we can, and build a list of the ones we + // could not. An association can be removed right away from waker w if + // w.s has a pointer to the sleeper, that is, the waker is not asserted + // yet. By atomically switching w.s to nil, we guarantee that + // subsequent calls to Assert() on the waker will not result in it + // being queued. + for w := s.allWakers; w != nil; w = s.allWakers { + next := w.allWakersNext // Before zapping. + if atomic.CompareAndSwapPointer(&w.s, usleeper(s), nil) { + w.allWakersNext = nil + w.next = nil + s.allWakers = next // Move ahead. + continue } - w = next - } - // The associations that we could not remove are either asserted, or in - // the process of being asserted, or have been asserted and cleared - // before being pulled from the sleeper lists. We must wait for them all - // to make it to the sleeper lists, so that we know that the wakers - // won't do any more work towards waking this sleeper up. - for pending != nil { - pulled := s.nextWaker(true) - - // Remove the waker we just pulled from the list of associated - // wakers. - prev := &pending - for w := *prev; w != nil; w = *prev { - if pulled == w { - *prev = w.allWakersNext - break + // Dequeue exactly one waiter from the list, it may not be + // this one but we know this one is in the process. We must + // leave it in the asserted state but drop it from our lists. + if w := s.nextWaker(true); w != nil { + prev := &s.allWakers + for *prev != w { + prev = &((*prev).allWakersNext) } - prev = &w.allWakersNext + *prev = (*prev).allWakersNext + w.allWakersNext = nil + w.next = nil } } - s.allWakers = nil } // enqueueAssertedWaker enqueues an asserted waker to the "ready" circular list @@ -349,10 +335,6 @@ type Waker struct { // allWakersNext is used to form a linked list of all wakers associated // to a given sleeper. allWakersNext *Waker - - // id is the value to be returned to sleepers when they wake up due to - // this waker being asserted. - id int } // Assert moves the waker to an asserted state, if it isn't asserted yet. When diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go index 3c8cbcdd8..cfa0972c0 100644 --- a/pkg/syncevent/waiter_test.go +++ b/pkg/syncevent/waiter_test.go @@ -105,7 +105,7 @@ func BenchmarkWaiterNotifyRedundant(b *testing.B) { func BenchmarkSleeperNotifyRedundant(b *testing.B) { var s sleep.Sleeper var w sleep.Waker - s.AddWaker(&w, 0) + s.AddWaker(&w) w.Assert() b.ResetTimer() @@ -146,7 +146,7 @@ func BenchmarkWaiterNotifyWaitAck(b *testing.B) { func BenchmarkSleeperNotifyWaitAck(b *testing.B) { var s sleep.Sleeper var w sleep.Waker - s.AddWaker(&w, 0) + s.AddWaker(&w) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -197,7 +197,7 @@ func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { w := wakerPool.Get().(*sleep.Waker) - s.AddWaker(w, 0) + s.AddWaker(w) w.Assert() s.Fetch(true) s.Done() @@ -237,7 +237,7 @@ func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { s := sleeperPool.Get().(*sleep.Sleeper) - s.AddWaker(&w, 0) + s.AddWaker(&w) w.Assert() s.Fetch(true) s.Done() @@ -266,14 +266,14 @@ func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) { var s sleep.Sleeper var ws [3]sleep.Waker for i := range ws { - s.AddWaker(&ws[i], i) + s.AddWaker(&ws[i]) } b.ResetTimer() for i := 0; i < b.N; i++ { ws[0].Assert() - if id, _ := s.Fetch(true); id != 0 { - b.Fatalf("Fetch: got %d, wanted 0", id) + if v := s.Fetch(true); v != &ws[0] { + b.Fatalf("Fetch: got %v, wanted %v", v, &ws[0]) } } } @@ -325,7 +325,7 @@ func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) { func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) { var s sleep.Sleeper var w sleep.Waker - s.AddWaker(&w, 0) + s.AddWaker(&w) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -374,7 +374,7 @@ func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) { var s sleep.Sleeper var ws [3]sleep.Waker for i := range ws { - s.AddWaker(&ws[i], i) + s.AddWaker(&ws[i]) } b.ResetTimer() @@ -382,8 +382,8 @@ func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) { go func() { ws[0].Assert() }() - if id, _ := s.Fetch(true); id != 0 { - b.Fatalf("Fetch: got %d, expected 0", id) + if v := s.Fetch(true); v != &ws[0] { + b.Fatalf("Fetch: got %v, expected %v", v, &ws[0]) } } } diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index b41e3e2fa..c15cbf81b 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -73,20 +73,19 @@ func New(lower stack.LinkEndpoint, n int, queueLen int) stack.LinkEndpoint { } func (q *queueDispatcher) dispatchLoop() { - const newPacketWakerID = 1 - const closeWakerID = 2 s := sleep.Sleeper{} - s.AddWaker(&q.newPacketWaker, newPacketWakerID) - s.AddWaker(&q.closeWaker, closeWakerID) + s.AddWaker(&q.newPacketWaker) + s.AddWaker(&q.closeWaker) defer s.Done() const batchSize = 32 var batch stack.PacketBufferList for { - id, ok := s.Fetch(true) - if ok && id == closeWakerID { + w := s.Fetch(true) + if w == &q.closeWaker { return } + // Must otherwise be the newPacketWaker. for pkt := q.q.dequeue(); pkt != nil; pkt = q.q.dequeue() { batch.PushBack(pkt) if batch.Len() < batchSize && !q.q.empty() { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index caf14b0dc..d0f68b72c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -762,14 +762,15 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { }() var s sleep.Sleeper - s.AddWaker(&e.notificationWaker, wakerForNotification) - s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) + s.AddWaker(&e.notificationWaker) + s.AddWaker(&e.newSegmentWaker) + defer s.Done() for { e.mu.Unlock() - index, _ := s.Fetch(true) + w := s.Fetch(true) e.mu.Lock() - switch index { - case wakerForNotification: + switch w { + case &e.notificationWaker: n := e.fetchNotifications() if n¬ifyClose != 0 { return @@ -788,7 +789,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() } - case wakerForNewSegment: + case &e.newSegmentWaker: // Process at most maxSegmentsPerWake segments. mayRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 80cd07218..12df7a7b4 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -51,13 +51,6 @@ const ( handshakeCompleted ) -// The following are used to set up sleepers. -const ( - wakerForNotification = iota - wakerForNewSegment - wakerForResend -) - const ( // Maximum space available for options. maxOptionSize = 40 @@ -530,9 +523,9 @@ func (h *handshake) complete() tcpip.Error { // Set up the wakers. var s sleep.Sleeper resendWaker := sleep.Waker{} - s.AddWaker(&resendWaker, wakerForResend) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + s.AddWaker(&resendWaker) + s.AddWaker(&h.ep.notificationWaker) + s.AddWaker(&h.ep.newSegmentWaker) defer s.Done() // Initialize the resend timer. @@ -545,11 +538,10 @@ func (h *handshake) complete() tcpip.Error { // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held // throughout handshake processing). h.ep.mu.Unlock() - index, _ := s.Fetch(true /* block */) + w := s.Fetch(true /* block */) h.ep.mu.Lock() - switch index { - - case wakerForResend: + switch w { + case &resendWaker: if err := timer.reset(); err != nil { return err } @@ -577,7 +569,7 @@ func (h *handshake) complete() tcpip.Error { h.sampleRTTWithTSOnly = true } - case wakerForNotification: + case &h.ep.notificationWaker: n := h.ep.fetchNotifications() if (n¬ifyClose)|(n¬ifyAbort) != 0 { return &tcpip.ErrAborted{} @@ -611,7 +603,7 @@ func (h *handshake) complete() tcpip.Error { // cleared because of a socket layer call. return &tcpip.ErrConnectionAborted{} } - case wakerForNewSegment: + case &h.ep.newSegmentWaker: if err := h.processSegments(); err != nil { return err } @@ -1346,6 +1338,103 @@ func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer) { e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } +// handleWakeup handles a wakeup event while connected. +// +// +checklocks:e.mu +func (e *endpoint) handleWakeup(w, closeWaker *sleep.Waker, closeTimer *tcpip.Timer) tcpip.Error { + switch w { + case &e.sndQueueInfo.sndWaker: + e.sendData(nil /* next */) + case &e.newSegmentWaker: + return e.handleSegmentsLocked(false /* fastPath */) + case &e.snd.resendWaker: + if !e.snd.retransmitTimerExpired() { + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return &tcpip.ErrTimeout{} + } + case closeWaker: + // This means the socket is being closed due to the + // TCP-FIN-WAIT2 timeout was hit. Just mark the socket as + // closed. + e.transitionToStateCloseLocked() + e.workerCleanup = true + case &e.snd.probeWaker: + return e.snd.probeTimerExpired() + case &e.keepalive.waker: + return e.keepaliveTimerExpired() + case &e.notificationWaker: + n := e.fetchNotifications() + if n¬ifyNonZeroReceiveWindow != 0 { + e.rcv.nonZeroWindow() + } + + if n¬ifyMTUChanged != 0 { + e.sndQueueInfo.sndQueueMu.Lock() + count := e.sndQueueInfo.PacketTooBigCount + e.sndQueueInfo.PacketTooBigCount = 0 + mtu := e.sndQueueInfo.SndMTU + e.sndQueueInfo.sndQueueMu.Unlock() + + e.snd.updateMaxPayloadSize(mtu, count) + } + + if n¬ifyReset != 0 || n¬ifyAbort != 0 { + return &tcpip.ErrConnectionAborted{} + } + + if n¬ifyResetByPeer != 0 { + return &tcpip.ErrConnectionReset{} + } + + if n¬ifyClose != 0 && e.closed { + switch e.EndpointState() { + case StateEstablished: + // Perform full shutdown if the endpoint is + // still established. This can occur when + // notifyClose was asserted just before + // becoming established. + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + case StateFinWait2: + // The socket has been closed and we are in + // FIN_WAIT2 so start the FIN_WAIT2 timer. + if *closeTimer == nil { + *closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + } + } + } + + if n¬ifyKeepaliveChanged != 0 { + // The timer could fire in background when the endpoint + // is drained. That's OK. See above. + e.resetKeepaliveTimer(true) + } + + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + if err := e.handleSegmentsLocked(false /* fastPath */); err != nil { + return err + } + } + if !e.EndpointState().closed() { + // Only block the worker if the endpoint + // is not in closed state or error state. + close(e.drainDone) + e.mu.Unlock() + <-e.undrain + e.mu.Lock() + } + } + + // N.B. notifyTickleWorker may be set, but there is no action + // to take in this case. + case &e.snd.reorderWaker: + return e.snd.rc.reorderTimerExpired() + default: + panic("unknown waker") // Shouldn't happen. + } + return nil +} + // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. @@ -1403,139 +1492,16 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.mu.Lock() } - // Set up the functions that will be called when the main protocol loop - // wakes up. - funcs := []struct { - w *sleep.Waker - f func() tcpip.Error - }{ - { - w: &e.sndQueueInfo.sndWaker, - f: func() tcpip.Error { - e.sendData(nil /* next */) - return nil - }, - }, - { - w: &closeWaker, - f: func() tcpip.Error { - // This means the socket is being closed due - // to the TCP-FIN-WAIT2 timeout was hit. Just - // mark the socket as closed. - e.transitionToStateCloseLocked() - e.workerCleanup = true - return nil - }, - }, - { - w: &e.snd.resendWaker, - f: func() tcpip.Error { - if !e.snd.retransmitTimerExpired() { - e.stack.Stats().TCP.EstablishedTimedout.Increment() - return &tcpip.ErrTimeout{} - } - return nil - }, - }, - { - w: &e.snd.probeWaker, - f: e.snd.probeTimerExpired, - }, - { - w: &e.newSegmentWaker, - f: func() tcpip.Error { - return e.handleSegmentsLocked(false /* fastPath */) - }, - }, - { - w: &e.keepalive.waker, - f: e.keepaliveTimerExpired, - }, - { - w: &e.notificationWaker, - f: func() tcpip.Error { - n := e.fetchNotifications() - if n¬ifyNonZeroReceiveWindow != 0 { - e.rcv.nonZeroWindow() - } - - if n¬ifyMTUChanged != 0 { - e.sndQueueInfo.sndQueueMu.Lock() - count := e.sndQueueInfo.PacketTooBigCount - e.sndQueueInfo.PacketTooBigCount = 0 - mtu := e.sndQueueInfo.SndMTU - e.sndQueueInfo.sndQueueMu.Unlock() - - e.snd.updateMaxPayloadSize(mtu, count) - } - - if n¬ifyReset != 0 || n¬ifyAbort != 0 { - return &tcpip.ErrConnectionAborted{} - } - - if n¬ifyResetByPeer != 0 { - return &tcpip.ErrConnectionReset{} - } - - if n¬ifyClose != 0 && e.closed { - switch e.EndpointState() { - case StateEstablished: - // Perform full shutdown if the endpoint is still - // established. This can occur when notifyClose - // was asserted just before becoming established. - e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) - case StateFinWait2: - // The socket has been closed and we are in FIN_WAIT2 - // so start the FIN_WAIT2 timer. - if closeTimer == nil { - closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) - } - } - } - - if n¬ifyKeepaliveChanged != 0 { - // The timer could fire in background - // when the endpoint is drained. That's - // OK. See above. - e.resetKeepaliveTimer(true) - } - - if n¬ifyDrain != 0 { - for !e.segmentQueue.empty() { - if err := e.handleSegmentsLocked(false /* fastPath */); err != nil { - return err - } - } - if !e.EndpointState().closed() { - // Only block the worker if the endpoint - // is not in closed state or error state. - close(e.drainDone) - e.mu.Unlock() // +checklocksforce - <-e.undrain - e.mu.Lock() - } - } - - if n¬ifyTickleWorker != 0 { - // Just a tickle notification. No need to do - // anything. - return nil - } - - return nil - }, - }, - { - w: &e.snd.reorderWaker, - f: e.snd.rc.reorderTimerExpired, - }, - } - - // Initialize the sleeper based on the wakers in funcs. + // Add all wakers. var s sleep.Sleeper - for i := range funcs { - s.AddWaker(funcs[i].w, i) - } + s.AddWaker(&e.sndQueueInfo.sndWaker) + s.AddWaker(&e.newSegmentWaker) + s.AddWaker(&e.snd.resendWaker) + s.AddWaker(&e.snd.probeWaker) + s.AddWaker(&closeWaker) + s.AddWaker(&e.keepalive.waker) + s.AddWaker(&e.notificationWaker) + s.AddWaker(&e.snd.reorderWaker) // Notify the caller that the waker initialization is complete and the // endpoint is ready. @@ -1581,7 +1547,7 @@ loop: } e.mu.Unlock() - v, _ := s.Fetch(true /* block */) + w := s.Fetch(true /* block */) e.mu.Lock() // We need to double check here because the notification may be @@ -1601,7 +1567,7 @@ loop: case StateClose: break loop default: - if err := funcs[v].f(); err != nil { + if err := e.handleWakeup(w, &closeWaker, &closeTimer); err != nil { cleanupOnError(err) e.protocolMainLoopDone(closeTimer) return @@ -1714,26 +1680,22 @@ func (e *endpoint) doTimeWait() (twReuse func()) { timeWaitDuration = time.Duration(tcpTW) } - const newSegment = 1 - const notification = 2 - const timeWaitDone = 3 - var s sleep.Sleeper defer s.Done() - s.AddWaker(&e.newSegmentWaker, newSegment) - s.AddWaker(&e.notificationWaker, notification) + s.AddWaker(&e.newSegmentWaker) + s.AddWaker(&e.notificationWaker) var timeWaitWaker sleep.Waker - s.AddWaker(&timeWaitWaker, timeWaitDone) + s.AddWaker(&timeWaitWaker) timeWaitTimer := e.stack.Clock().AfterFunc(timeWaitDuration, timeWaitWaker.Assert) defer timeWaitTimer.Stop() for { e.mu.Unlock() - v, _ := s.Fetch(true /* block */) + w := s.Fetch(true /* block */) e.mu.Lock() - switch v { - case newSegment: + switch w { + case &e.newSegmentWaker: extendTimeWait, reuseTW := e.handleTimeWaitSegments() if reuseTW != nil { return reuseTW @@ -1741,7 +1703,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { if extendTimeWait { timeWaitTimer.Reset(timeWaitDuration) } - case notification: + case &e.notificationWaker: n := e.fetchNotifications() if n¬ifyAbort != 0 { return nil @@ -1759,7 +1721,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { e.mu.Lock() return nil } - case timeWaitDone: + case &timeWaitWaker: return nil } } diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 7d110516b..2e93d2664 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -94,9 +94,10 @@ func (p *processor) start(wg *sync.WaitGroup) { defer p.sleeper.Done() for { - if id, _ := p.sleeper.Fetch(true); id == closeWaker { + if w := p.sleeper.Fetch(true); w == &p.closeWaker { break } + // If not the closeWaker, it must be &p.newEndpointWaker. for { ep := p.epQ.dequeue() if ep == nil { @@ -154,8 +155,8 @@ func (d *dispatcher) init(rng *rand.Rand, nProcessors int) { d.seed = rng.Uint32() for i := range d.processors { p := &d.processors[i] - p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) - p.sleeper.AddWaker(&p.closeWaker, closeWaker) + p.sleeper.AddWaker(&p.newEndpointWaker) + p.sleeper.AddWaker(&p.closeWaker) d.wg.Add(1) // NB: sleeper-waker registration must happen synchronously to avoid races // with `close`. It's possible to pull all this logic into `start`, but -- cgit v1.2.3