diff options
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 17 | ||||
-rw-r--r-- | test/syscalls/linux/socket_inet_loopback.cc | 75 |
2 files changed, 48 insertions, 44 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e188efccb..80ad1a9d4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -150,16 +150,17 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { return eps } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { +// handlePacket is called by the stack when new packets arrive to this transport +// endpoint. It returns false if the packet could not be matched to any +// transport endpoint, true otherwise. +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool { epsByNIC.mu.RLock() mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. - return + return false } } @@ -168,18 +169,19 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. - return + return true } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() - return + return true } transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. + return true } // handleError delivers an error to the transport endpoint identified by id. @@ -567,8 +569,7 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, } return false } - ep.handlePacket(id, pkt) - return true + return ep.handlePacket(id, pkt) } // deliverRawPacket attempts to deliver the given packet and returns whether it diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 38e6f8b5c..9a6b089f6 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -477,47 +477,50 @@ void TestHangupDuringConnect(const TestParam& param, TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; - // Create the listening socket. - FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), 1), SyscallSucceeds()); + for (int i = 0; i < 100; i++) { + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), 0), SyscallSucceeds()); - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - // Connect asynchronously and immediately hang up the listener. - FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); - if (ret != 0) { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - } + // Connect asynchronously and immediately hang up the listener. + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + } - hangup(listen_fd); + hangup(listen_fd); - // Wait for the connection to close. - struct pollfd pfd = { - .fd = client.get(), - }; - constexpr int kTimeout = 10000; - int n = poll(&pfd, 1, kTimeout); - ASSERT_GE(n, 0) << strerror(errno); - ASSERT_EQ(n, 1); - ASSERT_EQ(pfd.revents, POLLHUP | POLLERR); - ASSERT_EQ(close(client.release()), 0) << strerror(errno); + // Wait for the connection to close. + struct pollfd pfd = { + .fd = client.get(), + }; + constexpr int kTimeout = 10000; + int n = poll(&pfd, 1, kTimeout); + ASSERT_GE(n, 0) << strerror(errno); + ASSERT_EQ(n, 1); + ASSERT_EQ(pfd.revents, POLLHUP | POLLERR); + ASSERT_EQ(close(client.release()), 0) << strerror(errno); + } } TEST_P(SocketInetLoopbackTest, TCPListenCloseDuringConnect) { |