summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sleep/sleep_test.go182
-rw-r--r--pkg/sleep/sleep_unsafe.go108
-rw-r--r--pkg/syncevent/waiter_test.go22
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go11
-rw-r--r--pkg/tcpip/transport/tcp/accept.go13
-rw-r--r--pkg/tcpip/transport/tcp/connect.go286
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go7
7 files changed, 297 insertions, 332 deletions
diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go
index 1dd11707d..b27feb99b 100644
--- a/pkg/sleep/sleep_test.go
+++ b/pkg/sleep/sleep_test.go
@@ -67,7 +67,7 @@ func AssertedWakerAfterTwoAsserts(t *testing.T) {
func NotAssertedWakerWithSleeper(t *testing.T) {
var w Waker
var s Sleeper
- s.AddWaker(&w, 0)
+ s.AddWaker(&w)
if w.IsAsserted() {
t.Fatalf("Non-asserted waker is reported as asserted")
}
@@ -83,7 +83,7 @@ func NotAssertedWakerWithSleeper(t *testing.T) {
func NotAssertedWakerAfterWake(t *testing.T) {
var w Waker
var s Sleeper
- s.AddWaker(&w, 0)
+ s.AddWaker(&w)
w.Assert()
s.Fetch(true)
if w.IsAsserted() {
@@ -101,10 +101,10 @@ func AssertedWakerBeforeAdd(t *testing.T) {
var w Waker
var s Sleeper
w.Assert()
- s.AddWaker(&w, 0)
+ s.AddWaker(&w)
- if _, ok := s.Fetch(false); !ok {
- t.Fatalf("Fetch failed even though asserted waker was added")
+ if s.Fetch(false) != &w {
+ t.Fatalf("Fetch did not match waker")
}
}
@@ -128,7 +128,7 @@ func ClearedWaker(t *testing.T) {
func ClearedWakerWithSleeper(t *testing.T) {
var w Waker
var s Sleeper
- s.AddWaker(&w, 0)
+ s.AddWaker(&w)
w.Clear()
if w.IsAsserted() {
t.Fatalf("Cleared waker is reported as asserted")
@@ -145,7 +145,7 @@ func ClearedWakerWithSleeper(t *testing.T) {
func ClearedWakerAssertedWithSleeper(t *testing.T) {
var w Waker
var s Sleeper
- s.AddWaker(&w, 0)
+ s.AddWaker(&w)
w.Assert()
w.Clear()
if w.IsAsserted() {
@@ -163,18 +163,15 @@ func TestBlock(t *testing.T) {
var w Waker
var s Sleeper
- s.AddWaker(&w, 0)
+ s.AddWaker(&w)
// Assert waker after one second.
before := time.Now()
- go func() {
- time.Sleep(1 * time.Second)
- w.Assert()
- }()
+ time.AfterFunc(time.Second, w.Assert)
// Fetch the result and make sure it took at least 500ms.
- if _, ok := s.Fetch(true); !ok {
- t.Fatalf("Fetch failed unexpectedly")
+ if s.Fetch(true) != &w {
+ t.Fatalf("Fetch did not match waker")
}
if d := time.Now().Sub(before); d < 500*time.Millisecond {
t.Fatalf("Duration was too short: %v", d)
@@ -182,8 +179,8 @@ func TestBlock(t *testing.T) {
// Check that already-asserted waker completes inline.
w.Assert()
- if _, ok := s.Fetch(true); !ok {
- t.Fatalf("Fetch failed unexpectedly")
+ if s.Fetch(true) != &w {
+ t.Fatalf("Fetch did not match waker")
}
// Check that fetch sleeps if waker had been asserted but was reset
@@ -191,12 +188,10 @@ func TestBlock(t *testing.T) {
w.Assert()
w.Clear()
before = time.Now()
- go func() {
- time.Sleep(1 * time.Second)
- w.Assert()
- }()
- if _, ok := s.Fetch(true); !ok {
- t.Fatalf("Fetch failed unexpectedly")
+ time.AfterFunc(time.Second, w.Assert)
+
+ if s.Fetch(true) != &w {
+ t.Fatalf("Fetch did not match waker")
}
if d := time.Now().Sub(before); d < 500*time.Millisecond {
t.Fatalf("Duration was too short: %v", d)
@@ -209,30 +204,30 @@ func TestNonBlock(t *testing.T) {
var s Sleeper
// Don't block when there's no waker.
- if _, ok := s.Fetch(false); ok {
+ if s.Fetch(false) != nil {
t.Fatalf("Fetch succeeded when there is no waker")
}
// Don't block when waker isn't asserted.
- s.AddWaker(&w, 0)
- if _, ok := s.Fetch(false); ok {
+ s.AddWaker(&w)
+ if s.Fetch(false) != nil {
t.Fatalf("Fetch succeeded when waker was not asserted")
}
// Don't block when waker was asserted, but isn't anymore.
w.Assert()
w.Clear()
- if _, ok := s.Fetch(false); ok {
+ if s.Fetch(false) != nil {
t.Fatalf("Fetch succeeded when waker was not asserted anymore")
}
// Don't block when waker was consumed by previous Fetch().
w.Assert()
- if _, ok := s.Fetch(false); !ok {
+ if s.Fetch(false) != &w {
t.Fatalf("Fetch failed even though waker was asserted")
}
- if _, ok := s.Fetch(false); ok {
+ if s.Fetch(false) != nil {
t.Fatalf("Fetch succeeded when waker had been consumed")
}
}
@@ -244,29 +239,30 @@ func TestMultiple(t *testing.T) {
w1 := Waker{}
w2 := Waker{}
- s.AddWaker(&w1, 0)
- s.AddWaker(&w2, 1)
+ s.AddWaker(&w1)
+ s.AddWaker(&w2)
w1.Assert()
w2.Assert()
- v, ok := s.Fetch(false)
- if !ok {
+ v := s.Fetch(false)
+ if v == nil {
t.Fatalf("Fetch failed when there are asserted wakers")
}
-
- if v != 0 && v != 1 {
- t.Fatalf("Unexpected waker id: %v", v)
+ if v != &w1 && v != &w2 {
+ t.Fatalf("Unexpected waker: %v", v)
}
- want := 1 - v
- v, ok = s.Fetch(false)
- if !ok {
+ want := &w1
+ if v == want {
+ want = &w2 // Other waiter.
+ }
+ v = s.Fetch(false)
+ if v == nil {
t.Fatalf("Fetch failed when there is an asserted waker")
}
-
if v != want {
- t.Fatalf("Unexpected waker id, got %v, want %v", v, want)
+ t.Fatalf("Unexpected waker, got %v, want %v", v, want)
}
}
@@ -281,7 +277,7 @@ func TestDoneFunction(t *testing.T) {
s := Sleeper{}
w := make([]Waker, n)
for j := 0; j < n; j++ {
- s.AddWaker(&w[j], j)
+ s.AddWaker(&w[j])
}
s.Done()
}
@@ -293,7 +289,7 @@ func TestDoneFunction(t *testing.T) {
s := Sleeper{}
w := make([]Waker, n)
for j := 0; j < n; j++ {
- s.AddWaker(&w[j], j)
+ s.AddWaker(&w[j])
}
w[i].Assert()
s.Done()
@@ -307,7 +303,7 @@ func TestDoneFunction(t *testing.T) {
s := Sleeper{}
w := make([]Waker, n)
for j := 0; j < n; j++ {
- s.AddWaker(&w[j], j)
+ s.AddWaker(&w[j])
}
w[i].Assert()
w[i].Clear()
@@ -322,7 +318,7 @@ func TestDoneFunction(t *testing.T) {
s := Sleeper{}
w := make([]Waker, n)
for j := 0; j < n; j++ {
- s.AddWaker(&w[j], j)
+ s.AddWaker(&w[j])
}
// Pick the number of asserted elements, then assert
@@ -342,14 +338,14 @@ func TestRace(t *testing.T) {
const wakers = 100
const wakeRequests = 10000
- counts := make([]int, wakers)
- w := make([]Waker, wakers)
+ counts := make(map[*Waker]int, wakers)
s := Sleeper{}
// Associate each waker and start goroutines that will assert them.
- for i := range w {
- s.AddWaker(&w[i], i)
- go func(w *Waker) {
+ for i := 0; i < wakers; i++ {
+ var w Waker
+ s.AddWaker(&w)
+ go func() {
n := 0
for n < wakeRequests {
if !w.IsAsserted() {
@@ -359,19 +355,22 @@ func TestRace(t *testing.T) {
runtime.Gosched()
}
}
- }(&w[i])
+ }()
}
// Wait for all wake up notifications from all wakers.
for i := 0; i < wakers*wakeRequests; i++ {
- v, _ := s.Fetch(true)
+ v := s.Fetch(true)
counts[v]++
}
// Check that we got the right number for each.
- for i, v := range counts {
- if v != wakeRequests {
- t.Errorf("Waker %v only got %v wakes", i, v)
+ if got := len(counts); got != wakers {
+ t.Errorf("Got %d wakers, wanted %d", got, wakers)
+ }
+ for _, count := range counts {
+ if count != wakeRequests {
+ t.Errorf("Waker only got %d wakes, wanted %d", count, wakeRequests)
}
}
}
@@ -384,7 +383,7 @@ func TestRaceInOrder(t *testing.T) {
// Associate each waker and start goroutines that will assert them.
for i := range w {
- s.AddWaker(&w[i], i)
+ s.AddWaker(&w[i])
}
go func() {
for i := range w {
@@ -393,10 +392,10 @@ func TestRaceInOrder(t *testing.T) {
}()
// Wait for all wake up notifications from all wakers.
- for want := range w {
- got, _ := s.Fetch(true)
- if got != want {
- t.Fatalf("got %d want %d", got, want)
+ for i := range w {
+ got := s.Fetch(true)
+ if want := &w[i]; got != want {
+ t.Fatalf("got %v want %v", got, want)
}
}
}
@@ -408,7 +407,7 @@ func BenchmarkSleeperMultiSelect(b *testing.B) {
s := Sleeper{}
w := make([]Waker, count)
for i := range w {
- s.AddWaker(&w[i], i)
+ s.AddWaker(&w[i])
}
b.ResetTimer()
@@ -444,7 +443,7 @@ func BenchmarkGoMultiSelect(b *testing.B) {
func BenchmarkSleeperSingleSelect(b *testing.B) {
s := Sleeper{}
w := Waker{}
- s.AddWaker(&w, 0)
+ s.AddWaker(&w)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -494,16 +493,24 @@ func BenchmarkGoAssertNonWaiting(b *testing.B) {
// a new goroutine doesn't run immediately (i.e., the creator of a new goroutine
// is allowed to go to sleep before the new goroutine has a chance to run).
func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) {
- s := Sleeper{}
- w := Waker{}
- s.AddWaker(&w, 0)
- for i := 0; i < b.N; i++ {
- go func() {
+ var (
+ s Sleeper
+ w Waker
+ ns Sleeper
+ nw Waker
+ )
+ ns.AddWaker(&nw)
+ s.AddWaker(&w)
+ go func() {
+ for i := 0; i < b.N; i++ {
+ ns.Fetch(true)
w.Assert()
- }()
+ }
+ }()
+ for i := 0; i < b.N; i++ {
+ nw.Assert()
s.Fetch(true)
}
-
}
// BenchmarkGoWaitOnSingleSelect measures how long it takes to wait on one
@@ -511,11 +518,13 @@ func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) {
// goroutine doesn't run immediately (i.e., the creator of a new goroutine is
// allowed to go to sleep before the new goroutine has a chance to run).
func BenchmarkGoWaitOnSingleSelect(b *testing.B) {
- ch := make(chan struct{}, 1)
- for i := 0; i < b.N; i++ {
- go func() {
+ ch := make(chan struct{})
+ go func() {
+ for i := 0; i < b.N; i++ {
ch <- struct{}{}
- }()
+ }
+ }()
+ for i := 0; i < b.N; i++ {
<-ch
}
}
@@ -526,17 +535,26 @@ func BenchmarkGoWaitOnSingleSelect(b *testing.B) {
// allowed to go to sleep before the new goroutine has a chance to run).
func BenchmarkSleeperWaitOnMultiSelect(b *testing.B) {
const count = 4
- s := Sleeper{}
+ var (
+ s Sleeper
+ ns Sleeper
+ nw Waker
+ )
+ ns.AddWaker(&nw)
w := make([]Waker, count)
for i := range w {
- s.AddWaker(&w[i], i)
+ s.AddWaker(&w[i])
}
b.ResetTimer()
- for i := 0; i < b.N; i++ {
- go func() {
+ go func() {
+ for i := 0; i < b.N; i++ {
+ ns.Fetch(true)
w[count-1].Assert()
- }()
+ }
+ }()
+ for i := 0; i < b.N; i++ {
+ nw.Assert()
s.Fetch(true)
}
}
@@ -549,14 +567,16 @@ func BenchmarkGoWaitOnMultiSelect(b *testing.B) {
const count = 4
ch := make([]chan struct{}, count)
for i := range ch {
- ch[i] = make(chan struct{}, 1)
+ ch[i] = make(chan struct{})
}
b.ResetTimer()
- for i := 0; i < b.N; i++ {
- go func() {
+ go func() {
+ for i := 0; i < b.N; i++ {
ch[count-1] <- struct{}{}
- }()
+ }
+ }()
+ for i := 0; i < b.N; i++ {
select {
case <-ch[0]:
case <-ch[1]:
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
index c44206b1e..86c7cc983 100644
--- a/pkg/sleep/sleep_unsafe.go
+++ b/pkg/sleep/sleep_unsafe.go
@@ -37,15 +37,15 @@
//
// // One time set-up.
// s := sleep.Sleeper{}
-// s.AddWaker(&w1, constant1)
-// s.AddWaker(&w2, constant2)
+// s.AddWaker(&w1)
+// s.AddWaker(&w2)
//
// // Called repeatedly.
// for {
-// switch id, _ := s.Fetch(true); id {
-// case constant1:
+// switch s.Fetch(true) {
+// case &w1:
// // Do work triggered by w1 being asserted.
-// case constant2:
+// case &w2:
// // Do work triggered by w2 being asserted.
// }
// }
@@ -119,13 +119,18 @@ type Sleeper struct {
waitingG uintptr
}
-// AddWaker associates the given waker to the sleeper. id is the value to be
-// returned when the sleeper is woken by the given waker.
-func (s *Sleeper) AddWaker(w *Waker, id int) {
+// AddWaker associates the given waker to the sleeper.
+func (s *Sleeper) AddWaker(w *Waker) {
+ if w.allWakersNext != nil {
+ panic("waker has non-nil allWakersNext; owned by another sleeper?")
+ }
+ if w.next != nil {
+ panic("waker has non-nil next; queued in another sleeper?")
+ }
+
// Add the waker to the list of all wakers.
w.allWakersNext = s.allWakers
s.allWakers = w
- w.id = id
// Try to associate the waker with the sleeper. If it's already
// asserted, we simply enqueue it in the "ready" list.
@@ -213,28 +218,26 @@ func commitSleep(g uintptr, waitingG unsafe.Pointer) bool {
return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(waitingG), preparingG, g)
}
-// Fetch fetches the next wake-up notification. If a notification is immediately
-// available, it is returned right away. Otherwise, the behavior depends on the
-// value of 'block': if true, the current goroutine blocks until a notification
-// arrives, then returns it; if false, returns 'ok' as false.
-//
-// When 'ok' is true, the value of 'id' corresponds to the id associated with
-// the waker; when 'ok' is false, 'id' is undefined.
+// Fetch fetches the next wake-up notification. If a notification is
+// immediately available, the asserted waker is returned immediately.
+// Otherwise, the behavior depends on the value of 'block': if true, the
+// current goroutine blocks until a notification arrives and returns the
+// asserted waker; if false, nil will be returned.
//
// N.B. This method is *not* thread-safe. Only one goroutine at a time is
// allowed to call this method.
-func (s *Sleeper) Fetch(block bool) (id int, ok bool) {
+func (s *Sleeper) Fetch(block bool) *Waker {
for {
w := s.nextWaker(block)
if w == nil {
- return -1, false
+ return nil
}
// Reassociate the waker with the sleeper. If the waker was
// still asserted we can return it, otherwise try the next one.
old := (*Sleeper)(atomic.SwapPointer(&w.s, usleeper(s)))
if old == &assertedSleeper {
- return w.id, true
+ return w
}
}
}
@@ -243,51 +246,34 @@ func (s *Sleeper) Fetch(block bool) (id int, ok bool) {
// removes the association with all wakers so that they can be safely reused
// by another sleeper after Done() returns.
func (s *Sleeper) Done() {
- // Remove all associations that we can, and build a list of the ones
- // we could not. An association can be removed right away from waker w
- // if w.s has a pointer to the sleeper, that is, the waker is not
- // asserted yet. By atomically switching w.s to nil, we guarantee that
- // subsequent calls to Assert() on the waker will not result in it being
- // queued to this sleeper.
- var pending *Waker
- w := s.allWakers
- for w != nil {
- next := w.allWakersNext
- for {
- t := atomic.LoadPointer(&w.s)
- if t != usleeper(s) {
- w.allWakersNext = pending
- pending = w
- break
- }
-
- if atomic.CompareAndSwapPointer(&w.s, t, nil) {
- break
- }
+ // Remove all associations that we can, and build a list of the ones we
+ // could not. An association can be removed right away from waker w if
+ // w.s has a pointer to the sleeper, that is, the waker is not asserted
+ // yet. By atomically switching w.s to nil, we guarantee that
+ // subsequent calls to Assert() on the waker will not result in it
+ // being queued.
+ for w := s.allWakers; w != nil; w = s.allWakers {
+ next := w.allWakersNext // Before zapping.
+ if atomic.CompareAndSwapPointer(&w.s, usleeper(s), nil) {
+ w.allWakersNext = nil
+ w.next = nil
+ s.allWakers = next // Move ahead.
+ continue
}
- w = next
- }
- // The associations that we could not remove are either asserted, or in
- // the process of being asserted, or have been asserted and cleared
- // before being pulled from the sleeper lists. We must wait for them all
- // to make it to the sleeper lists, so that we know that the wakers
- // won't do any more work towards waking this sleeper up.
- for pending != nil {
- pulled := s.nextWaker(true)
-
- // Remove the waker we just pulled from the list of associated
- // wakers.
- prev := &pending
- for w := *prev; w != nil; w = *prev {
- if pulled == w {
- *prev = w.allWakersNext
- break
+ // Dequeue exactly one waiter from the list, it may not be
+ // this one but we know this one is in the process. We must
+ // leave it in the asserted state but drop it from our lists.
+ if w := s.nextWaker(true); w != nil {
+ prev := &s.allWakers
+ for *prev != w {
+ prev = &((*prev).allWakersNext)
}
- prev = &w.allWakersNext
+ *prev = (*prev).allWakersNext
+ w.allWakersNext = nil
+ w.next = nil
}
}
- s.allWakers = nil
}
// enqueueAssertedWaker enqueues an asserted waker to the "ready" circular list
@@ -349,10 +335,6 @@ type Waker struct {
// allWakersNext is used to form a linked list of all wakers associated
// to a given sleeper.
allWakersNext *Waker
-
- // id is the value to be returned to sleepers when they wake up due to
- // this waker being asserted.
- id int
}
// Assert moves the waker to an asserted state, if it isn't asserted yet. When
diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go
index 3c8cbcdd8..cfa0972c0 100644
--- a/pkg/syncevent/waiter_test.go
+++ b/pkg/syncevent/waiter_test.go
@@ -105,7 +105,7 @@ func BenchmarkWaiterNotifyRedundant(b *testing.B) {
func BenchmarkSleeperNotifyRedundant(b *testing.B) {
var s sleep.Sleeper
var w sleep.Waker
- s.AddWaker(&w, 0)
+ s.AddWaker(&w)
w.Assert()
b.ResetTimer()
@@ -146,7 +146,7 @@ func BenchmarkWaiterNotifyWaitAck(b *testing.B) {
func BenchmarkSleeperNotifyWaitAck(b *testing.B) {
var s sleep.Sleeper
var w sleep.Waker
- s.AddWaker(&w, 0)
+ s.AddWaker(&w)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -197,7 +197,7 @@ func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := wakerPool.Get().(*sleep.Waker)
- s.AddWaker(w, 0)
+ s.AddWaker(w)
w.Assert()
s.Fetch(true)
s.Done()
@@ -237,7 +237,7 @@ func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
s := sleeperPool.Get().(*sleep.Sleeper)
- s.AddWaker(&w, 0)
+ s.AddWaker(&w)
w.Assert()
s.Fetch(true)
s.Done()
@@ -266,14 +266,14 @@ func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) {
var s sleep.Sleeper
var ws [3]sleep.Waker
for i := range ws {
- s.AddWaker(&ws[i], i)
+ s.AddWaker(&ws[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ws[0].Assert()
- if id, _ := s.Fetch(true); id != 0 {
- b.Fatalf("Fetch: got %d, wanted 0", id)
+ if v := s.Fetch(true); v != &ws[0] {
+ b.Fatalf("Fetch: got %v, wanted %v", v, &ws[0])
}
}
}
@@ -325,7 +325,7 @@ func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) {
func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) {
var s sleep.Sleeper
var w sleep.Waker
- s.AddWaker(&w, 0)
+ s.AddWaker(&w)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -374,7 +374,7 @@ func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) {
var s sleep.Sleeper
var ws [3]sleep.Waker
for i := range ws {
- s.AddWaker(&ws[i], i)
+ s.AddWaker(&ws[i])
}
b.ResetTimer()
@@ -382,8 +382,8 @@ func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) {
go func() {
ws[0].Assert()
}()
- if id, _ := s.Fetch(true); id != 0 {
- b.Fatalf("Fetch: got %d, expected 0", id)
+ if v := s.Fetch(true); v != &ws[0] {
+ b.Fatalf("Fetch: got %v, expected %v", v, &ws[0])
}
}
}
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index b41e3e2fa..c15cbf81b 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -73,20 +73,19 @@ func New(lower stack.LinkEndpoint, n int, queueLen int) stack.LinkEndpoint {
}
func (q *queueDispatcher) dispatchLoop() {
- const newPacketWakerID = 1
- const closeWakerID = 2
s := sleep.Sleeper{}
- s.AddWaker(&q.newPacketWaker, newPacketWakerID)
- s.AddWaker(&q.closeWaker, closeWakerID)
+ s.AddWaker(&q.newPacketWaker)
+ s.AddWaker(&q.closeWaker)
defer s.Done()
const batchSize = 32
var batch stack.PacketBufferList
for {
- id, ok := s.Fetch(true)
- if ok && id == closeWakerID {
+ w := s.Fetch(true)
+ if w == &q.closeWaker {
return
}
+ // Must otherwise be the newPacketWaker.
for pkt := q.q.dequeue(); pkt != nil; pkt = q.q.dequeue() {
batch.PushBack(pkt)
if batch.Len() < batchSize && !q.q.empty() {
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index caf14b0dc..d0f68b72c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -762,14 +762,15 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
}()
var s sleep.Sleeper
- s.AddWaker(&e.notificationWaker, wakerForNotification)
- s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
+ s.AddWaker(&e.notificationWaker)
+ s.AddWaker(&e.newSegmentWaker)
+ defer s.Done()
for {
e.mu.Unlock()
- index, _ := s.Fetch(true)
+ w := s.Fetch(true)
e.mu.Lock()
- switch index {
- case wakerForNotification:
+ switch w {
+ case &e.notificationWaker:
n := e.fetchNotifications()
if n&notifyClose != 0 {
return
@@ -788,7 +789,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
}
- case wakerForNewSegment:
+ case &e.newSegmentWaker:
// Process at most maxSegmentsPerWake segments.
mayRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 80cd07218..12df7a7b4 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -51,13 +51,6 @@ const (
handshakeCompleted
)
-// The following are used to set up sleepers.
-const (
- wakerForNotification = iota
- wakerForNewSegment
- wakerForResend
-)
-
const (
// Maximum space available for options.
maxOptionSize = 40
@@ -530,9 +523,9 @@ func (h *handshake) complete() tcpip.Error {
// Set up the wakers.
var s sleep.Sleeper
resendWaker := sleep.Waker{}
- s.AddWaker(&resendWaker, wakerForResend)
- s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
- s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
+ s.AddWaker(&resendWaker)
+ s.AddWaker(&h.ep.notificationWaker)
+ s.AddWaker(&h.ep.newSegmentWaker)
defer s.Done()
// Initialize the resend timer.
@@ -545,11 +538,10 @@ func (h *handshake) complete() tcpip.Error {
// Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
// throughout handshake processing).
h.ep.mu.Unlock()
- index, _ := s.Fetch(true /* block */)
+ w := s.Fetch(true /* block */)
h.ep.mu.Lock()
- switch index {
-
- case wakerForResend:
+ switch w {
+ case &resendWaker:
if err := timer.reset(); err != nil {
return err
}
@@ -577,7 +569,7 @@ func (h *handshake) complete() tcpip.Error {
h.sampleRTTWithTSOnly = true
}
- case wakerForNotification:
+ case &h.ep.notificationWaker:
n := h.ep.fetchNotifications()
if (n&notifyClose)|(n&notifyAbort) != 0 {
return &tcpip.ErrAborted{}
@@ -611,7 +603,7 @@ func (h *handshake) complete() tcpip.Error {
// cleared because of a socket layer call.
return &tcpip.ErrConnectionAborted{}
}
- case wakerForNewSegment:
+ case &h.ep.newSegmentWaker:
if err := h.processSegments(); err != nil {
return err
}
@@ -1346,6 +1338,103 @@ func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer) {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
+// handleWakeup handles a wakeup event while connected.
+//
+// +checklocks:e.mu
+func (e *endpoint) handleWakeup(w, closeWaker *sleep.Waker, closeTimer *tcpip.Timer) tcpip.Error {
+ switch w {
+ case &e.sndQueueInfo.sndWaker:
+ e.sendData(nil /* next */)
+ case &e.newSegmentWaker:
+ return e.handleSegmentsLocked(false /* fastPath */)
+ case &e.snd.resendWaker:
+ if !e.snd.retransmitTimerExpired() {
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
+ return &tcpip.ErrTimeout{}
+ }
+ case closeWaker:
+ // This means the socket is being closed due to the
+ // TCP-FIN-WAIT2 timeout was hit. Just mark the socket as
+ // closed.
+ e.transitionToStateCloseLocked()
+ e.workerCleanup = true
+ case &e.snd.probeWaker:
+ return e.snd.probeTimerExpired()
+ case &e.keepalive.waker:
+ return e.keepaliveTimerExpired()
+ case &e.notificationWaker:
+ n := e.fetchNotifications()
+ if n&notifyNonZeroReceiveWindow != 0 {
+ e.rcv.nonZeroWindow()
+ }
+
+ if n&notifyMTUChanged != 0 {
+ e.sndQueueInfo.sndQueueMu.Lock()
+ count := e.sndQueueInfo.PacketTooBigCount
+ e.sndQueueInfo.PacketTooBigCount = 0
+ mtu := e.sndQueueInfo.SndMTU
+ e.sndQueueInfo.sndQueueMu.Unlock()
+
+ e.snd.updateMaxPayloadSize(mtu, count)
+ }
+
+ if n&notifyReset != 0 || n&notifyAbort != 0 {
+ return &tcpip.ErrConnectionAborted{}
+ }
+
+ if n&notifyResetByPeer != 0 {
+ return &tcpip.ErrConnectionReset{}
+ }
+
+ if n&notifyClose != 0 && e.closed {
+ switch e.EndpointState() {
+ case StateEstablished:
+ // Perform full shutdown if the endpoint is
+ // still established. This can occur when
+ // notifyClose was asserted just before
+ // becoming established.
+ e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ case StateFinWait2:
+ // The socket has been closed and we are in
+ // FIN_WAIT2 so start the FIN_WAIT2 timer.
+ if *closeTimer == nil {
+ *closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
+ }
+ }
+ }
+
+ if n&notifyKeepaliveChanged != 0 {
+ // The timer could fire in background when the endpoint
+ // is drained. That's OK. See above.
+ e.resetKeepaliveTimer(true)
+ }
+
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ if err := e.handleSegmentsLocked(false /* fastPath */); err != nil {
+ return err
+ }
+ }
+ if !e.EndpointState().closed() {
+ // Only block the worker if the endpoint
+ // is not in closed state or error state.
+ close(e.drainDone)
+ e.mu.Unlock()
+ <-e.undrain
+ e.mu.Lock()
+ }
+ }
+
+ // N.B. notifyTickleWorker may be set, but there is no action
+ // to take in this case.
+ case &e.snd.reorderWaker:
+ return e.snd.rc.reorderTimerExpired()
+ default:
+ panic("unknown waker") // Shouldn't happen.
+ }
+ return nil
+}
+
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
@@ -1403,139 +1492,16 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.mu.Lock()
}
- // Set up the functions that will be called when the main protocol loop
- // wakes up.
- funcs := []struct {
- w *sleep.Waker
- f func() tcpip.Error
- }{
- {
- w: &e.sndQueueInfo.sndWaker,
- f: func() tcpip.Error {
- e.sendData(nil /* next */)
- return nil
- },
- },
- {
- w: &closeWaker,
- f: func() tcpip.Error {
- // This means the socket is being closed due
- // to the TCP-FIN-WAIT2 timeout was hit. Just
- // mark the socket as closed.
- e.transitionToStateCloseLocked()
- e.workerCleanup = true
- return nil
- },
- },
- {
- w: &e.snd.resendWaker,
- f: func() tcpip.Error {
- if !e.snd.retransmitTimerExpired() {
- e.stack.Stats().TCP.EstablishedTimedout.Increment()
- return &tcpip.ErrTimeout{}
- }
- return nil
- },
- },
- {
- w: &e.snd.probeWaker,
- f: e.snd.probeTimerExpired,
- },
- {
- w: &e.newSegmentWaker,
- f: func() tcpip.Error {
- return e.handleSegmentsLocked(false /* fastPath */)
- },
- },
- {
- w: &e.keepalive.waker,
- f: e.keepaliveTimerExpired,
- },
- {
- w: &e.notificationWaker,
- f: func() tcpip.Error {
- n := e.fetchNotifications()
- if n&notifyNonZeroReceiveWindow != 0 {
- e.rcv.nonZeroWindow()
- }
-
- if n&notifyMTUChanged != 0 {
- e.sndQueueInfo.sndQueueMu.Lock()
- count := e.sndQueueInfo.PacketTooBigCount
- e.sndQueueInfo.PacketTooBigCount = 0
- mtu := e.sndQueueInfo.SndMTU
- e.sndQueueInfo.sndQueueMu.Unlock()
-
- e.snd.updateMaxPayloadSize(mtu, count)
- }
-
- if n&notifyReset != 0 || n&notifyAbort != 0 {
- return &tcpip.ErrConnectionAborted{}
- }
-
- if n&notifyResetByPeer != 0 {
- return &tcpip.ErrConnectionReset{}
- }
-
- if n&notifyClose != 0 && e.closed {
- switch e.EndpointState() {
- case StateEstablished:
- // Perform full shutdown if the endpoint is still
- // established. This can occur when notifyClose
- // was asserted just before becoming established.
- e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
- case StateFinWait2:
- // The socket has been closed and we are in FIN_WAIT2
- // so start the FIN_WAIT2 timer.
- if closeTimer == nil {
- closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
- }
- }
- }
-
- if n&notifyKeepaliveChanged != 0 {
- // The timer could fire in background
- // when the endpoint is drained. That's
- // OK. See above.
- e.resetKeepaliveTimer(true)
- }
-
- if n&notifyDrain != 0 {
- for !e.segmentQueue.empty() {
- if err := e.handleSegmentsLocked(false /* fastPath */); err != nil {
- return err
- }
- }
- if !e.EndpointState().closed() {
- // Only block the worker if the endpoint
- // is not in closed state or error state.
- close(e.drainDone)
- e.mu.Unlock() // +checklocksforce
- <-e.undrain
- e.mu.Lock()
- }
- }
-
- if n&notifyTickleWorker != 0 {
- // Just a tickle notification. No need to do
- // anything.
- return nil
- }
-
- return nil
- },
- },
- {
- w: &e.snd.reorderWaker,
- f: e.snd.rc.reorderTimerExpired,
- },
- }
-
- // Initialize the sleeper based on the wakers in funcs.
+ // Add all wakers.
var s sleep.Sleeper
- for i := range funcs {
- s.AddWaker(funcs[i].w, i)
- }
+ s.AddWaker(&e.sndQueueInfo.sndWaker)
+ s.AddWaker(&e.newSegmentWaker)
+ s.AddWaker(&e.snd.resendWaker)
+ s.AddWaker(&e.snd.probeWaker)
+ s.AddWaker(&closeWaker)
+ s.AddWaker(&e.keepalive.waker)
+ s.AddWaker(&e.notificationWaker)
+ s.AddWaker(&e.snd.reorderWaker)
// Notify the caller that the waker initialization is complete and the
// endpoint is ready.
@@ -1581,7 +1547,7 @@ loop:
}
e.mu.Unlock()
- v, _ := s.Fetch(true /* block */)
+ w := s.Fetch(true /* block */)
e.mu.Lock()
// We need to double check here because the notification may be
@@ -1601,7 +1567,7 @@ loop:
case StateClose:
break loop
default:
- if err := funcs[v].f(); err != nil {
+ if err := e.handleWakeup(w, &closeWaker, &closeTimer); err != nil {
cleanupOnError(err)
e.protocolMainLoopDone(closeTimer)
return
@@ -1714,26 +1680,22 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
timeWaitDuration = time.Duration(tcpTW)
}
- const newSegment = 1
- const notification = 2
- const timeWaitDone = 3
-
var s sleep.Sleeper
defer s.Done()
- s.AddWaker(&e.newSegmentWaker, newSegment)
- s.AddWaker(&e.notificationWaker, notification)
+ s.AddWaker(&e.newSegmentWaker)
+ s.AddWaker(&e.notificationWaker)
var timeWaitWaker sleep.Waker
- s.AddWaker(&timeWaitWaker, timeWaitDone)
+ s.AddWaker(&timeWaitWaker)
timeWaitTimer := e.stack.Clock().AfterFunc(timeWaitDuration, timeWaitWaker.Assert)
defer timeWaitTimer.Stop()
for {
e.mu.Unlock()
- v, _ := s.Fetch(true /* block */)
+ w := s.Fetch(true /* block */)
e.mu.Lock()
- switch v {
- case newSegment:
+ switch w {
+ case &e.newSegmentWaker:
extendTimeWait, reuseTW := e.handleTimeWaitSegments()
if reuseTW != nil {
return reuseTW
@@ -1741,7 +1703,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
if extendTimeWait {
timeWaitTimer.Reset(timeWaitDuration)
}
- case notification:
+ case &e.notificationWaker:
n := e.fetchNotifications()
if n&notifyAbort != 0 {
return nil
@@ -1759,7 +1721,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
e.mu.Lock()
return nil
}
- case timeWaitDone:
+ case &timeWaitWaker:
return nil
}
}
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 7d110516b..2e93d2664 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -94,9 +94,10 @@ func (p *processor) start(wg *sync.WaitGroup) {
defer p.sleeper.Done()
for {
- if id, _ := p.sleeper.Fetch(true); id == closeWaker {
+ if w := p.sleeper.Fetch(true); w == &p.closeWaker {
break
}
+ // If not the closeWaker, it must be &p.newEndpointWaker.
for {
ep := p.epQ.dequeue()
if ep == nil {
@@ -154,8 +155,8 @@ func (d *dispatcher) init(rng *rand.Rand, nProcessors int) {
d.seed = rng.Uint32()
for i := range d.processors {
p := &d.processors[i]
- p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
- p.sleeper.AddWaker(&p.closeWaker, closeWaker)
+ p.sleeper.AddWaker(&p.newEndpointWaker)
+ p.sleeper.AddWaker(&p.closeWaker)
d.wg.Add(1)
// NB: sleeper-waker registration must happen synchronously to avoid races
// with `close`. It's possible to pull all this logic into `start`, but