summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go37
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go92
2 files changed, 114 insertions, 15 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index e2bd57ebf..7acc7e7b0 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -398,25 +398,32 @@ func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
e.mu.Unlock()
defer e.pendingAccepted.Done()
- e.acceptMu.Lock()
- for {
- if e.accepted == (accepted{}) {
- n.notifyProtocolGoroutine(notifyReset)
- break
- }
- if e.accepted.endpoints.Len() == e.accepted.cap {
- e.acceptCond.Wait()
- continue
- }
+ // Drop the lock before notifying to avoid deadlock in user-specified
+ // callbacks.
+ delivered := func() bool {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ for {
+ if e.accepted == (accepted{}) {
+ return false
+ }
+ if e.accepted.endpoints.Len() == e.accepted.cap {
+ e.acceptCond.Wait()
+ continue
+ }
- e.accepted.endpoints.PushBack(n)
- if !withSynCookie {
- atomic.AddInt32(&e.synRcvdCount, -1)
+ e.accepted.endpoints.PushBack(n)
+ if !withSynCookie {
+ atomic.AddInt32(&e.synRcvdCount, -1)
+ }
+ return true
}
+ }()
+ if delivered {
e.waiterQueue.Notify(waiter.ReadableEvents)
- break
+ } else {
+ n.notifyProtocolGoroutine(notifyReset)
}
- e.acceptMu.Unlock()
}
// propagateInheritableOptionsLocked propagates any options set on the listening
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 5605a4390..cda3ecbcd 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1291,6 +1291,98 @@ func TestListenShutdown(t *testing.T) {
))
}
+var _ waiter.EntryCallback = (callback)(nil)
+
+type callback func(*waiter.Entry, waiter.EventMask)
+
+func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) {
+ cb(entry, mask)
+}
+
+func TestListenerReadinessOnEvent(t *testing.T) {
+ s := stack.New(stack.Options{
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
+ {
+ ep := loopback.New()
+ if testing.Verbose() {
+ ep = sniffer.New(ep)
+ }
+ const id = 1
+ if err := s.CreateNIC(id, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
+ }
+ if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil {
+ t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: header.IPv4EmptySubnet, NIC: id},
+ })
+ }
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil {
+ t.Fatalf("Bind(%s): %s", context.StackAddr, err)
+ }
+ const backlog = 1
+ if err := ep.Listen(backlog); err != nil {
+ t.Fatalf("Listen(%d): %s", backlog, err)
+ }
+
+ address, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("GetLocalAddress(): %s", err)
+ }
+
+ conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
+ }
+ defer conn.Close()
+
+ events := make(chan waiter.EventMask)
+ // Scope `entry` to allow a binding of the same name below.
+ {
+ entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) {
+ events <- ep.Readiness(mask)
+ })}
+ wq.EventRegister(&entry, waiter.EventIn)
+ defer wq.EventUnregister(&entry)
+ }
+
+ entry, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&entry, waiter.EventOut)
+ defer wq.EventUnregister(&entry)
+
+ switch err := conn.Connect(address).(type) {
+ case *tcpip.ErrConnectStarted:
+ default:
+ t.Fatalf("Connect(%#v): %v", address, err)
+ }
+
+ // Read at least one event.
+ got := <-events
+ for {
+ select {
+ case event := <-events:
+ got |= event
+ continue
+ case <-ch:
+ if want := waiter.ReadableEvents; got != want {
+ t.Errorf("observed events = %b, want %b", got, want)
+ }
+ }
+ break
+ }
+}
+
// TestListenCloseWhileConnect tests for the listening endpoint to
// drain the accept-queue when closed. This should reset all of the
// pending connections that are waiting to be accepted.