summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/netstack/netstack.go7
-rw-r--r--pkg/tcpip/tcpip.go4
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go9
-rw-r--r--test/syscalls/linux/tcp_socket.cc9
5 files changed, 25 insertions, 8 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 049d04bf2..ed2fbcceb 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -2229,11 +2229,16 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
var copied int
// Copy as many views as possible into the user-provided buffer.
- for dst.NumBytes() != 0 {
+ for {
+ // Always do at least one fetchReadView, even if the number of bytes to
+ // read is 0.
err = s.fetchReadView()
if err != nil {
break
}
+ if dst.NumBytes() == 0 {
+ break
+ }
var n int
var e error
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 0fa141d58..d29d9a704 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1124,6 +1124,10 @@ type ReadErrors struct {
// InvalidEndpointState is the number of times we found the endpoint state
// to be unexpected.
InvalidEndpointState StatCounter
+
+ // NotConnected is the number of times we tried to read but found that the
+ // endpoint was not connected.
+ NotConnected StatCounter
}
// WriteErrors collects packet write errors from an endpoint write call.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b5a8e15ee..e4a6b1b8b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1003,8 +1003,8 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
- e.stats.ReadErrors.InvalidEndpointState.Increment()
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ e.stats.ReadErrors.NotConnected.Increment()
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
}
v, err := e.readLocked()
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 2c1505067..cc118c993 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5405,12 +5405,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
- // Expect InvalidEndpointState errors on a read at this point.
- if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState)
+ if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
+ t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected)
}
- if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 {
- t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1)
+ if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
+ t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1)
}
if err := ep.Listen(10); err != nil {
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 525ccbd88..8a8b68e75 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -1339,6 +1339,15 @@ TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) {
EXPECT_EQ(get, kTCPDeferAccept);
}
+TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) {
+ auto s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ char buf[1];
+ EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallFailsWithErrno(ENOTCONN));
+ EXPECT_THAT(recv(s.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));