summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/netstack/netstack.go22
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go16
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go14
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go4
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go5
-rw-r--r--pkg/sentry/socket/unix/unix.go22
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go22
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go4
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go2
-rw-r--r--pkg/tcpip/stack/transport_test.go4
-rw-r--r--pkg/tcpip/tcpip.go5
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go20
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go8
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go30
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go13
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go96
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go4
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc20
20 files changed, 163 insertions, 152 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 36c17d1ba..91790834b 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -830,7 +830,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
// blockingAccept implements a blocking version of accept(2), that is, if no
// connections are ready to be accept, it will block until one becomes ready.
-func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
+func (s *socketOpsCommon) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
// Register for notifications.
e, ch := waiter.NewChannelEntry(nil)
s.EventRegister(&e, waiter.EventIn)
@@ -839,7 +839,7 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite
// Try to accept the connection again; if it fails, then wait until we
// get a notification.
for {
- if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock {
+ if ep, wq, err := s.Endpoint.Accept(peerAddr); err != tcpip.ErrWouldBlock {
return ep, wq, syserr.TranslateNetstackError(err)
}
@@ -852,15 +852,18 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite
// Accept implements the linux syscall accept(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- // Issue the accept request to get the new endpoint.
- ep, wq, terr := s.Endpoint.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, wq, terr := s.Endpoint.Accept(peerAddr)
if terr != nil {
if terr != tcpip.ErrWouldBlock || !blocking {
return 0, nil, 0, syserr.TranslateNetstackError(terr)
}
var err *syserr.Error
- ep, wq, err = s.blockingAccept(t)
+ ep, wq, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -880,13 +883,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
- // Get address of the peer and write it to peer slice.
- var err *syserr.Error
- addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ if peerAddr != nil {
+ addr, addrLen = ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 1f7d17f5f..0f342e655 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -151,14 +151,18 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
// tcpip.Endpoint.
func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
- ep, wq, terr := s.Endpoint.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, wq, terr := s.Endpoint.Accept(peerAddr)
if terr != nil {
if terr != tcpip.ErrWouldBlock || !blocking {
return 0, nil, 0, syserr.TranslateNetstackError(terr)
}
var err *syserr.Error
- ep, wq, err = s.blockingAccept(t)
+ ep, wq, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -176,13 +180,9 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
+ if peerAddr != nil {
// Get address of the peer and write it to peer slice.
- var err *syserr.Error
- addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ addr, addrLen = ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index e3a75b519..aa4f3c04d 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -391,7 +391,7 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error {
}
// Accept accepts a new connection.
-func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) {
+func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) {
e.Lock()
defer e.Unlock()
@@ -401,6 +401,18 @@ func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) {
select {
case ne := <-e.acceptedChan:
+ if peerAddr != nil {
+ ne.Lock()
+ c := ne.connected
+ ne.Unlock()
+ if c != nil {
+ addr, err := c.GetLocalAddress()
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ *peerAddr = addr
+ }
+ }
return ne, nil
default:
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 4751b2fd8..f8aacca13 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -144,12 +144,12 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi
}
// Listen starts listening on the connection.
-func (e *connectionlessEndpoint) Listen(int) *syserr.Error {
+func (*connectionlessEndpoint) Listen(int) *syserr.Error {
return syserr.ErrNotSupported
}
// Accept accepts a new connection.
-func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) {
+func (*connectionlessEndpoint) Accept(*tcpip.FullAddress) (Endpoint, *syserr.Error) {
return nil, syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 1200cf9bb..cbbdd000f 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -151,7 +151,10 @@ type Endpoint interface {
// block if no new connections are available.
//
// The returned Queue is the wait queue for the newly created endpoint.
- Accept() (Endpoint, *syserr.Error)
+ //
+ // peerAddr if not nil will be populated with the address of the connected
+ // peer on a successful accept.
+ Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error)
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 0a7a26495..616530eb6 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -205,7 +205,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
// blockingAccept implements a blocking version of accept(2), that is, if no
// connections are ready to be accept, it will block until one becomes ready.
-func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+func (s *SocketOperations) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) {
// Register for notifications.
e, ch := waiter.NewChannelEntry(nil)
s.EventRegister(&e, waiter.EventIn)
@@ -214,7 +214,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *
// Try to accept the connection; if it fails, then wait until we get a
// notification.
for {
- if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock {
return ep, err
}
@@ -227,15 +227,18 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *
// Accept implements the linux syscall accept(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- // Issue the accept request to get the new endpoint.
- ep, err := s.ep.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, err := s.ep.Accept(peerAddr)
if err != nil {
if err != syserr.ErrWouldBlock || !blocking {
return 0, nil, 0, err
}
var err *syserr.Error
- ep, err = s.blockingAccept(t)
+ ep, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -252,13 +255,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
- // Get address of the peer.
- var err *syserr.Error
- addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ if peerAddr != nil {
+ addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 65a285b8f..e25c7e84a 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -96,7 +96,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
// blockingAccept implements a blocking version of accept(2), that is, if no
// connections are ready to be accept, it will block until one becomes ready.
-func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+func (s *SocketVFS2) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) {
// Register for notifications.
e, ch := waiter.NewChannelEntry(nil)
s.socketOpsCommon.EventRegister(&e, waiter.EventIn)
@@ -105,7 +105,7 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr
// Try to accept the connection; if it fails, then wait until we get a
// notification.
for {
- if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock {
return ep, err
}
@@ -118,15 +118,18 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr
// Accept implements the linux syscall accept(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- // Issue the accept request to get the new endpoint.
- ep, err := s.ep.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, err := s.ep.Accept(peerAddr)
if err != nil {
if err != syserr.ErrWouldBlock || !blocking {
return 0, nil, 0, err
}
var err *syserr.Error
- ep, err = s.blockingAccept(t)
+ ep, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -144,13 +147,8 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
- // Get address of the peer.
- var err *syserr.Error
- addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ if peerAddr != nil {
+ addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 68a954a10..4f551cd92 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -245,7 +245,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
// Accept implements net.Conn.Accept.
func (l *TCPListener) Accept() (net.Conn, error) {
- n, wq, err := l.ep.Accept()
+ n, wq, err := l.ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
@@ -254,7 +254,7 @@ func (l *TCPListener) Accept() (net.Conn, error) {
defer l.wq.EventUnregister(&waitEntry)
for {
- n, wq, err = l.ep.Accept()
+ n, wq, err = l.ep.Accept(nil)
if err != tcpip.ErrWouldBlock {
break
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 9e37cab18..3f58a15ea 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -188,7 +188,7 @@ func main() {
defer wq.EventUnregister(&waitEntry)
for {
- n, wq, err := ep.Accept()
+ n, wq, err := ep.Accept(nil)
if err != nil {
if err == tcpip.ErrWouldBlock {
<-notifyCh
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index a1458c899..9292bfccb 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -180,7 +180,7 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
return nil
}
-func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
if len(f.acceptQueue) == 0 {
return nil, nil, nil
}
@@ -631,7 +631,7 @@ func TestTransportForwarding(t *testing.T) {
Data: req.ToVectorisedView(),
}))
- aep, _, err := ep.Accept()
+ aep, _, err := ep.Accept(nil)
if err != nil || aep == nil {
t.Fatalf("Accept failed: %v, %v", aep, err)
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index b113d8613..8ba615521 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -561,7 +561,10 @@ type Endpoint interface {
// block if no new connections are available.
//
// The returned Queue is the wait queue for the newly created endpoint.
- Accept() (Endpoint, *waiter.Queue, *Error)
+ //
+ // If peerAddr is not nil then it is populated with the peer address of the
+ // returned endpoint.
+ Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error)
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 346ca4bda..ad71ff3b6 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -597,7 +597,7 @@ func (*endpoint) Listen(int) *tcpip.Error {
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 81093e9ca..8bd4e5e37 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -192,13 +192,13 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
return ep.ReadPacket(addr, nil)
}
-func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// TODO(gvisor.dev/issue/173): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
}
// Peek implements tcpip.Endpoint.Peek.
-func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -210,25 +210,25 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
// connected, and this function always returnes tcpip.ErrNotSupported.
-func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
// with Shutdown, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
// Listen, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+func (*endpoint) Listen(backlog int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
// Accept, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
@@ -267,12 +267,12 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// Even a connected socket doesn't return a remote address.
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -371,7 +371,7 @@ func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error {
}
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return false, tcpip.ErrNotSupported
}
@@ -508,7 +508,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
}
// State implements socket.Socket.State.
-func (ep *endpoint) State() uint32 {
+func (*endpoint) State() uint32 {
return 0
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 71feeb748..fb03e6047 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -446,12 +446,12 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Listen implements tcpip.Endpoint.Listen.
-func (e *endpoint) Listen(backlog int) *tcpip.Error {
+func (*endpoint) Listen(backlog int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept implements tcpip.Endpoint.Accept.
-func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
@@ -482,12 +482,12 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// Even a connected socket doesn't return a remote address.
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 6074cc24e..80e9dd465 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -371,12 +371,12 @@ func testV4Accept(t *testing.T, c *context.Context) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- nep, _, err := c.EP.Accept()
+ nep, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -510,13 +510,13 @@ func TestV6AcceptOnV6(t *testing.T) {
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
-
- nep, _, err := c.EP.Accept()
+ var addr tcpip.FullAddress
+ nep, _, err := c.EP.Accept(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(&addr)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -526,20 +526,14 @@ func TestV6AcceptOnV6(t *testing.T) {
}
}
+ if addr.Addr != context.TestV6Addr {
+ t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr)
+ }
+
// Make sure we can still query the v6 only status of the new endpoint,
// that is, that it is in fact a v6 socket.
if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
- t.Fatalf("GetSockOpt failed failed: %v", err)
- }
-
- // Check the peer address.
- addr, err := nep.GetRemoteAddress()
- if err != nil {
- t.Fatalf("GetRemoteAddress failed failed: %v", err)
- }
-
- if addr.Addr != context.TestV6Addr {
- t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr)
+ t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err)
}
}
@@ -610,12 +604,12 @@ func testV4ListenClose(t *testing.T, c *context.Context) {
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- nep, _, err := c.EP.Accept()
+ nep, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 3f18efeef..4cf966b65 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2453,7 +2453,9 @@ func (e *endpoint) startAcceptedLoop() {
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode.
-func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+//
+// addr if not-nil will contain the peer address of the returned endpoint.
+func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -2475,6 +2477,9 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
default:
return nil, nil, tcpip.ErrWouldBlock
}
+ if peerAddr != nil {
+ *peerAddr = n.getRemoteAddress()
+ }
return n, n.waiterQueue, nil
}
@@ -2577,11 +2582,15 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
+ return e.getRemoteAddress(), nil
+}
+
+func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
return tcpip.FullAddress{
Addr: e.ID.RemoteAddress,
Port: e.ID.RemotePort,
NIC: e.boundNICID,
- }, nil
+ }
}
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index adb32e428..3d09d6def 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -291,12 +291,12 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2203,12 +2203,12 @@ func TestScaledWindowAccept(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2277,12 +2277,12 @@ func TestNonScaledWindowAccept(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2840,12 +2840,12 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2895,12 +2895,12 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5135,12 +5135,12 @@ func TestListenBacklogFull(t *testing.T) {
defer c.WQ.EventUnregister(&we)
for i := 0; i < listenBacklog; i++ {
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5152,7 +5152,7 @@ func TestListenBacklogFull(t *testing.T) {
}
// Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != tcpip.ErrWouldBlock {
select {
case <-ch:
@@ -5164,12 +5164,12 @@ func TestListenBacklogFull(t *testing.T) {
// Now a new handshake must succeed.
executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */)
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5476,12 +5476,12 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5552,12 +5552,12 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5568,7 +5568,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
// Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != tcpip.ErrWouldBlock {
select {
case <-ch:
@@ -5657,7 +5657,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
RcvWnd: 30000,
})
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err != nil && err != tcpip.ErrWouldBlock {
t.Fatalf("Accept failed: %s", err)
@@ -5672,7 +5672,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5730,12 +5730,12 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
defer c.WQ.EventUnregister(&we)
// Verify that there is only one acceptable connection at this point.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5800,12 +5800,12 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
defer c.WQ.EventUnregister(&we)
// Now check that there is one acceptable connections.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5853,12 +5853,12 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- aep, _, err := ep.Accept()
+ aep, _, err := ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- aep, _, err = ep.Accept()
+ aep, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6293,12 +6293,12 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6412,12 +6412,12 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6519,12 +6519,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6602,12 +6602,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
c.SendPacket(nil, ackHeaders)
// Try to accept the connection.
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6675,12 +6675,12 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6824,12 +6824,12 @@ func TestTCPCloseWithData(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -7271,8 +7271,8 @@ func TestTCPDeferAccept(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
+ if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Send data. This should result in an acceptable endpoint.
@@ -7293,9 +7293,9 @@ func TestTCPDeferAccept(t *testing.T) {
// Give a bit of time for the socket to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept()
+ aep, _, err := c.EP.Accept(nil)
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
}
aep.Close()
@@ -7329,8 +7329,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
+ if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Sleep for a little of the tcpDeferAccept timeout.
@@ -7362,9 +7362,9 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
// Give sometime for the endpoint to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept()
+ aep, _, err := c.EP.Accept(nil)
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
}
aep.Close()
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 1f5340cd0..8bb5e5f6d 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -948,12 +948,12 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
c.t.Fatalf("Accept failed: %v", err)
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index c74bc4d94..2828b2c01 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1218,7 +1218,7 @@ func (*endpoint) Listen(int) *tcpip.Error {
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index ffcd90475..54fee2e82 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -1161,30 +1161,26 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
SyscallSucceeds());
ASSERT_THAT(close(conn_fd.release()), SyscallSucceeds());
- // TODO(gvisor.dev/issue/3780): Remove this.
if (IsRunningOnGvisor()) {
- // Wait for the RST to be observed.
+ // Gvisor packet procssing is asynchronous and can take a bit of time in
+ // some cases so we give it a bit of time to process the RST packet before
+ // calling accept.
+ //
+ // There is nothing to poll() on so we have no choice but to use a sleep
+ // here.
absl::SleepFor(absl::Milliseconds(100));
}
sockaddr_storage accept_addr;
socklen_t addrlen = sizeof(accept_addr);
- // TODO(gvisor.dev/issue/3780): Remove this.
- if (IsRunningOnGvisor()) {
- ASSERT_THAT(accept(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&accept_addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
- return;
- }
-
- conn_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept(
+ auto accept_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept(
listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen));
ASSERT_EQ(addrlen, listener.addr_len);
int err;
socklen_t optlen = sizeof(err);
- ASSERT_THAT(getsockopt(conn_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
+ ASSERT_THAT(getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
SyscallSucceeds());
ASSERT_EQ(err, ECONNRESET);
ASSERT_EQ(optlen, sizeof(err));