diff options
20 files changed, 163 insertions, 152 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 36c17d1ba..91790834b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -830,7 +830,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { +func (s *socketOpsCommon) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -839,7 +839,7 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite // Try to accept the connection again; if it fails, then wait until we // get a notification. for { - if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock { + if ep, wq, err := s.Endpoint.Accept(peerAddr); err != tcpip.ErrWouldBlock { return ep, wq, syserr.TranslateNetstackError(err) } @@ -852,15 +852,18 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, wq, terr := s.Endpoint.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { if terr != tcpip.ErrWouldBlock || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } var err *syserr.Error - ep, wq, err = s.blockingAccept(t) + ep, wq, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -880,13 +883,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer and write it to peer slice. - var err *syserr.Error - addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 1f7d17f5f..0f342e655 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -151,14 +151,18 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs // tcpip.Endpoint. func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. - ep, wq, terr := s.Endpoint.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { if terr != tcpip.ErrWouldBlock || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } var err *syserr.Error - ep, wq, err = s.blockingAccept(t) + ep, wq, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -176,13 +180,9 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 - if peerRequested { + if peerAddr != nil { // Get address of the peer and write it to peer slice. - var err *syserr.Error - addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + addr, addrLen = ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index e3a75b519..aa4f3c04d 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -391,7 +391,7 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error { } // Accept accepts a new connection. -func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) { +func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) { e.Lock() defer e.Unlock() @@ -401,6 +401,18 @@ func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) { select { case ne := <-e.acceptedChan: + if peerAddr != nil { + ne.Lock() + c := ne.connected + ne.Unlock() + if c != nil { + addr, err := c.GetLocalAddress() + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + *peerAddr = addr + } + } return ne, nil default: diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 4751b2fd8..f8aacca13 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -144,12 +144,12 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi } // Listen starts listening on the connection. -func (e *connectionlessEndpoint) Listen(int) *syserr.Error { +func (*connectionlessEndpoint) Listen(int) *syserr.Error { return syserr.ErrNotSupported } // Accept accepts a new connection. -func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) { +func (*connectionlessEndpoint) Accept(*tcpip.FullAddress) (Endpoint, *syserr.Error) { return nil, syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 1200cf9bb..cbbdd000f 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -151,7 +151,10 @@ type Endpoint interface { // block if no new connections are available. // // The returned Queue is the wait queue for the newly created endpoint. - Accept() (Endpoint, *syserr.Error) + // + // peerAddr if not nil will be populated with the address of the connected + // peer on a successful accept. + Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 0a7a26495..616530eb6 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -205,7 +205,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { +func (s *SocketOperations) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -214,7 +214,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Try to accept the connection; if it fails, then wait until we get a // notification. for { - if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock { return ep, err } @@ -227,15 +227,18 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, err := s.ep.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, err := s.ep.Accept(peerAddr) if err != nil { if err != syserr.ErrWouldBlock || !blocking { return 0, nil, 0, err } var err *syserr.Error - ep, err = s.blockingAccept(t) + ep, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -252,13 +255,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer. - var err *syserr.Error - addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 65a285b8f..e25c7e84a 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -96,7 +96,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { +func (s *SocketVFS2) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.socketOpsCommon.EventRegister(&e, waiter.EventIn) @@ -105,7 +105,7 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr // Try to accept the connection; if it fails, then wait until we get a // notification. for { - if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock { return ep, err } @@ -118,15 +118,18 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, err := s.ep.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, err := s.ep.Accept(peerAddr) if err != nil { if err != syserr.ErrWouldBlock || !blocking { return 0, nil, 0, err } var err *syserr.Error - ep, err = s.blockingAccept(t) + ep, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -144,13 +147,8 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer. - var err *syserr.Error - addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 68a954a10..4f551cd92 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -245,7 +245,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { // Accept implements net.Conn.Accept. func (l *TCPListener) Accept() (net.Conn, error) { - n, wq, err := l.ep.Accept() + n, wq, err := l.ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. @@ -254,7 +254,7 @@ func (l *TCPListener) Accept() (net.Conn, error) { defer l.wq.EventUnregister(&waitEntry) for { - n, wq, err = l.ep.Accept() + n, wq, err = l.ep.Accept(nil) if err != tcpip.ErrWouldBlock { break diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 9e37cab18..3f58a15ea 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -188,7 +188,7 @@ func main() { defer wq.EventUnregister(&waitEntry) for { - n, wq, err := ep.Accept() + n, wq, err := ep.Accept(nil) if err != nil { if err == tcpip.ErrWouldBlock { <-notifyCh diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a1458c899..9292bfccb 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -180,7 +180,7 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { return nil } -func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { if len(f.acceptQueue) == 0 { return nil, nil, nil } @@ -631,7 +631,7 @@ func TestTransportForwarding(t *testing.T) { Data: req.ToVectorisedView(), })) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err != nil || aep == nil { t.Fatalf("Accept failed: %v, %v", aep, err) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index b113d8613..8ba615521 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -561,7 +561,10 @@ type Endpoint interface { // block if no new connections are available. // // The returned Queue is the wait queue for the newly created endpoint. - Accept() (Endpoint, *waiter.Queue, *Error) + // + // If peerAddr is not nil then it is populated with the peer address of the + // returned endpoint. + Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 346ca4bda..ad71ff3b6 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -597,7 +597,7 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 81093e9ca..8bd4e5e37 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -192,13 +192,13 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes return ep.ReadPacket(addr, nil) } -func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. return 0, nil, tcpip.ErrInvalidOptionValue } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -210,25 +210,25 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes tcpip.ErrNotSupported. -func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrNotSupported } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { return tcpip.ErrNotSupported } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with // Accept, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -267,12 +267,12 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -371,7 +371,7 @@ func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { +func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return false, tcpip.ErrNotSupported } @@ -508,7 +508,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, } // State implements socket.Socket.State. -func (ep *endpoint) State() uint32 { +func (*endpoint) State() uint32 { return 0 } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 71feeb748..fb03e6047 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -446,12 +446,12 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -482,12 +482,12 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 6074cc24e..80e9dd465 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -371,12 +371,12 @@ func testV4Accept(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -510,13 +510,13 @@ func TestV6AcceptOnV6(t *testing.T) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() + var addr tcpip.FullAddress + nep, _, err := c.EP.Accept(&addr) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(&addr) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -526,20 +526,14 @@ func TestV6AcceptOnV6(t *testing.T) { } } + if addr.Addr != context.TestV6Addr { + t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) + } + // Make sure we can still query the v6 only status of the new endpoint, // that is, that it is in fact a v6 socket. if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { - t.Fatalf("GetSockOpt failed failed: %v", err) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestV6Addr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) + t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err) } } @@ -610,12 +604,12 @@ func testV4ListenClose(t *testing.T, c *context.Context) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 3f18efeef..4cf966b65 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2453,7 +2453,9 @@ func (e *endpoint) startAcceptedLoop() { // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +// +// addr if not-nil will contain the peer address of the returned endpoint. +func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2475,6 +2477,9 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { default: return nil, nil, tcpip.ErrWouldBlock } + if peerAddr != nil { + *peerAddr = n.getRemoteAddress() + } return n, n.waiterQueue, nil } @@ -2577,11 +2582,15 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotConnected } + return e.getRemoteAddress(), nil +} + +func (e *endpoint) getRemoteAddress() tcpip.FullAddress { return tcpip.FullAddress{ Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort, NIC: e.boundNICID, - }, nil + } } func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index adb32e428..3d09d6def 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -291,12 +291,12 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2203,12 +2203,12 @@ func TestScaledWindowAccept(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2277,12 +2277,12 @@ func TestNonScaledWindowAccept(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2840,12 +2840,12 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2895,12 +2895,12 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5135,12 +5135,12 @@ func TestListenBacklogFull(t *testing.T) { defer c.WQ.EventUnregister(&we) for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5152,7 +5152,7 @@ func TestListenBacklogFull(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -5164,12 +5164,12 @@ func TestListenBacklogFull(t *testing.T) { // Now a new handshake must succeed. executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5476,12 +5476,12 @@ func TestListenSynRcvdQueueFull(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5552,12 +5552,12 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5568,7 +5568,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -5657,7 +5657,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { RcvWnd: 30000, }) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err != nil && err != tcpip.ErrWouldBlock { t.Fatalf("Accept failed: %s", err) @@ -5672,7 +5672,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5730,12 +5730,12 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5800,12 +5800,12 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5853,12 +5853,12 @@ func TestEndpointBindListenAcceptState(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - aep, _, err = ep.Accept() + aep, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6293,12 +6293,12 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6412,12 +6412,12 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6519,12 +6519,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6602,12 +6602,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { c.SendPacket(nil, ackHeaders) // Try to accept the connection. - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6675,12 +6675,12 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6824,12 +6824,12 @@ func TestTCPCloseWithData(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -7271,8 +7271,8 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Send data. This should result in an acceptable endpoint. @@ -7293,9 +7293,9 @@ func TestTCPDeferAccept(t *testing.T) { // Give a bit of time for the socket to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() @@ -7329,8 +7329,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Sleep for a little of the tcpDeferAccept timeout. @@ -7362,9 +7362,9 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { // Give sometime for the endpoint to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 1f5340cd0..8bb5e5f6d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -948,12 +948,12 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { c.t.Fatalf("Accept failed: %v", err) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c74bc4d94..2828b2c01 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1218,7 +1218,7 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index ffcd90475..54fee2e82 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -1161,30 +1161,26 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { SyscallSucceeds()); ASSERT_THAT(close(conn_fd.release()), SyscallSucceeds()); - // TODO(gvisor.dev/issue/3780): Remove this. if (IsRunningOnGvisor()) { - // Wait for the RST to be observed. + // Gvisor packet procssing is asynchronous and can take a bit of time in + // some cases so we give it a bit of time to process the RST packet before + // calling accept. + // + // There is nothing to poll() on so we have no choice but to use a sleep + // here. absl::SleepFor(absl::Milliseconds(100)); } sockaddr_storage accept_addr; socklen_t addrlen = sizeof(accept_addr); - // TODO(gvisor.dev/issue/3780): Remove this. - if (IsRunningOnGvisor()) { - ASSERT_THAT(accept(listen_fd.get(), - reinterpret_cast<sockaddr*>(&accept_addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - return; - } - - conn_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept( + auto accept_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept( listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen)); ASSERT_EQ(addrlen, listener.addr_len); int err; socklen_t optlen = sizeof(err); - ASSERT_THAT(getsockopt(conn_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), + ASSERT_THAT(getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), SyscallSucceeds()); ASSERT_EQ(err, ECONNRESET); ASSERT_EQ(optlen, sizeof(err)); |