diff options
author | Bhasker Hariharan <bhaskerh@google.com> | 2020-09-02 18:19:50 -0700 |
---|---|---|
committer | Andrei Vagin <avagin@gmail.com> | 2020-09-09 17:53:10 -0700 |
commit | 89185098fae9cfe2725e7645aeb2ebd57aa26320 (patch) | |
tree | 84f260a8497acc75d20c53d0c06e453a1a734083 | |
parent | f33077c875523d5f33edf8c395300e29e7c9e609 (diff) |
Fix Accept to not return error for sockets in accept queue.
Accept on gVisor will return an error if a socket in the accept queue was closed
before Accept() was called. Linux will return the new fd even if the returned
socket is already closed by the peer say due to a RST being sent by the peer.
This seems to be intentional in linux more details on the github issue.
Fixes #3780
PiperOrigin-RevId: 329828404
20 files changed, 163 insertions, 152 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 36c17d1ba..91790834b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -830,7 +830,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { +func (s *socketOpsCommon) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -839,7 +839,7 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite // Try to accept the connection again; if it fails, then wait until we // get a notification. for { - if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock { + if ep, wq, err := s.Endpoint.Accept(peerAddr); err != tcpip.ErrWouldBlock { return ep, wq, syserr.TranslateNetstackError(err) } @@ -852,15 +852,18 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, wq, terr := s.Endpoint.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { if terr != tcpip.ErrWouldBlock || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } var err *syserr.Error - ep, wq, err = s.blockingAccept(t) + ep, wq, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -880,13 +883,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer and write it to peer slice. - var err *syserr.Error - addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 1f7d17f5f..0f342e655 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -151,14 +151,18 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs // tcpip.Endpoint. func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. - ep, wq, terr := s.Endpoint.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { if terr != tcpip.ErrWouldBlock || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } var err *syserr.Error - ep, wq, err = s.blockingAccept(t) + ep, wq, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -176,13 +180,9 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 - if peerRequested { + if peerAddr != nil { // Get address of the peer and write it to peer slice. - var err *syserr.Error - addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + addr, addrLen = ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index e3a75b519..aa4f3c04d 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -391,7 +391,7 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error { } // Accept accepts a new connection. -func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) { +func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) { e.Lock() defer e.Unlock() @@ -401,6 +401,18 @@ func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) { select { case ne := <-e.acceptedChan: + if peerAddr != nil { + ne.Lock() + c := ne.connected + ne.Unlock() + if c != nil { + addr, err := c.GetLocalAddress() + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + *peerAddr = addr + } + } return ne, nil default: diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 4751b2fd8..f8aacca13 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -144,12 +144,12 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi } // Listen starts listening on the connection. -func (e *connectionlessEndpoint) Listen(int) *syserr.Error { +func (*connectionlessEndpoint) Listen(int) *syserr.Error { return syserr.ErrNotSupported } // Accept accepts a new connection. -func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) { +func (*connectionlessEndpoint) Accept(*tcpip.FullAddress) (Endpoint, *syserr.Error) { return nil, syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 1200cf9bb..cbbdd000f 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -151,7 +151,10 @@ type Endpoint interface { // block if no new connections are available. // // The returned Queue is the wait queue for the newly created endpoint. - Accept() (Endpoint, *syserr.Error) + // + // peerAddr if not nil will be populated with the address of the connected + // peer on a successful accept. + Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 0a7a26495..616530eb6 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -205,7 +205,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { +func (s *SocketOperations) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -214,7 +214,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Try to accept the connection; if it fails, then wait until we get a // notification. for { - if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock { return ep, err } @@ -227,15 +227,18 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, err := s.ep.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, err := s.ep.Accept(peerAddr) if err != nil { if err != syserr.ErrWouldBlock || !blocking { return 0, nil, 0, err } var err *syserr.Error - ep, err = s.blockingAccept(t) + ep, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -252,13 +255,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer. - var err *syserr.Error - addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 65a285b8f..e25c7e84a 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -96,7 +96,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { +func (s *SocketVFS2) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.socketOpsCommon.EventRegister(&e, waiter.EventIn) @@ -105,7 +105,7 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr // Try to accept the connection; if it fails, then wait until we get a // notification. for { - if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock { return ep, err } @@ -118,15 +118,18 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, err := s.ep.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, err := s.ep.Accept(peerAddr) if err != nil { if err != syserr.ErrWouldBlock || !blocking { return 0, nil, 0, err } var err *syserr.Error - ep, err = s.blockingAccept(t) + ep, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -144,13 +147,8 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer. - var err *syserr.Error - addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 68a954a10..4f551cd92 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -245,7 +245,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { // Accept implements net.Conn.Accept. func (l *TCPListener) Accept() (net.Conn, error) { - n, wq, err := l.ep.Accept() + n, wq, err := l.ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. @@ -254,7 +254,7 @@ func (l *TCPListener) Accept() (net.Conn, error) { defer l.wq.EventUnregister(&waitEntry) for { - n, wq, err = l.ep.Accept() + n, wq, err = l.ep.Accept(nil) if err != tcpip.ErrWouldBlock { break diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 9e37cab18..3f58a15ea 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -188,7 +188,7 @@ func main() { defer wq.EventUnregister(&waitEntry) for { - n, wq, err := ep.Accept() + n, wq, err := ep.Accept(nil) if err != nil { if err == tcpip.ErrWouldBlock { <-notifyCh diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a1458c899..9292bfccb 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -180,7 +180,7 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { return nil } -func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { if len(f.acceptQueue) == 0 { return nil, nil, nil } @@ -631,7 +631,7 @@ func TestTransportForwarding(t *testing.T) { Data: req.ToVectorisedView(), })) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err != nil || aep == nil { t.Fatalf("Accept failed: %v, %v", aep, err) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index b113d8613..8ba615521 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -561,7 +561,10 @@ type Endpoint interface { // block if no new connections are available. // // The returned Queue is the wait queue for the newly created endpoint. - Accept() (Endpoint, *waiter.Queue, *Error) + // + // If peerAddr is not nil then it is populated with the peer address of the + // returned endpoint. + Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 346ca4bda..ad71ff3b6 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -597,7 +597,7 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 81093e9ca..8bd4e5e37 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -192,13 +192,13 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes return ep.ReadPacket(addr, nil) } -func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. return 0, nil, tcpip.ErrInvalidOptionValue } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -210,25 +210,25 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes tcpip.ErrNotSupported. -func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrNotSupported } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { return tcpip.ErrNotSupported } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with // Accept, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -267,12 +267,12 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -371,7 +371,7 @@ func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { +func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return false, tcpip.ErrNotSupported } @@ -508,7 +508,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, } // State implements socket.Socket.State. -func (ep *endpoint) State() uint32 { +func (*endpoint) State() uint32 { return 0 } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 71feeb748..fb03e6047 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -446,12 +446,12 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -482,12 +482,12 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 6074cc24e..80e9dd465 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -371,12 +371,12 @@ func testV4Accept(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -510,13 +510,13 @@ func TestV6AcceptOnV6(t *testing.T) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() + var addr tcpip.FullAddress + nep, _, err := c.EP.Accept(&addr) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(&addr) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -526,20 +526,14 @@ func TestV6AcceptOnV6(t *testing.T) { } } + if addr.Addr != context.TestV6Addr { + t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) + } + // Make sure we can still query the v6 only status of the new endpoint, // that is, that it is in fact a v6 socket. if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { - t.Fatalf("GetSockOpt failed failed: %v", err) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestV6Addr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) + t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err) } } @@ -610,12 +604,12 @@ func testV4ListenClose(t *testing.T, c *context.Context) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 3f18efeef..4cf966b65 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2453,7 +2453,9 @@ func (e *endpoint) startAcceptedLoop() { // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +// +// addr if not-nil will contain the peer address of the returned endpoint. +func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2475,6 +2477,9 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { default: return nil, nil, tcpip.ErrWouldBlock } + if peerAddr != nil { + *peerAddr = n.getRemoteAddress() + } return n, n.waiterQueue, nil } @@ -2577,11 +2582,15 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotConnected } + return e.getRemoteAddress(), nil +} + +func (e *endpoint) getRemoteAddress() tcpip.FullAddress { return tcpip.FullAddress{ Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort, NIC: e.boundNICID, - }, nil + } } func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index adb32e428..3d09d6def 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -291,12 +291,12 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2203,12 +2203,12 @@ func TestScaledWindowAccept(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2277,12 +2277,12 @@ func TestNonScaledWindowAccept(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2840,12 +2840,12 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2895,12 +2895,12 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5135,12 +5135,12 @@ func TestListenBacklogFull(t *testing.T) { defer c.WQ.EventUnregister(&we) for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5152,7 +5152,7 @@ func TestListenBacklogFull(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -5164,12 +5164,12 @@ func TestListenBacklogFull(t *testing.T) { // Now a new handshake must succeed. executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5476,12 +5476,12 @@ func TestListenSynRcvdQueueFull(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5552,12 +5552,12 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5568,7 +5568,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -5657,7 +5657,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { RcvWnd: 30000, }) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err != nil && err != tcpip.ErrWouldBlock { t.Fatalf("Accept failed: %s", err) @@ -5672,7 +5672,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5730,12 +5730,12 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5800,12 +5800,12 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5853,12 +5853,12 @@ func TestEndpointBindListenAcceptState(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - aep, _, err = ep.Accept() + aep, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6293,12 +6293,12 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6412,12 +6412,12 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6519,12 +6519,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6602,12 +6602,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { c.SendPacket(nil, ackHeaders) // Try to accept the connection. - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6675,12 +6675,12 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6824,12 +6824,12 @@ func TestTCPCloseWithData(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -7271,8 +7271,8 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Send data. This should result in an acceptable endpoint. @@ -7293,9 +7293,9 @@ func TestTCPDeferAccept(t *testing.T) { // Give a bit of time for the socket to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() @@ -7329,8 +7329,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Sleep for a little of the tcpDeferAccept timeout. @@ -7362,9 +7362,9 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { // Give sometime for the endpoint to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 1f5340cd0..8bb5e5f6d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -948,12 +948,12 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { c.t.Fatalf("Accept failed: %v", err) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c74bc4d94..2828b2c01 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1218,7 +1218,7 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index ffcd90475..54fee2e82 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -1161,30 +1161,26 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { SyscallSucceeds()); ASSERT_THAT(close(conn_fd.release()), SyscallSucceeds()); - // TODO(gvisor.dev/issue/3780): Remove this. if (IsRunningOnGvisor()) { - // Wait for the RST to be observed. + // Gvisor packet procssing is asynchronous and can take a bit of time in + // some cases so we give it a bit of time to process the RST packet before + // calling accept. + // + // There is nothing to poll() on so we have no choice but to use a sleep + // here. absl::SleepFor(absl::Milliseconds(100)); } sockaddr_storage accept_addr; socklen_t addrlen = sizeof(accept_addr); - // TODO(gvisor.dev/issue/3780): Remove this. - if (IsRunningOnGvisor()) { - ASSERT_THAT(accept(listen_fd.get(), - reinterpret_cast<sockaddr*>(&accept_addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - return; - } - - conn_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept( + auto accept_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept( listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen)); ASSERT_EQ(addrlen, listener.addr_len); int err; socklen_t optlen = sizeof(err); - ASSERT_THAT(getsockopt(conn_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), + ASSERT_THAT(getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), SyscallSucceeds()); ASSERT_EQ(err, ECONNRESET); ASSERT_EQ(optlen, sizeof(err)); |