summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go45
-rw-r--r--pkg/sentry/socket/unix/unix.go2
-rw-r--r--pkg/sentry/strace/socket.go2
-rw-r--r--pkg/tcpip/stack/transport_test.go5
-rw-r--r--pkg/tcpip/tcpip.go3
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go10
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go10
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go10
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go6
-rw-r--r--test/syscalls/linux/udp_socket.cc124
10 files changed, 160 insertions, 57 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 4d8a5ac22..635042263 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -291,18 +291,22 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
return tcpip.Address(addr)
}
-// GetAddress reads an sockaddr struct from the given address and converts it
-// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
-// addresses.
-func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) {
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and
+// AF_INET6 addresses.
+//
+// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
+//
+// AddressAndFamily returns an address, its family.
+func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
}
family := usermem.ByteOrder.Uint16(addr)
if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
- return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported
}
// Get the rest of the fields based on the address family.
@@ -310,7 +314,7 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
case linux.AF_UNIX:
path := addr[2:]
if len(path) > linux.UnixPathMax {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
// Drop the terminating NUL (if one exists) and everything after
// it for filesystem (non-abstract) addresses.
@@ -321,12 +325,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
}
return tcpip.FullAddress{
Addr: tcpip.Address(path),
- }, nil
+ }, family, nil
case linux.AF_INET:
var a linux.SockAddrInet
if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
@@ -334,12 +338,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
Addr: bytesToIPAddress(a.Addr[:]),
Port: ntohs(a.Port),
}
- return out, nil
+ return out, family, nil
case linux.AF_INET6:
var a linux.SockAddrInet6
if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
@@ -350,13 +354,13 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
if isLinkLocal(out.Addr) {
out.NIC = tcpip.NICID(a.Scope_id)
}
- return out, nil
+ return out, family, nil
case linux.AF_UNSPEC:
- return tcpip.FullAddress{}, nil
+ return tcpip.FullAddress{}, family, nil
default:
- return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
}
}
@@ -482,11 +486,18 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr, false /* strict */)
+ addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */)
if err != nil {
return err
}
+ if family == linux.AF_UNSPEC {
+ err := s.Endpoint.Disconnect()
+ if err == tcpip.ErrNotSupported {
+ return syserr.ErrAddressFamilyNotSupported
+ }
+ return syserr.TranslateNetstackError(err)
+ }
// Always return right away in the non-blocking case.
if !blocking {
return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
@@ -515,7 +526,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr, true /* strict */)
+ addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */)
if err != nil {
return err
}
@@ -2023,7 +2034,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, err := GetAddress(s.family, to, true /* strict */)
+ addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */)
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 8a3f65236..0d0cb68df 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -116,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */)
+ addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
if err != nil {
return "", err
}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index 386b40af7..f779186ad 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
- fa, err := epsocket.GetAddress(int(family), b, true /* strict */)
+ fa, _, err := epsocket.AddressAndFamily(int(family), b, true /* strict */)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 8146e8444..1c811ab68 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -105,6 +105,11 @@ func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*fakeTransportEndpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index ba2dd85b8..29a6025d9 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -358,6 +358,9 @@ type Endpoint interface {
// ErrAddressFamilyNotSupported must be returned.
Connect(address FullAddress) *Error
+ // Disconnect disconnects the endpoint from its peer.
+ Disconnect() *Error
+
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
Shutdown(flags ShutdownFlags) *Error
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 2e8d5d4bf..451d3880e 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -428,16 +428,16 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if addr.Addr == "" {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
nicid := addr.NIC
localPort := uint16(0)
switch e.state {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 53c9515a4..13e17e2a6 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -349,16 +349,16 @@ func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error)
return 0, tcpip.ControlMessages{}, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect implements tcpip.Endpoint.Connect.
func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if addr.Addr == "" {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
if ep.closed {
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index e5f835c20..24b32e4af 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1362,13 +1362,13 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
return netProto, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- if addr.Addr == "" && addr.Port == 0 {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
return e.connect(addr, true, true)
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 640bb8667..8b3356406 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -711,7 +711,8 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
-func (e *endpoint) disconnect() *tcpip.Error {
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -750,9 +751,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- if addr.Addr == "" {
- return e.disconnect()
- }
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index 6ffb65168..111dbacdf 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -378,16 +378,17 @@ TEST_P(UdpSocketTest, Connect) {
EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0);
}
-TEST_P(UdpSocketTest, ConnectAny) {
+void ConnectAny(AddressFamily family, int sockfd, uint16_t port) {
struct sockaddr_storage addr = {};
// Precondition check.
{
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(
+ getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
- if (GetParam() == AddressFamily::kIpv4) {
+ if (family == AddressFamily::kIpv4) {
auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr);
EXPECT_EQ(addrlen, sizeof(*addr_out));
EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY));
@@ -400,21 +401,24 @@ TEST_P(UdpSocketTest, ConnectAny) {
{
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
+ EXPECT_THAT(
+ getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
}
struct sockaddr_storage baddr = {};
- if (GetParam() == AddressFamily::kIpv4) {
+ if (family == AddressFamily::kIpv4) {
auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr);
addrlen = sizeof(*addr_in);
addr_in->sin_family = AF_INET;
addr_in->sin_addr.s_addr = htonl(INADDR_ANY);
+ addr_in->sin_port = port;
} else {
auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr);
addrlen = sizeof(*addr_in);
addr_in->sin6_family = AF_INET6;
- if (GetParam() == AddressFamily::kIpv6) {
+ addr_in->sin6_port = port;
+ if (family == AddressFamily::kIpv6) {
addr_in->sin6_addr = IN6ADDR_ANY_INIT;
} else {
TestAddress const& v4_mapped_any = V4MappedAny();
@@ -424,21 +428,23 @@ TEST_P(UdpSocketTest, ConnectAny) {
}
}
- ASSERT_THAT(connect(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen),
+ // TODO(b/138658473): gVisor doesn't allow connecting to the zero port.
+ if (port == 0) {
+ SKIP_IF(IsRunningOnGvisor());
+ }
+
+ ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen),
SyscallSucceeds());
}
- // TODO(b/138658473): gVisor doesn't return the correct local address after
- // connecting to the any address.
- SKIP_IF(IsRunningOnGvisor());
-
// Postcondition check.
{
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(
+ getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
- if (GetParam() == AddressFamily::kIpv4) {
+ if (family == AddressFamily::kIpv4) {
auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr);
EXPECT_EQ(addrlen, sizeof(*addr_out));
EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK));
@@ -446,7 +452,7 @@ TEST_P(UdpSocketTest, ConnectAny) {
auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr);
EXPECT_EQ(addrlen, sizeof(*addr_out));
struct in6_addr loopback;
- if (GetParam() == AddressFamily::kIpv6) {
+ if (family == AddressFamily::kIpv6) {
loopback = IN6ADDR_LOOPBACK_INIT;
} else {
TestAddress const& v4_mapped_loopback = V4MappedLoopback();
@@ -459,11 +465,91 @@ TEST_P(UdpSocketTest, ConnectAny) {
}
addrlen = sizeof(addr);
- EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
+ if (port == 0) {
+ EXPECT_THAT(
+ getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+ } else {
+ EXPECT_THAT(
+ getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ }
+ }
+}
+
+TEST_P(UdpSocketTest, ConnectAny) { ConnectAny(GetParam(), s_, 0); }
+
+TEST_P(UdpSocketTest, ConnectAnyWithPort) {
+ auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1]));
+ ConnectAny(GetParam(), s_, port);
+}
+
+void DisconnectAfterConnectAny(AddressFamily family, int sockfd, int port) {
+ struct sockaddr_storage addr = {};
+
+ socklen_t addrlen = sizeof(addr);
+ struct sockaddr_storage baddr = {};
+ if (family == AddressFamily::kIpv4) {
+ auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr);
+ addrlen = sizeof(*addr_in);
+ addr_in->sin_family = AF_INET;
+ addr_in->sin_addr.s_addr = htonl(INADDR_ANY);
+ addr_in->sin_port = port;
+ } else {
+ auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr);
+ addrlen = sizeof(*addr_in);
+ addr_in->sin6_family = AF_INET6;
+ addr_in->sin6_port = port;
+ if (family == AddressFamily::kIpv6) {
+ addr_in->sin6_addr = IN6ADDR_ANY_INIT;
+ } else {
+ TestAddress const& v4_mapped_any = V4MappedAny();
+ addr_in->sin6_addr =
+ reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr)
+ ->sin6_addr;
+ }
+ }
+
+ // TODO(b/138658473): gVisor doesn't allow connecting to the zero port.
+ if (port == 0) {
+ SKIP_IF(IsRunningOnGvisor());
+ }
+
+ ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen),
+ SyscallSucceeds());
+ // Now the socket is bound to the loopback address.
+
+ // Disconnect
+ addrlen = sizeof(addr);
+ addr.ss_family = AF_UNSPEC;
+ ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Check that after disconnect the socket is bound to the ANY address.
+ EXPECT_THAT(getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ if (family == AddressFamily::kIpv4) {
+ auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY));
+ } else {
+ auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ struct in6_addr loopback = IN6ADDR_ANY_INIT;
+
+ EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0);
}
}
+TEST_P(UdpSocketTest, DisconnectAfterConnectAny) {
+ DisconnectAfterConnectAny(GetParam(), s_, 0);
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) {
+ auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1]));
+ DisconnectAfterConnectAny(GetParam(), s_, port);
+}
+
TEST_P(UdpSocketTest, DisconnectAfterBind) {
ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds());
// Connect the socket.