summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/connect.go11
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go147
2 files changed, 102 insertions, 56 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 1798510bc..6e5e55b6f 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1024,14 +1024,19 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// delivered to this endpoint from the demuxer when the endpoint
// is transitioned to StateClose.
func (e *endpoint) transitionToStateCloseLocked() {
- if e.EndpointState() == StateClose {
+ s := e.EndpointState()
+ if s == StateClose {
return
}
+
+ if s.connected() {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
+ e.stack.Stats().TCP.EstablishedClosed.Increment()
+ }
+
// Mark the endpoint as fully closed for reads/writes.
e.cleanupLocked()
e.setEndpointState(StateClose)
- e.stack.Stats().TCP.CurrentConnected.Decrement()
- e.stack.Stats().TCP.EstablishedClosed.Increment()
}
// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index e67ec42b1..fb25b86b9 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -146,6 +146,24 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
}
}
+func TestCloseWithoutConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ c.EP.Close()
+
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
func TestTCPSegmentsSentIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -1276,68 +1294,91 @@ func TestConnectBindToDevice(t *testing.T) {
}
}
-func TestRstOnSynSent(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
+func TestSynSent(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ reset bool
+ }{
+ {"RstOnSynSent", true},
+ {"CloseOnSynSent", false},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
- // Create an endpoint, don't handshake because we want to interfere with the
- // handshake process.
- c.Create(-1)
+ // Create an endpoint, don't handshake because we want to interfere with the
+ // handshake process.
+ c.Create(-1)
- // Start connection attempt.
- waitEntry, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventOut)
- defer c.WQ.EventUnregister(&waitEntry)
+ // Start connection attempt.
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
- addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
- t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
- }
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
+ }
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
- // Ensure that we've reached SynSent state
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
- // Send a packet with a proper ACK and a RST flag to cause the socket
- // to Error and close out
- iss := seqnum.Value(789)
- rcvWnd := seqnum.Size(30000)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagRst | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: nil,
- })
+ if test.reset {
+ // Send a packet with a proper ACK and a RST flag to cause the socket
+ // to error and close out.
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+ } else {
+ c.EP.Close()
+ }
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatal("timed out waiting for packet to arrive")
- }
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for packet to arrive")
+ }
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
- }
+ if test.reset {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
+ }
+ } else {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted)
+ }
+ }
- // Due to the RST the endpoint should be in an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+
+ // Due to the RST the endpoint should be in an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ })
}
}