diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 37 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 92 |
2 files changed, 114 insertions, 15 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index e2bd57ebf..7acc7e7b0 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -398,25 +398,32 @@ func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { e.mu.Unlock() defer e.pendingAccepted.Done() - e.acceptMu.Lock() - for { - if e.accepted == (accepted{}) { - n.notifyProtocolGoroutine(notifyReset) - break - } - if e.accepted.endpoints.Len() == e.accepted.cap { - e.acceptCond.Wait() - continue - } + // Drop the lock before notifying to avoid deadlock in user-specified + // callbacks. + delivered := func() bool { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + for { + if e.accepted == (accepted{}) { + return false + } + if e.accepted.endpoints.Len() == e.accepted.cap { + e.acceptCond.Wait() + continue + } - e.accepted.endpoints.PushBack(n) - if !withSynCookie { - atomic.AddInt32(&e.synRcvdCount, -1) + e.accepted.endpoints.PushBack(n) + if !withSynCookie { + atomic.AddInt32(&e.synRcvdCount, -1) + } + return true } + }() + if delivered { e.waiterQueue.Notify(waiter.ReadableEvents) - break + } else { + n.notifyProtocolGoroutine(notifyReset) } - e.acceptMu.Unlock() } // propagateInheritableOptionsLocked propagates any options set on the listening diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 5605a4390..cda3ecbcd 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1291,6 +1291,98 @@ func TestListenShutdown(t *testing.T) { )) } +var _ waiter.EntryCallback = (callback)(nil) + +type callback func(*waiter.Entry, waiter.EventMask) + +func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) { + cb(entry, mask) +} + +func TestListenerReadinessOnEvent(t *testing.T) { + s := stack.New(stack.Options{ + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + { + ep := loopback.New() + if testing.Verbose() { + ep = sniffer.New(ep) + } + const id = 1 + if err := s.CreateNIC(id, ep); err != nil { + t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) + } + if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil { + t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: header.IPv4EmptySubnet, NIC: id}, + }) + } + + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil { + t.Fatalf("Bind(%s): %s", context.StackAddr, err) + } + const backlog = 1 + if err := ep.Listen(backlog); err != nil { + t.Fatalf("Listen(%d): %s", backlog, err) + } + + address, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress(): %s", err) + } + + conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) + } + defer conn.Close() + + events := make(chan waiter.EventMask) + // Scope `entry` to allow a binding of the same name below. + { + entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) { + events <- ep.Readiness(mask) + })} + wq.EventRegister(&entry, waiter.EventIn) + defer wq.EventUnregister(&entry) + } + + entry, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&entry, waiter.EventOut) + defer wq.EventUnregister(&entry) + + switch err := conn.Connect(address).(type) { + case *tcpip.ErrConnectStarted: + default: + t.Fatalf("Connect(%#v): %v", address, err) + } + + // Read at least one event. + got := <-events + for { + select { + case event := <-events: + got |= event + continue + case <-ch: + if want := waiter.ReadableEvents; got != want { + t.Errorf("observed events = %b, want %b", got, want) + } + } + break + } +} + // TestListenCloseWhileConnect tests for the listening endpoint to // drain the accept-queue when closed. This should reset all of the // pending connections that are waiting to be accepted. |