diff options
-rw-r--r-- | pkg/sleep/sleep_test.go | 182 | ||||
-rw-r--r-- | pkg/sleep/sleep_unsafe.go | 108 | ||||
-rw-r--r-- | pkg/syncevent/waiter_test.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/link/qdisc/fifo/endpoint.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 286 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/dispatcher.go | 7 |
7 files changed, 297 insertions, 332 deletions
diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go index 1dd11707d..b27feb99b 100644 --- a/pkg/sleep/sleep_test.go +++ b/pkg/sleep/sleep_test.go @@ -67,7 +67,7 @@ func AssertedWakerAfterTwoAsserts(t *testing.T) { func NotAssertedWakerWithSleeper(t *testing.T) { var w Waker var s Sleeper - s.AddWaker(&w, 0) + s.AddWaker(&w) if w.IsAsserted() { t.Fatalf("Non-asserted waker is reported as asserted") } @@ -83,7 +83,7 @@ func NotAssertedWakerWithSleeper(t *testing.T) { func NotAssertedWakerAfterWake(t *testing.T) { var w Waker var s Sleeper - s.AddWaker(&w, 0) + s.AddWaker(&w) w.Assert() s.Fetch(true) if w.IsAsserted() { @@ -101,10 +101,10 @@ func AssertedWakerBeforeAdd(t *testing.T) { var w Waker var s Sleeper w.Assert() - s.AddWaker(&w, 0) + s.AddWaker(&w) - if _, ok := s.Fetch(false); !ok { - t.Fatalf("Fetch failed even though asserted waker was added") + if s.Fetch(false) != &w { + t.Fatalf("Fetch did not match waker") } } @@ -128,7 +128,7 @@ func ClearedWaker(t *testing.T) { func ClearedWakerWithSleeper(t *testing.T) { var w Waker var s Sleeper - s.AddWaker(&w, 0) + s.AddWaker(&w) w.Clear() if w.IsAsserted() { t.Fatalf("Cleared waker is reported as asserted") @@ -145,7 +145,7 @@ func ClearedWakerWithSleeper(t *testing.T) { func ClearedWakerAssertedWithSleeper(t *testing.T) { var w Waker var s Sleeper - s.AddWaker(&w, 0) + s.AddWaker(&w) w.Assert() w.Clear() if w.IsAsserted() { @@ -163,18 +163,15 @@ func TestBlock(t *testing.T) { var w Waker var s Sleeper - s.AddWaker(&w, 0) + s.AddWaker(&w) // Assert waker after one second. before := time.Now() - go func() { - time.Sleep(1 * time.Second) - w.Assert() - }() + time.AfterFunc(time.Second, w.Assert) // Fetch the result and make sure it took at least 500ms. - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") + if s.Fetch(true) != &w { + t.Fatalf("Fetch did not match waker") } if d := time.Now().Sub(before); d < 500*time.Millisecond { t.Fatalf("Duration was too short: %v", d) @@ -182,8 +179,8 @@ func TestBlock(t *testing.T) { // Check that already-asserted waker completes inline. w.Assert() - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") + if s.Fetch(true) != &w { + t.Fatalf("Fetch did not match waker") } // Check that fetch sleeps if waker had been asserted but was reset @@ -191,12 +188,10 @@ func TestBlock(t *testing.T) { w.Assert() w.Clear() before = time.Now() - go func() { - time.Sleep(1 * time.Second) - w.Assert() - }() - if _, ok := s.Fetch(true); !ok { - t.Fatalf("Fetch failed unexpectedly") + time.AfterFunc(time.Second, w.Assert) + + if s.Fetch(true) != &w { + t.Fatalf("Fetch did not match waker") } if d := time.Now().Sub(before); d < 500*time.Millisecond { t.Fatalf("Duration was too short: %v", d) @@ -209,30 +204,30 @@ func TestNonBlock(t *testing.T) { var s Sleeper // Don't block when there's no waker. - if _, ok := s.Fetch(false); ok { + if s.Fetch(false) != nil { t.Fatalf("Fetch succeeded when there is no waker") } // Don't block when waker isn't asserted. - s.AddWaker(&w, 0) - if _, ok := s.Fetch(false); ok { + s.AddWaker(&w) + if s.Fetch(false) != nil { t.Fatalf("Fetch succeeded when waker was not asserted") } // Don't block when waker was asserted, but isn't anymore. w.Assert() w.Clear() - if _, ok := s.Fetch(false); ok { + if s.Fetch(false) != nil { t.Fatalf("Fetch succeeded when waker was not asserted anymore") } // Don't block when waker was consumed by previous Fetch(). w.Assert() - if _, ok := s.Fetch(false); !ok { + if s.Fetch(false) != &w { t.Fatalf("Fetch failed even though waker was asserted") } - if _, ok := s.Fetch(false); ok { + if s.Fetch(false) != nil { t.Fatalf("Fetch succeeded when waker had been consumed") } } @@ -244,29 +239,30 @@ func TestMultiple(t *testing.T) { w1 := Waker{} w2 := Waker{} - s.AddWaker(&w1, 0) - s.AddWaker(&w2, 1) + s.AddWaker(&w1) + s.AddWaker(&w2) w1.Assert() w2.Assert() - v, ok := s.Fetch(false) - if !ok { + v := s.Fetch(false) + if v == nil { t.Fatalf("Fetch failed when there are asserted wakers") } - - if v != 0 && v != 1 { - t.Fatalf("Unexpected waker id: %v", v) + if v != &w1 && v != &w2 { + t.Fatalf("Unexpected waker: %v", v) } - want := 1 - v - v, ok = s.Fetch(false) - if !ok { + want := &w1 + if v == want { + want = &w2 // Other waiter. + } + v = s.Fetch(false) + if v == nil { t.Fatalf("Fetch failed when there is an asserted waker") } - if v != want { - t.Fatalf("Unexpected waker id, got %v, want %v", v, want) + t.Fatalf("Unexpected waker, got %v, want %v", v, want) } } @@ -281,7 +277,7 @@ func TestDoneFunction(t *testing.T) { s := Sleeper{} w := make([]Waker, n) for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) + s.AddWaker(&w[j]) } s.Done() } @@ -293,7 +289,7 @@ func TestDoneFunction(t *testing.T) { s := Sleeper{} w := make([]Waker, n) for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) + s.AddWaker(&w[j]) } w[i].Assert() s.Done() @@ -307,7 +303,7 @@ func TestDoneFunction(t *testing.T) { s := Sleeper{} w := make([]Waker, n) for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) + s.AddWaker(&w[j]) } w[i].Assert() w[i].Clear() @@ -322,7 +318,7 @@ func TestDoneFunction(t *testing.T) { s := Sleeper{} w := make([]Waker, n) for j := 0; j < n; j++ { - s.AddWaker(&w[j], j) + s.AddWaker(&w[j]) } // Pick the number of asserted elements, then assert @@ -342,14 +338,14 @@ func TestRace(t *testing.T) { const wakers = 100 const wakeRequests = 10000 - counts := make([]int, wakers) - w := make([]Waker, wakers) + counts := make(map[*Waker]int, wakers) s := Sleeper{} // Associate each waker and start goroutines that will assert them. - for i := range w { - s.AddWaker(&w[i], i) - go func(w *Waker) { + for i := 0; i < wakers; i++ { + var w Waker + s.AddWaker(&w) + go func() { n := 0 for n < wakeRequests { if !w.IsAsserted() { @@ -359,19 +355,22 @@ func TestRace(t *testing.T) { runtime.Gosched() } } - }(&w[i]) + }() } // Wait for all wake up notifications from all wakers. for i := 0; i < wakers*wakeRequests; i++ { - v, _ := s.Fetch(true) + v := s.Fetch(true) counts[v]++ } // Check that we got the right number for each. - for i, v := range counts { - if v != wakeRequests { - t.Errorf("Waker %v only got %v wakes", i, v) + if got := len(counts); got != wakers { + t.Errorf("Got %d wakers, wanted %d", got, wakers) + } + for _, count := range counts { + if count != wakeRequests { + t.Errorf("Waker only got %d wakes, wanted %d", count, wakeRequests) } } } @@ -384,7 +383,7 @@ func TestRaceInOrder(t *testing.T) { // Associate each waker and start goroutines that will assert them. for i := range w { - s.AddWaker(&w[i], i) + s.AddWaker(&w[i]) } go func() { for i := range w { @@ -393,10 +392,10 @@ func TestRaceInOrder(t *testing.T) { }() // Wait for all wake up notifications from all wakers. - for want := range w { - got, _ := s.Fetch(true) - if got != want { - t.Fatalf("got %d want %d", got, want) + for i := range w { + got := s.Fetch(true) + if want := &w[i]; got != want { + t.Fatalf("got %v want %v", got, want) } } } @@ -408,7 +407,7 @@ func BenchmarkSleeperMultiSelect(b *testing.B) { s := Sleeper{} w := make([]Waker, count) for i := range w { - s.AddWaker(&w[i], i) + s.AddWaker(&w[i]) } b.ResetTimer() @@ -444,7 +443,7 @@ func BenchmarkGoMultiSelect(b *testing.B) { func BenchmarkSleeperSingleSelect(b *testing.B) { s := Sleeper{} w := Waker{} - s.AddWaker(&w, 0) + s.AddWaker(&w) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -494,16 +493,24 @@ func BenchmarkGoAssertNonWaiting(b *testing.B) { // a new goroutine doesn't run immediately (i.e., the creator of a new goroutine // is allowed to go to sleep before the new goroutine has a chance to run). func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) { - s := Sleeper{} - w := Waker{} - s.AddWaker(&w, 0) - for i := 0; i < b.N; i++ { - go func() { + var ( + s Sleeper + w Waker + ns Sleeper + nw Waker + ) + ns.AddWaker(&nw) + s.AddWaker(&w) + go func() { + for i := 0; i < b.N; i++ { + ns.Fetch(true) w.Assert() - }() + } + }() + for i := 0; i < b.N; i++ { + nw.Assert() s.Fetch(true) } - } // BenchmarkGoWaitOnSingleSelect measures how long it takes to wait on one @@ -511,11 +518,13 @@ func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) { // goroutine doesn't run immediately (i.e., the creator of a new goroutine is // allowed to go to sleep before the new goroutine has a chance to run). func BenchmarkGoWaitOnSingleSelect(b *testing.B) { - ch := make(chan struct{}, 1) - for i := 0; i < b.N; i++ { - go func() { + ch := make(chan struct{}) + go func() { + for i := 0; i < b.N; i++ { ch <- struct{}{} - }() + } + }() + for i := 0; i < b.N; i++ { <-ch } } @@ -526,17 +535,26 @@ func BenchmarkGoWaitOnSingleSelect(b *testing.B) { // allowed to go to sleep before the new goroutine has a chance to run). func BenchmarkSleeperWaitOnMultiSelect(b *testing.B) { const count = 4 - s := Sleeper{} + var ( + s Sleeper + ns Sleeper + nw Waker + ) + ns.AddWaker(&nw) w := make([]Waker, count) for i := range w { - s.AddWaker(&w[i], i) + s.AddWaker(&w[i]) } b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { + go func() { + for i := 0; i < b.N; i++ { + ns.Fetch(true) w[count-1].Assert() - }() + } + }() + for i := 0; i < b.N; i++ { + nw.Assert() s.Fetch(true) } } @@ -549,14 +567,16 @@ func BenchmarkGoWaitOnMultiSelect(b *testing.B) { const count = 4 ch := make([]chan struct{}, count) for i := range ch { - ch[i] = make(chan struct{}, 1) + ch[i] = make(chan struct{}) } b.ResetTimer() - for i := 0; i < b.N; i++ { - go func() { + go func() { + for i := 0; i < b.N; i++ { ch[count-1] <- struct{}{} - }() + } + }() + for i := 0; i < b.N; i++ { select { case <-ch[0]: case <-ch[1]: diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index c44206b1e..86c7cc983 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -37,15 +37,15 @@ // // // One time set-up. // s := sleep.Sleeper{} -// s.AddWaker(&w1, constant1) -// s.AddWaker(&w2, constant2) +// s.AddWaker(&w1) +// s.AddWaker(&w2) // // // Called repeatedly. // for { -// switch id, _ := s.Fetch(true); id { -// case constant1: +// switch s.Fetch(true) { +// case &w1: // // Do work triggered by w1 being asserted. -// case constant2: +// case &w2: // // Do work triggered by w2 being asserted. // } // } @@ -119,13 +119,18 @@ type Sleeper struct { waitingG uintptr } -// AddWaker associates the given waker to the sleeper. id is the value to be -// returned when the sleeper is woken by the given waker. -func (s *Sleeper) AddWaker(w *Waker, id int) { +// AddWaker associates the given waker to the sleeper. +func (s *Sleeper) AddWaker(w *Waker) { + if w.allWakersNext != nil { + panic("waker has non-nil allWakersNext; owned by another sleeper?") + } + if w.next != nil { + panic("waker has non-nil next; queued in another sleeper?") + } + // Add the waker to the list of all wakers. w.allWakersNext = s.allWakers s.allWakers = w - w.id = id // Try to associate the waker with the sleeper. If it's already // asserted, we simply enqueue it in the "ready" list. @@ -213,28 +218,26 @@ func commitSleep(g uintptr, waitingG unsafe.Pointer) bool { return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(waitingG), preparingG, g) } -// Fetch fetches the next wake-up notification. If a notification is immediately -// available, it is returned right away. Otherwise, the behavior depends on the -// value of 'block': if true, the current goroutine blocks until a notification -// arrives, then returns it; if false, returns 'ok' as false. -// -// When 'ok' is true, the value of 'id' corresponds to the id associated with -// the waker; when 'ok' is false, 'id' is undefined. +// Fetch fetches the next wake-up notification. If a notification is +// immediately available, the asserted waker is returned immediately. +// Otherwise, the behavior depends on the value of 'block': if true, the +// current goroutine blocks until a notification arrives and returns the +// asserted waker; if false, nil will be returned. // // N.B. This method is *not* thread-safe. Only one goroutine at a time is // allowed to call this method. -func (s *Sleeper) Fetch(block bool) (id int, ok bool) { +func (s *Sleeper) Fetch(block bool) *Waker { for { w := s.nextWaker(block) if w == nil { - return -1, false + return nil } // Reassociate the waker with the sleeper. If the waker was // still asserted we can return it, otherwise try the next one. old := (*Sleeper)(atomic.SwapPointer(&w.s, usleeper(s))) if old == &assertedSleeper { - return w.id, true + return w } } } @@ -243,51 +246,34 @@ func (s *Sleeper) Fetch(block bool) (id int, ok bool) { // removes the association with all wakers so that they can be safely reused // by another sleeper after Done() returns. func (s *Sleeper) Done() { - // Remove all associations that we can, and build a list of the ones - // we could not. An association can be removed right away from waker w - // if w.s has a pointer to the sleeper, that is, the waker is not - // asserted yet. By atomically switching w.s to nil, we guarantee that - // subsequent calls to Assert() on the waker will not result in it being - // queued to this sleeper. - var pending *Waker - w := s.allWakers - for w != nil { - next := w.allWakersNext - for { - t := atomic.LoadPointer(&w.s) - if t != usleeper(s) { - w.allWakersNext = pending - pending = w - break - } - - if atomic.CompareAndSwapPointer(&w.s, t, nil) { - break - } + // Remove all associations that we can, and build a list of the ones we + // could not. An association can be removed right away from waker w if + // w.s has a pointer to the sleeper, that is, the waker is not asserted + // yet. By atomically switching w.s to nil, we guarantee that + // subsequent calls to Assert() on the waker will not result in it + // being queued. + for w := s.allWakers; w != nil; w = s.allWakers { + next := w.allWakersNext // Before zapping. + if atomic.CompareAndSwapPointer(&w.s, usleeper(s), nil) { + w.allWakersNext = nil + w.next = nil + s.allWakers = next // Move ahead. + continue } - w = next - } - // The associations that we could not remove are either asserted, or in - // the process of being asserted, or have been asserted and cleared - // before being pulled from the sleeper lists. We must wait for them all - // to make it to the sleeper lists, so that we know that the wakers - // won't do any more work towards waking this sleeper up. - for pending != nil { - pulled := s.nextWaker(true) - - // Remove the waker we just pulled from the list of associated - // wakers. - prev := &pending - for w := *prev; w != nil; w = *prev { - if pulled == w { - *prev = w.allWakersNext - break + // Dequeue exactly one waiter from the list, it may not be + // this one but we know this one is in the process. We must + // leave it in the asserted state but drop it from our lists. + if w := s.nextWaker(true); w != nil { + prev := &s.allWakers + for *prev != w { + prev = &((*prev).allWakersNext) } - prev = &w.allWakersNext + *prev = (*prev).allWakersNext + w.allWakersNext = nil + w.next = nil } } - s.allWakers = nil } // enqueueAssertedWaker enqueues an asserted waker to the "ready" circular list @@ -349,10 +335,6 @@ type Waker struct { // allWakersNext is used to form a linked list of all wakers associated // to a given sleeper. allWakersNext *Waker - - // id is the value to be returned to sleepers when they wake up due to - // this waker being asserted. - id int } // Assert moves the waker to an asserted state, if it isn't asserted yet. When diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go index 3c8cbcdd8..cfa0972c0 100644 --- a/pkg/syncevent/waiter_test.go +++ b/pkg/syncevent/waiter_test.go @@ -105,7 +105,7 @@ func BenchmarkWaiterNotifyRedundant(b *testing.B) { func BenchmarkSleeperNotifyRedundant(b *testing.B) { var s sleep.Sleeper var w sleep.Waker - s.AddWaker(&w, 0) + s.AddWaker(&w) w.Assert() b.ResetTimer() @@ -146,7 +146,7 @@ func BenchmarkWaiterNotifyWaitAck(b *testing.B) { func BenchmarkSleeperNotifyWaitAck(b *testing.B) { var s sleep.Sleeper var w sleep.Waker - s.AddWaker(&w, 0) + s.AddWaker(&w) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -197,7 +197,7 @@ func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { w := wakerPool.Get().(*sleep.Waker) - s.AddWaker(w, 0) + s.AddWaker(w) w.Assert() s.Fetch(true) s.Done() @@ -237,7 +237,7 @@ func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { s := sleeperPool.Get().(*sleep.Sleeper) - s.AddWaker(&w, 0) + s.AddWaker(&w) w.Assert() s.Fetch(true) s.Done() @@ -266,14 +266,14 @@ func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) { var s sleep.Sleeper var ws [3]sleep.Waker for i := range ws { - s.AddWaker(&ws[i], i) + s.AddWaker(&ws[i]) } b.ResetTimer() for i := 0; i < b.N; i++ { ws[0].Assert() - if id, _ := s.Fetch(true); id != 0 { - b.Fatalf("Fetch: got %d, wanted 0", id) + if v := s.Fetch(true); v != &ws[0] { + b.Fatalf("Fetch: got %v, wanted %v", v, &ws[0]) } } } @@ -325,7 +325,7 @@ func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) { func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) { var s sleep.Sleeper var w sleep.Waker - s.AddWaker(&w, 0) + s.AddWaker(&w) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -374,7 +374,7 @@ func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) { var s sleep.Sleeper var ws [3]sleep.Waker for i := range ws { - s.AddWaker(&ws[i], i) + s.AddWaker(&ws[i]) } b.ResetTimer() @@ -382,8 +382,8 @@ func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) { go func() { ws[0].Assert() }() - if id, _ := s.Fetch(true); id != 0 { - b.Fatalf("Fetch: got %d, expected 0", id) + if v := s.Fetch(true); v != &ws[0] { + b.Fatalf("Fetch: got %v, expected %v", v, &ws[0]) } } } diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index b41e3e2fa..c15cbf81b 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -73,20 +73,19 @@ func New(lower stack.LinkEndpoint, n int, queueLen int) stack.LinkEndpoint { } func (q *queueDispatcher) dispatchLoop() { - const newPacketWakerID = 1 - const closeWakerID = 2 s := sleep.Sleeper{} - s.AddWaker(&q.newPacketWaker, newPacketWakerID) - s.AddWaker(&q.closeWaker, closeWakerID) + s.AddWaker(&q.newPacketWaker) + s.AddWaker(&q.closeWaker) defer s.Done() const batchSize = 32 var batch stack.PacketBufferList for { - id, ok := s.Fetch(true) - if ok && id == closeWakerID { + w := s.Fetch(true) + if w == &q.closeWaker { return } + // Must otherwise be the newPacketWaker. for pkt := q.q.dequeue(); pkt != nil; pkt = q.q.dequeue() { batch.PushBack(pkt) if batch.Len() < batchSize && !q.q.empty() { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index caf14b0dc..d0f68b72c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -762,14 +762,15 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { }() var s sleep.Sleeper - s.AddWaker(&e.notificationWaker, wakerForNotification) - s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) + s.AddWaker(&e.notificationWaker) + s.AddWaker(&e.newSegmentWaker) + defer s.Done() for { e.mu.Unlock() - index, _ := s.Fetch(true) + w := s.Fetch(true) e.mu.Lock() - switch index { - case wakerForNotification: + switch w { + case &e.notificationWaker: n := e.fetchNotifications() if n¬ifyClose != 0 { return @@ -788,7 +789,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() } - case wakerForNewSegment: + case &e.newSegmentWaker: // Process at most maxSegmentsPerWake segments. mayRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 80cd07218..12df7a7b4 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -51,13 +51,6 @@ const ( handshakeCompleted ) -// The following are used to set up sleepers. -const ( - wakerForNotification = iota - wakerForNewSegment - wakerForResend -) - const ( // Maximum space available for options. maxOptionSize = 40 @@ -530,9 +523,9 @@ func (h *handshake) complete() tcpip.Error { // Set up the wakers. var s sleep.Sleeper resendWaker := sleep.Waker{} - s.AddWaker(&resendWaker, wakerForResend) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + s.AddWaker(&resendWaker) + s.AddWaker(&h.ep.notificationWaker) + s.AddWaker(&h.ep.newSegmentWaker) defer s.Done() // Initialize the resend timer. @@ -545,11 +538,10 @@ func (h *handshake) complete() tcpip.Error { // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held // throughout handshake processing). h.ep.mu.Unlock() - index, _ := s.Fetch(true /* block */) + w := s.Fetch(true /* block */) h.ep.mu.Lock() - switch index { - - case wakerForResend: + switch w { + case &resendWaker: if err := timer.reset(); err != nil { return err } @@ -577,7 +569,7 @@ func (h *handshake) complete() tcpip.Error { h.sampleRTTWithTSOnly = true } - case wakerForNotification: + case &h.ep.notificationWaker: n := h.ep.fetchNotifications() if (n¬ifyClose)|(n¬ifyAbort) != 0 { return &tcpip.ErrAborted{} @@ -611,7 +603,7 @@ func (h *handshake) complete() tcpip.Error { // cleared because of a socket layer call. return &tcpip.ErrConnectionAborted{} } - case wakerForNewSegment: + case &h.ep.newSegmentWaker: if err := h.processSegments(); err != nil { return err } @@ -1346,6 +1338,103 @@ func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer) { e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } +// handleWakeup handles a wakeup event while connected. +// +// +checklocks:e.mu +func (e *endpoint) handleWakeup(w, closeWaker *sleep.Waker, closeTimer *tcpip.Timer) tcpip.Error { + switch w { + case &e.sndQueueInfo.sndWaker: + e.sendData(nil /* next */) + case &e.newSegmentWaker: + return e.handleSegmentsLocked(false /* fastPath */) + case &e.snd.resendWaker: + if !e.snd.retransmitTimerExpired() { + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return &tcpip.ErrTimeout{} + } + case closeWaker: + // This means the socket is being closed due to the + // TCP-FIN-WAIT2 timeout was hit. Just mark the socket as + // closed. + e.transitionToStateCloseLocked() + e.workerCleanup = true + case &e.snd.probeWaker: + return e.snd.probeTimerExpired() + case &e.keepalive.waker: + return e.keepaliveTimerExpired() + case &e.notificationWaker: + n := e.fetchNotifications() + if n¬ifyNonZeroReceiveWindow != 0 { + e.rcv.nonZeroWindow() + } + + if n¬ifyMTUChanged != 0 { + e.sndQueueInfo.sndQueueMu.Lock() + count := e.sndQueueInfo.PacketTooBigCount + e.sndQueueInfo.PacketTooBigCount = 0 + mtu := e.sndQueueInfo.SndMTU + e.sndQueueInfo.sndQueueMu.Unlock() + + e.snd.updateMaxPayloadSize(mtu, count) + } + + if n¬ifyReset != 0 || n¬ifyAbort != 0 { + return &tcpip.ErrConnectionAborted{} + } + + if n¬ifyResetByPeer != 0 { + return &tcpip.ErrConnectionReset{} + } + + if n¬ifyClose != 0 && e.closed { + switch e.EndpointState() { + case StateEstablished: + // Perform full shutdown if the endpoint is + // still established. This can occur when + // notifyClose was asserted just before + // becoming established. + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + case StateFinWait2: + // The socket has been closed and we are in + // FIN_WAIT2 so start the FIN_WAIT2 timer. + if *closeTimer == nil { + *closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + } + } + } + + if n¬ifyKeepaliveChanged != 0 { + // The timer could fire in background when the endpoint + // is drained. That's OK. See above. + e.resetKeepaliveTimer(true) + } + + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + if err := e.handleSegmentsLocked(false /* fastPath */); err != nil { + return err + } + } + if !e.EndpointState().closed() { + // Only block the worker if the endpoint + // is not in closed state or error state. + close(e.drainDone) + e.mu.Unlock() + <-e.undrain + e.mu.Lock() + } + } + + // N.B. notifyTickleWorker may be set, but there is no action + // to take in this case. + case &e.snd.reorderWaker: + return e.snd.rc.reorderTimerExpired() + default: + panic("unknown waker") // Shouldn't happen. + } + return nil +} + // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. @@ -1403,139 +1492,16 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.mu.Lock() } - // Set up the functions that will be called when the main protocol loop - // wakes up. - funcs := []struct { - w *sleep.Waker - f func() tcpip.Error - }{ - { - w: &e.sndQueueInfo.sndWaker, - f: func() tcpip.Error { - e.sendData(nil /* next */) - return nil - }, - }, - { - w: &closeWaker, - f: func() tcpip.Error { - // This means the socket is being closed due - // to the TCP-FIN-WAIT2 timeout was hit. Just - // mark the socket as closed. - e.transitionToStateCloseLocked() - e.workerCleanup = true - return nil - }, - }, - { - w: &e.snd.resendWaker, - f: func() tcpip.Error { - if !e.snd.retransmitTimerExpired() { - e.stack.Stats().TCP.EstablishedTimedout.Increment() - return &tcpip.ErrTimeout{} - } - return nil - }, - }, - { - w: &e.snd.probeWaker, - f: e.snd.probeTimerExpired, - }, - { - w: &e.newSegmentWaker, - f: func() tcpip.Error { - return e.handleSegmentsLocked(false /* fastPath */) - }, - }, - { - w: &e.keepalive.waker, - f: e.keepaliveTimerExpired, - }, - { - w: &e.notificationWaker, - f: func() tcpip.Error { - n := e.fetchNotifications() - if n¬ifyNonZeroReceiveWindow != 0 { - e.rcv.nonZeroWindow() - } - - if n¬ifyMTUChanged != 0 { - e.sndQueueInfo.sndQueueMu.Lock() - count := e.sndQueueInfo.PacketTooBigCount - e.sndQueueInfo.PacketTooBigCount = 0 - mtu := e.sndQueueInfo.SndMTU - e.sndQueueInfo.sndQueueMu.Unlock() - - e.snd.updateMaxPayloadSize(mtu, count) - } - - if n¬ifyReset != 0 || n¬ifyAbort != 0 { - return &tcpip.ErrConnectionAborted{} - } - - if n¬ifyResetByPeer != 0 { - return &tcpip.ErrConnectionReset{} - } - - if n¬ifyClose != 0 && e.closed { - switch e.EndpointState() { - case StateEstablished: - // Perform full shutdown if the endpoint is still - // established. This can occur when notifyClose - // was asserted just before becoming established. - e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) - case StateFinWait2: - // The socket has been closed and we are in FIN_WAIT2 - // so start the FIN_WAIT2 timer. - if closeTimer == nil { - closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) - } - } - } - - if n¬ifyKeepaliveChanged != 0 { - // The timer could fire in background - // when the endpoint is drained. That's - // OK. See above. - e.resetKeepaliveTimer(true) - } - - if n¬ifyDrain != 0 { - for !e.segmentQueue.empty() { - if err := e.handleSegmentsLocked(false /* fastPath */); err != nil { - return err - } - } - if !e.EndpointState().closed() { - // Only block the worker if the endpoint - // is not in closed state or error state. - close(e.drainDone) - e.mu.Unlock() // +checklocksforce - <-e.undrain - e.mu.Lock() - } - } - - if n¬ifyTickleWorker != 0 { - // Just a tickle notification. No need to do - // anything. - return nil - } - - return nil - }, - }, - { - w: &e.snd.reorderWaker, - f: e.snd.rc.reorderTimerExpired, - }, - } - - // Initialize the sleeper based on the wakers in funcs. + // Add all wakers. var s sleep.Sleeper - for i := range funcs { - s.AddWaker(funcs[i].w, i) - } + s.AddWaker(&e.sndQueueInfo.sndWaker) + s.AddWaker(&e.newSegmentWaker) + s.AddWaker(&e.snd.resendWaker) + s.AddWaker(&e.snd.probeWaker) + s.AddWaker(&closeWaker) + s.AddWaker(&e.keepalive.waker) + s.AddWaker(&e.notificationWaker) + s.AddWaker(&e.snd.reorderWaker) // Notify the caller that the waker initialization is complete and the // endpoint is ready. @@ -1581,7 +1547,7 @@ loop: } e.mu.Unlock() - v, _ := s.Fetch(true /* block */) + w := s.Fetch(true /* block */) e.mu.Lock() // We need to double check here because the notification may be @@ -1601,7 +1567,7 @@ loop: case StateClose: break loop default: - if err := funcs[v].f(); err != nil { + if err := e.handleWakeup(w, &closeWaker, &closeTimer); err != nil { cleanupOnError(err) e.protocolMainLoopDone(closeTimer) return @@ -1714,26 +1680,22 @@ func (e *endpoint) doTimeWait() (twReuse func()) { timeWaitDuration = time.Duration(tcpTW) } - const newSegment = 1 - const notification = 2 - const timeWaitDone = 3 - var s sleep.Sleeper defer s.Done() - s.AddWaker(&e.newSegmentWaker, newSegment) - s.AddWaker(&e.notificationWaker, notification) + s.AddWaker(&e.newSegmentWaker) + s.AddWaker(&e.notificationWaker) var timeWaitWaker sleep.Waker - s.AddWaker(&timeWaitWaker, timeWaitDone) + s.AddWaker(&timeWaitWaker) timeWaitTimer := e.stack.Clock().AfterFunc(timeWaitDuration, timeWaitWaker.Assert) defer timeWaitTimer.Stop() for { e.mu.Unlock() - v, _ := s.Fetch(true /* block */) + w := s.Fetch(true /* block */) e.mu.Lock() - switch v { - case newSegment: + switch w { + case &e.newSegmentWaker: extendTimeWait, reuseTW := e.handleTimeWaitSegments() if reuseTW != nil { return reuseTW @@ -1741,7 +1703,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { if extendTimeWait { timeWaitTimer.Reset(timeWaitDuration) } - case notification: + case &e.notificationWaker: n := e.fetchNotifications() if n¬ifyAbort != 0 { return nil @@ -1759,7 +1721,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { e.mu.Lock() return nil } - case timeWaitDone: + case &timeWaitWaker: return nil } } diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 7d110516b..2e93d2664 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -94,9 +94,10 @@ func (p *processor) start(wg *sync.WaitGroup) { defer p.sleeper.Done() for { - if id, _ := p.sleeper.Fetch(true); id == closeWaker { + if w := p.sleeper.Fetch(true); w == &p.closeWaker { break } + // If not the closeWaker, it must be &p.newEndpointWaker. for { ep := p.epQ.dequeue() if ep == nil { @@ -154,8 +155,8 @@ func (d *dispatcher) init(rng *rand.Rand, nProcessors int) { d.seed = rng.Uint32() for i := range d.processors { p := &d.processors[i] - p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) - p.sleeper.AddWaker(&p.closeWaker, closeWaker) + p.sleeper.AddWaker(&p.newEndpointWaker) + p.sleeper.AddWaker(&p.closeWaker) d.wg.Add(1) // NB: sleeper-waker registration must happen synchronously to avoid races // with `close`. It's possible to pull all this logic into `start`, but |