diff options
author | Andrei Vagin <avagin@google.com> | 2019-07-03 13:57:24 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-07-03 14:19:02 -0700 |
commit | 116cac053e2e4e167caa9707439065af7c7b82b3 (patch) | |
tree | 1cc098420d1f323f1e8d92dfba3ed4a4c7991cfe | |
parent | f10862696c8508ee69f25838e25caacabf55ef83 (diff) |
netstack/udp: connect with the AF_UNSPEC address family means disconnect
PiperOrigin-RevId: 256433283
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 17 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 2 | ||||
-rw-r--r-- | pkg/sentry/strace/socket.go | 2 | ||||
-rw-r--r-- | pkg/syserr/netstack.go | 75 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 79 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 42 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 2 | ||||
-rw-r--r-- | test/syscalls/linux/udp_socket.cc | 155 |
12 files changed, 299 insertions, 91 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index b2b2d98a1..9d1bcfd41 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -285,14 +285,14 @@ func bytesToIPAddress(addr []byte) tcpip.Address { // GetAddress reads an sockaddr struct from the given address and converts it // to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6 // addresses. -func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { +func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { return tcpip.FullAddress{}, syserr.ErrInvalidArgument } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) { + if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported } @@ -317,7 +317,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { case linux.AF_INET: var a linux.SockAddrInet if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, syserr.ErrBadAddress + return tcpip.FullAddress{}, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) @@ -330,7 +330,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { case linux.AF_INET6: var a linux.SockAddrInet6 if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, syserr.ErrBadAddress + return tcpip.FullAddress{}, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) @@ -343,6 +343,9 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { } return out, nil + case linux.AF_UNSPEC: + return tcpip.FullAddress{}, nil + default: return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported } @@ -465,7 +468,7 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr) + addr, err := GetAddress(s.family, sockaddr, false /* strict */) if err != nil { return err } @@ -498,7 +501,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr) + addr, err := GetAddress(s.family, sockaddr, true /* strict */) if err != nil { return err } @@ -1922,7 +1925,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, err := GetAddress(s.family, to) + addrBuf, err := GetAddress(s.family, to, true /* strict */) if err != nil { return 0, err } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index b30871a90..637168714 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -110,7 +110,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr) + addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { return "", err } diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index f9cf2eb21..386b40af7 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, err := epsocket.GetAddress(int(family), b) + fa, err := epsocket.GetAddress(int(family), b, true /* strict */) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 965808a27..8ff922c69 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -49,43 +49,44 @@ var ( ) var netstackErrorTranslations = map[*tcpip.Error]*Error{ - tcpip.ErrUnknownProtocol: ErrUnknownProtocol, - tcpip.ErrUnknownNICID: ErrUnknownNICID, - tcpip.ErrUnknownDevice: ErrUnknownDevice, - tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption, - tcpip.ErrDuplicateNICID: ErrDuplicateNICID, - tcpip.ErrDuplicateAddress: ErrDuplicateAddress, - tcpip.ErrNoRoute: ErrNoRoute, - tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint, - tcpip.ErrAlreadyBound: ErrAlreadyBound, - tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState, - tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting, - tcpip.ErrAlreadyConnected: ErrAlreadyConnected, - tcpip.ErrNoPortAvailable: ErrNoPortAvailable, - tcpip.ErrPortInUse: ErrPortInUse, - tcpip.ErrBadLocalAddress: ErrBadLocalAddress, - tcpip.ErrClosedForSend: ErrClosedForSend, - tcpip.ErrClosedForReceive: ErrClosedForReceive, - tcpip.ErrWouldBlock: ErrWouldBlock, - tcpip.ErrConnectionRefused: ErrConnectionRefused, - tcpip.ErrTimeout: ErrTimeout, - tcpip.ErrAborted: ErrAborted, - tcpip.ErrConnectStarted: ErrConnectStarted, - tcpip.ErrDestinationRequired: ErrDestinationRequired, - tcpip.ErrNotSupported: ErrNotSupported, - tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported, - tcpip.ErrNotConnected: ErrNotConnected, - tcpip.ErrConnectionReset: ErrConnectionReset, - tcpip.ErrConnectionAborted: ErrConnectionAborted, - tcpip.ErrNoSuchFile: ErrNoSuchFile, - tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue, - tcpip.ErrNoLinkAddress: ErrHostDown, - tcpip.ErrBadAddress: ErrBadAddress, - tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable, - tcpip.ErrMessageTooLong: ErrMessageTooLong, - tcpip.ErrNoBufferSpace: ErrNoBufferSpace, - tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled, - tcpip.ErrNotPermitted: ErrNotPermittedNet, + tcpip.ErrUnknownProtocol: ErrUnknownProtocol, + tcpip.ErrUnknownNICID: ErrUnknownNICID, + tcpip.ErrUnknownDevice: ErrUnknownDevice, + tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption, + tcpip.ErrDuplicateNICID: ErrDuplicateNICID, + tcpip.ErrDuplicateAddress: ErrDuplicateAddress, + tcpip.ErrNoRoute: ErrNoRoute, + tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint, + tcpip.ErrAlreadyBound: ErrAlreadyBound, + tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState, + tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting, + tcpip.ErrAlreadyConnected: ErrAlreadyConnected, + tcpip.ErrNoPortAvailable: ErrNoPortAvailable, + tcpip.ErrPortInUse: ErrPortInUse, + tcpip.ErrBadLocalAddress: ErrBadLocalAddress, + tcpip.ErrClosedForSend: ErrClosedForSend, + tcpip.ErrClosedForReceive: ErrClosedForReceive, + tcpip.ErrWouldBlock: ErrWouldBlock, + tcpip.ErrConnectionRefused: ErrConnectionRefused, + tcpip.ErrTimeout: ErrTimeout, + tcpip.ErrAborted: ErrAborted, + tcpip.ErrConnectStarted: ErrConnectStarted, + tcpip.ErrDestinationRequired: ErrDestinationRequired, + tcpip.ErrNotSupported: ErrNotSupported, + tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported, + tcpip.ErrNotConnected: ErrNotConnected, + tcpip.ErrConnectionReset: ErrConnectionReset, + tcpip.ErrConnectionAborted: ErrConnectionAborted, + tcpip.ErrNoSuchFile: ErrNoSuchFile, + tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue, + tcpip.ErrNoLinkAddress: ErrHostDown, + tcpip.ErrBadAddress: ErrBadAddress, + tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable, + tcpip.ErrMessageTooLong: ErrMessageTooLong, + tcpip.ErrNoBufferSpace: ErrNoBufferSpace, + tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled, + tcpip.ErrNotPermitted: ErrNotPermittedNet, + tcpip.ErrAddressFamilyNotSupported: ErrAddressFamilyNotSupported, } // TranslateNetstackError converts an error from the tcpip package to a sentry diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 966b3acfd..c61f96fb0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -66,43 +66,44 @@ func (e *Error) IgnoreStats() bool { // Errors that can be returned by the network stack. var ( - ErrUnknownProtocol = &Error{msg: "unknown protocol"} - ErrUnknownNICID = &Error{msg: "unknown nic id"} - ErrUnknownDevice = &Error{msg: "unknown device"} - ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"} - ErrDuplicateNICID = &Error{msg: "duplicate nic id"} - ErrDuplicateAddress = &Error{msg: "duplicate address"} - ErrNoRoute = &Error{msg: "no route"} - ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"} - ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true} - ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"} - ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true} - ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true} - ErrNoPortAvailable = &Error{msg: "no ports are available"} - ErrPortInUse = &Error{msg: "port is in use"} - ErrBadLocalAddress = &Error{msg: "bad local address"} - ErrClosedForSend = &Error{msg: "endpoint is closed for send"} - ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"} - ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true} - ErrConnectionRefused = &Error{msg: "connection was refused"} - ErrTimeout = &Error{msg: "operation timed out"} - ErrAborted = &Error{msg: "operation aborted"} - ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true} - ErrDestinationRequired = &Error{msg: "destination address is required"} - ErrNotSupported = &Error{msg: "operation not supported"} - ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"} - ErrNotConnected = &Error{msg: "endpoint not connected"} - ErrConnectionReset = &Error{msg: "connection reset by peer"} - ErrConnectionAborted = &Error{msg: "connection aborted"} - ErrNoSuchFile = &Error{msg: "no such file"} - ErrInvalidOptionValue = &Error{msg: "invalid option value specified"} - ErrNoLinkAddress = &Error{msg: "no remote link address"} - ErrBadAddress = &Error{msg: "bad address"} - ErrNetworkUnreachable = &Error{msg: "network is unreachable"} - ErrMessageTooLong = &Error{msg: "message too long"} - ErrNoBufferSpace = &Error{msg: "no buffer space available"} - ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} - ErrNotPermitted = &Error{msg: "operation not permitted"} + ErrUnknownProtocol = &Error{msg: "unknown protocol"} + ErrUnknownNICID = &Error{msg: "unknown nic id"} + ErrUnknownDevice = &Error{msg: "unknown device"} + ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"} + ErrDuplicateNICID = &Error{msg: "duplicate nic id"} + ErrDuplicateAddress = &Error{msg: "duplicate address"} + ErrNoRoute = &Error{msg: "no route"} + ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"} + ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true} + ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"} + ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true} + ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true} + ErrNoPortAvailable = &Error{msg: "no ports are available"} + ErrPortInUse = &Error{msg: "port is in use"} + ErrBadLocalAddress = &Error{msg: "bad local address"} + ErrClosedForSend = &Error{msg: "endpoint is closed for send"} + ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"} + ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true} + ErrConnectionRefused = &Error{msg: "connection was refused"} + ErrTimeout = &Error{msg: "operation timed out"} + ErrAborted = &Error{msg: "operation aborted"} + ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true} + ErrDestinationRequired = &Error{msg: "destination address is required"} + ErrNotSupported = &Error{msg: "operation not supported"} + ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"} + ErrNotConnected = &Error{msg: "endpoint not connected"} + ErrConnectionReset = &Error{msg: "connection reset by peer"} + ErrConnectionAborted = &Error{msg: "connection aborted"} + ErrNoSuchFile = &Error{msg: "no such file"} + ErrInvalidOptionValue = &Error{msg: "invalid option value specified"} + ErrNoLinkAddress = &Error{msg: "no remote link address"} + ErrBadAddress = &Error{msg: "bad address"} + ErrNetworkUnreachable = &Error{msg: "network is unreachable"} + ErrMessageTooLong = &Error{msg: "message too long"} + ErrNoBufferSpace = &Error{msg: "no buffer space available"} + ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} + ErrNotPermitted = &Error{msg: "operation not permitted"} + ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} ) // Errors related to Subnet @@ -339,6 +340,10 @@ type Endpoint interface { // get the actual result. The first call to Connect after the socket has // connected returns nil. Calling connect again results in ErrAlreadyConnected. // Anything else -- the attempt to connect failed. + // + // If address.Addr is empty, this means that Enpoint has to be + // disconnected if this is supported, otherwise + // ErrAddressFamilyNotSupported must be returned. Connect(address FullAddress) *Error // Shutdown closes the read and/or write end of the endpoint connection diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 33cefd937..ab9e80747 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -422,6 +422,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() + if addr.Addr == "" { + // AF_UNSPEC isn't supported. + return tcpip.ErrAddressFamilyNotSupported + } + nicid := addr.NIC localPort := uint16(0) switch e.state { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 73599dd9d..42aded77f 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -298,6 +298,11 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() + if addr.Addr == "" { + // AF_UNSPEC isn't supported. + return tcpip.ErrAddressFamilyNotSupported + } + if ep.closed { return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ee60ebf58..cb40fea94 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1271,6 +1271,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + if addr.Addr == "" && addr.Port == 0 { + // AF_UNSPEC isn't supported. + return tcpip.ErrAddressFamilyNotSupported + } + return e.connect(addr, true, true) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index ec61a3886..b93959034 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -342,6 +342,7 @@ func loadError(s string) *tcpip.Error { tcpip.ErrNoBufferSpace, tcpip.ErrBroadcastDisabled, tcpip.ErrNotPermitted, + tcpip.ErrAddressFamilyNotSupported, } messageToError = make(map[string]*tcpip.Error) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 99fdfb795..cb0ea83a6 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -698,8 +698,44 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t return netProto, nil } +func (e *endpoint) disconnect() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + if e.state != stateConnected { + return nil + } + id := stack.TransportEndpointID{} + // Exclude ephemerally bound endpoints. + if e.bindNICID != 0 || e.id.LocalAddress == "" { + var err *tcpip.Error + id = stack.TransportEndpointID{ + LocalPort: e.id.LocalPort, + LocalAddress: e.id.LocalAddress, + } + id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + if err != nil { + return err + } + e.state = stateBound + } else { + e.state = stateInitial + } + + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.id = id + e.route.Release() + e.route = stack.Route{} + e.dstPort = 0 + + return nil +} + // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + if addr.Addr == "" { + return e.disconnect() + } if addr.Port == 0 { // We don't support connecting to port zero. return tcpip.ErrInvalidEndpointState @@ -734,12 +770,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { defer r.Release() id := stack.TransportEndpointID{ - LocalAddress: r.LocalAddress, + LocalAddress: e.id.LocalAddress, LocalPort: localPort, RemotePort: addr.Port, RemoteAddress: r.RemoteAddress, } + if e.state == stateInitial { + id.LocalAddress = r.LocalAddress + } + // Even if we're connected, this endpoint can still be used to send // packets on a different network protocol, so we register both even if // v6only is set to false and this is an ipv6 endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 701bdd72b..18e786397 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -92,8 +92,6 @@ func (e *endpoint) afterLoad() { if err != nil { panic(*err) } - - e.id.LocalAddress = e.route.LocalAddress } else if len(e.id.LocalAddress) != 0 { // stateBound if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 31db8a2ad..1bb0307c4 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <arpa/inet.h> #include <fcntl.h> #include <linux/errqueue.h> #include <netinet/in.h> @@ -304,12 +305,50 @@ TEST_P(UdpSocketTest, ReceiveAfterConnect) { SyscallSucceedsWithValue(sizeof(buf))); // Receive the data. - char received[512]; + char received[sizeof(buf)]; EXPECT_THAT(recv(s_, received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); } +TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { + // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); + + // Get the address s_ was bound to during connect. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + + for (int i = 0; i < 2; i++) { + // Send from t_ to s_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, + reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(s_, received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Disconnect s_. + struct sockaddr addr = {}; + addr.sa_family = AF_UNSPEC; + ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), SyscallSucceeds()); + // Connect s_ loopback:TestPort. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + } +} + TEST_P(UdpSocketTest, Connect) { ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); @@ -335,6 +374,112 @@ TEST_P(UdpSocketTest, Connect) { EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); } +TEST_P(UdpSocketTest, DisconnectAfterBind) { + ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); + // Connect the socket. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT( + connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), + SyscallSucceeds()); + + // Check that we're still bound. + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { + struct sockaddr_storage baddr = {}; + socklen_t addrlen; + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + if (addr_[0]->sa_family == AF_INET) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); + addr_in->sin_family = AF_INET; + addr_in->sin_port = port; + inet_pton(AF_INET, "0.0.0.0", + reinterpret_cast<void*>(&addr_in->sin_addr.s_addr)); + } else { + auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + inet_pton(AF_INET6, + "::", reinterpret_cast<void*>(&addr_in->sin6_addr.s6_addr)); + addr_in->sin6_scope_id = 0; + } + ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), + SyscallSucceeds()); + // Connect the socket. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT( + connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), + SyscallSucceeds()); + + // Check that we're still bound. + addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, &baddr, addrlen), 0); + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, Disconnect) { + for (int i = 0; i < 2; i++) { + // Try to connect again. + EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); + + // Try to disconnect. + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT( + connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), + SyscallSucceeds()); + + peerlen = sizeof(peer); + EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Check that we're still bound. + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&addr), 0); + } +} + +TEST_P(UdpSocketTest, ConnectBadAddress) { + struct sockaddr addr = {}; + addr.sa_family = addr_[0]->sa_family; + ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), + SyscallFailsWithErrno(EINVAL)); +} + TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); @@ -397,7 +542,7 @@ TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { SyscallSucceedsWithValue(sizeof(buf))); // Receive the data. - char received[512]; + char received[sizeof(buf)]; EXPECT_THAT(recv(s_, received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); @@ -419,7 +564,7 @@ TEST_P(UdpSocketTest, SendAndReceiveConnected) { SyscallSucceedsWithValue(sizeof(buf))); // Receive the data. - char received[512]; + char received[sizeof(buf)]; EXPECT_THAT(recv(s_, received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); @@ -462,7 +607,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) { ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); // Receive the data. It works because it was sent before the connect. - char received[512]; + char received[sizeof(buf)]; EXPECT_THAT(recv(s_, received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); @@ -491,7 +636,7 @@ TEST_P(UdpSocketTest, ReceiveFrom) { SyscallSucceedsWithValue(sizeof(buf))); // Receive the data and sender address. - char received[512]; + char received[sizeof(buf)]; struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); EXPECT_THAT(recvfrom(s_, received, sizeof(received), 0, |