summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcpconntrack
diff options
context:
space:
mode:
authorGoogler <noreply@google.com>2018-04-27 10:37:02 -0700
committerAdin Scannell <ascannell@google.com>2018-04-28 01:44:26 -0400
commitd02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch)
tree54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/tcpip/transport/tcpconntrack
parentf70210e742919f40aa2f0934a22f1c9ba6dada62 (diff)
Check in gVisor.
PiperOrigin-RevId: 194583126 Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/tcpip/transport/tcpconntrack')
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD24
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go333
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go501
3 files changed, 858 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
new file mode 100644
index 000000000..cf83ca134
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -0,0 +1,24 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "tcpconntrack",
+ srcs = ["tcp_conntrack.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcpconntrack",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ ],
+)
+
+go_test(
+ name = "tcpconntrack_test",
+ size = "small",
+ srcs = ["tcp_conntrack_test.go"],
+ deps = [
+ ":tcpconntrack",
+ "//pkg/tcpip/header",
+ ],
+)
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
new file mode 100644
index 000000000..487f2572d
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -0,0 +1,333 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package tcpconntrack implements a TCP connection tracking object. It allows
+// users with access to a segment stream to figure out when a connection is
+// established, reset, and closed (and in the last case, who closed first).
+package tcpconntrack
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+// Result is returned when the state of a TCB is updated in response to an
+// inbound or outbound segment.
+type Result int
+
+const (
+ // ResultDrop indicates that the segment should be dropped.
+ ResultDrop Result = iota
+
+ // ResultConnecting indicates that the connection remains in a
+ // connecting state.
+ ResultConnecting
+
+ // ResultAlive indicates that the connection remains alive (connected).
+ ResultAlive
+
+ // ResultReset indicates that the connection was reset.
+ ResultReset
+
+ // ResultClosedByPeer indicates that the connection was gracefully
+ // closed, and the inbound stream was closed first.
+ ResultClosedByPeer
+
+ // ResultClosedBySelf indicates that the connection was gracefully
+ // closed, and the outbound stream was closed first.
+ ResultClosedBySelf
+)
+
+// TCB is a TCP Control Block. It holds state necessary to keep track of a TCP
+// connection and inform the caller when the connection has been closed.
+type TCB struct {
+ inbound stream
+ outbound stream
+
+ // State handlers.
+ handlerInbound func(*TCB, header.TCP) Result
+ handlerOutbound func(*TCB, header.TCP) Result
+
+ // firstFin holds a pointer to the first stream to send a FIN.
+ firstFin *stream
+
+ // state is the current state of the stream.
+ state Result
+}
+
+// Init initializes the state of the TCB according to the initial SYN.
+func (t *TCB) Init(initialSyn header.TCP) {
+ t.handlerInbound = synSentStateInbound
+ t.handlerOutbound = synSentStateOutbound
+
+ iss := seqnum.Value(initialSyn.SequenceNumber())
+ t.outbound.una = iss
+ t.outbound.nxt = iss.Add(logicalLen(initialSyn))
+ t.outbound.end = t.outbound.nxt
+
+ // Even though "end" is a sequence number, we don't know the initial
+ // receive sequence number yet, so we store the window size until we get
+ // a SYN from the peer.
+ t.inbound.una = 0
+ t.inbound.nxt = 0
+ t.inbound.end = seqnum.Value(initialSyn.WindowSize())
+ t.state = ResultConnecting
+}
+
+// UpdateStateInbound updates the state of the TCB based on the supplied inbound
+// segment.
+func (t *TCB) UpdateStateInbound(tcp header.TCP) Result {
+ st := t.handlerInbound(t, tcp)
+ if st != ResultDrop {
+ t.state = st
+ }
+ return st
+}
+
+// UpdateStateOutbound updates the state of the TCB based on the supplied
+// outbound segment.
+func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
+ st := t.handlerOutbound(t, tcp)
+ if st != ResultDrop {
+ t.state = st
+ }
+ return st
+}
+
+// IsAlive returns true as long as the connection is established(Alive)
+// or connecting state.
+func (t *TCB) IsAlive() bool {
+ return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed())
+}
+
+// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream.
+func (t *TCB) OutboundSendSequenceNumber() seqnum.Value {
+ return t.outbound.nxt
+}
+
+// adapResult modifies the supplied "Result" according to the state of the TCB;
+// if r is anything other than "Alive", or if one of the streams isn't closed
+// yet, it is returned unmodified. Otherwise it's converted to either
+// ClosedBySelf or ClosedByPeer depending on which stream was closed first.
+func (t *TCB) adaptResult(r Result) Result {
+ // Check the unmodified case.
+ if r != ResultAlive || !t.inbound.closed() || !t.outbound.closed() {
+ return r
+ }
+
+ // Find out which was closed first.
+ if t.firstFin == &t.outbound {
+ return ResultClosedBySelf
+ }
+
+ return ResultClosedByPeer
+}
+
+// synSentStateInbound is the state handler for inbound segments when the
+// connection is in SYN-SENT state.
+func synSentStateInbound(t *TCB, tcp header.TCP) Result {
+ flags := tcp.Flags()
+ ackPresent := flags&header.TCPFlagAck != 0
+ ack := seqnum.Value(tcp.AckNumber())
+
+ // Ignore segment if ack is present but not acceptable.
+ if ackPresent && !(ack-1).InRange(t.outbound.una, t.outbound.nxt) {
+ return ResultConnecting
+ }
+
+ // If reset is specified, we will let the packet through no matter what
+ // but we will also destroy the connection if the ACK is present (and
+ // implicitly acceptable).
+ if flags&header.TCPFlagRst != 0 {
+ if ackPresent {
+ t.inbound.rstSeen = true
+ return ResultReset
+ }
+ return ResultConnecting
+ }
+
+ // Ignore segment if SYN is not set.
+ if flags&header.TCPFlagSyn == 0 {
+ return ResultConnecting
+ }
+
+ // Update state informed by this SYN.
+ irs := seqnum.Value(tcp.SequenceNumber())
+ t.inbound.una = irs
+ t.inbound.nxt = irs.Add(logicalLen(tcp))
+ t.inbound.end += irs
+
+ t.outbound.end = t.outbound.una.Add(seqnum.Size(tcp.WindowSize()))
+
+ // If the ACK was set (it is acceptable), update our unacknowledgement
+ // tracking.
+ if ackPresent {
+ // Advance the "una" and "end" indices of the outbound stream.
+ if t.outbound.una.LessThan(ack) {
+ t.outbound.una = ack
+ }
+
+ if end := ack.Add(seqnum.Size(tcp.WindowSize())); t.outbound.end.LessThan(end) {
+ t.outbound.end = end
+ }
+ }
+
+ // Update handlers so that new calls will be handled by new state.
+ t.handlerInbound = allOtherInbound
+ t.handlerOutbound = allOtherOutbound
+
+ return ResultAlive
+}
+
+// synSentStateOutbound is the state handler for outbound segments when the
+// connection is in SYN-SENT state.
+func synSentStateOutbound(t *TCB, tcp header.TCP) Result {
+ // Drop outbound segments that aren't retransmits of the original one.
+ if tcp.Flags() != header.TCPFlagSyn ||
+ tcp.SequenceNumber() != uint32(t.outbound.una) {
+ return ResultDrop
+ }
+
+ // Update the receive window. We only remember the largest value seen.
+ if wnd := seqnum.Value(tcp.WindowSize()); wnd > t.inbound.end {
+ t.inbound.end = wnd
+ }
+
+ return ResultConnecting
+}
+
+// update updates the state of inbound and outbound streams, given the supplied
+// inbound segment. For outbound segments, this same function can be called with
+// swapped inbound/outbound streams.
+func update(tcp header.TCP, inbound, outbound *stream, firstFin **stream) Result {
+ // Ignore segments out of the window.
+ s := seqnum.Value(tcp.SequenceNumber())
+ if !inbound.acceptable(s, dataLen(tcp)) {
+ return ResultAlive
+ }
+
+ flags := tcp.Flags()
+ if flags&header.TCPFlagRst != 0 {
+ inbound.rstSeen = true
+ return ResultReset
+ }
+
+ // Ignore segments that don't have the ACK flag, and those with the SYN
+ // flag.
+ if flags&header.TCPFlagAck == 0 || flags&header.TCPFlagSyn != 0 {
+ return ResultAlive
+ }
+
+ // Ignore segments that acknowledge not yet sent data.
+ ack := seqnum.Value(tcp.AckNumber())
+ if outbound.nxt.LessThan(ack) {
+ return ResultAlive
+ }
+
+ // Advance the "una" and "end" indices of the outbound stream.
+ if outbound.una.LessThan(ack) {
+ outbound.una = ack
+ }
+
+ if end := ack.Add(seqnum.Size(tcp.WindowSize())); outbound.end.LessThan(end) {
+ outbound.end = end
+ }
+
+ // Advance the "nxt" index of the inbound stream.
+ end := s.Add(logicalLen(tcp))
+ if inbound.nxt.LessThan(end) {
+ inbound.nxt = end
+ }
+
+ // Note the index of the FIN segment. And stash away a pointer to the
+ // first stream to see a FIN.
+ if flags&header.TCPFlagFin != 0 && !inbound.finSeen {
+ inbound.finSeen = true
+ inbound.fin = end - 1
+
+ if *firstFin == nil {
+ *firstFin = inbound
+ }
+ }
+
+ return ResultAlive
+}
+
+// allOtherInbound is the state handler for inbound segments in all states
+// except SYN-SENT.
+func allOtherInbound(t *TCB, tcp header.TCP) Result {
+ return t.adaptResult(update(tcp, &t.inbound, &t.outbound, &t.firstFin))
+}
+
+// allOtherOutbound is the state handler for outbound segments in all states
+// except SYN-SENT.
+func allOtherOutbound(t *TCB, tcp header.TCP) Result {
+ return t.adaptResult(update(tcp, &t.outbound, &t.inbound, &t.firstFin))
+}
+
+// streams holds the state of a TCP unidirectional stream.
+type stream struct {
+ // The interval [una, end) is the allowed interval as defined by the
+ // receiver, i.e., anything less than una has already been acknowledged
+ // and anything greater than or equal to end is beyond the receiver
+ // window. The interval [una, nxt) is the acknowledgable range, whose
+ // right edge indicates the sequence number of the next byte to be sent
+ // by the sender, i.e., anything greater than or equal to nxt hasn't
+ // been sent yet.
+ una seqnum.Value
+ nxt seqnum.Value
+ end seqnum.Value
+
+ // finSeen indicates if a FIN has already been sent on this stream.
+ finSeen bool
+
+ // fin is the sequence number of the FIN. It is only valid after finSeen
+ // is set to true.
+ fin seqnum.Value
+
+ // rstSeen indicates if a RST has already been sent on this stream.
+ rstSeen bool
+}
+
+// acceptable determines if the segment with the given sequence number and data
+// length is acceptable, i.e., if it's within the [una, end) window or, in case
+// the window is zero, if it's a packet with no payload and sequence number
+// equal to una.
+func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
+ wnd := s.una.Size(s.end)
+ if wnd == 0 {
+ return segLen == 0 && segSeq == s.una
+ }
+
+ // Make sure [segSeq, seqSeq+segLen) is non-empty.
+ if segLen == 0 {
+ segLen = 1
+ }
+
+ return seqnum.Overlap(s.una, wnd, segSeq, segLen)
+}
+
+// closed determines if the stream has already been closed. This happens when
+// a FIN has been set by the sender and acknowledged by the receiver.
+func (s *stream) closed() bool {
+ return s.finSeen && s.fin.LessThan(s.una)
+}
+
+// dataLen returns the length of the TCP segment payload.
+func dataLen(tcp header.TCP) seqnum.Size {
+ return seqnum.Size(len(tcp) - int(tcp.DataOffset()))
+}
+
+// logicalLen calculates the logical length of the TCP segment.
+func logicalLen(tcp header.TCP) seqnum.Size {
+ l := dataLen(tcp)
+ flags := tcp.Flags()
+ if flags&header.TCPFlagSyn != 0 {
+ l++
+ }
+ if flags&header.TCPFlagFin != 0 {
+ l++
+ }
+ return l
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
new file mode 100644
index 000000000..50cab3132
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
@@ -0,0 +1,501 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcpconntrack_test
+
+import (
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcpconntrack"
+)
+
+// connected creates a connection tracker TCB and sets it to a connected state
+// by performing a 3-way handshake.
+func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: iss,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: irw,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN-ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: irs,
+ AckNum: iss + 1,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ WindowSize: isw,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: iss + 1,
+ AckNum: irs + 1,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: irw,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ return &tcb
+}
+
+func TestConnectionRefused(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive RST.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 1235,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
+ }
+}
+
+func TestConnectionRefusedInSynRcvd(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive RST with no ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 790,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagRst,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
+ }
+}
+
+func TestConnectionResetInSynRcvd(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send RST with no ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagRst,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
+ }
+}
+
+func TestRetransmitOnSynSent(t *testing.T) {
+ // Send initial SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Retransmit the same SYN.
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting)
+ }
+}
+
+func TestRetransmitOnSynRcvd(t *testing.T) {
+ // Send initial SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN. This will cause the state to go to SYN-RCVD.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Retransmit the original SYN.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Transmit a SYN-ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 790,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+}
+
+func TestClosedBySelf(t *testing.T) {
+ tcb := connected(t, 1234, 789, 30000, 50000)
+
+ // Send FIN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 790,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive FIN/ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 790,
+ AckNum: 1236,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1236,
+ AckNum: 791,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
+ }
+}
+
+func TestClosedByPeer(t *testing.T) {
+ tcb := connected(t, 1234, 789, 30000, 50000)
+
+ // Receive FIN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 790,
+ AckNum: 1235,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send FIN/ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 791,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 791,
+ AckNum: 1236,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer)
+ }
+}
+
+func TestSendAndReceiveDataClosedBySelf(t *testing.T) {
+ sseq := uint32(1234)
+ rseq := uint32(789)
+ tcb := connected(t, sseq, rseq, 30000, 50000)
+ sseq++
+ rseq++
+
+ // Send some data.
+ tcp := make(header.TCP, header.TCPMinimumSize+1024)
+
+ for i := uint32(0); i < 10; i++ {
+ // Send some data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+ sseq += uint32(len(tcp)) - header.TCPMinimumSize
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive ack for data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: rseq,
+ AckNum: sseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+ }
+
+ for i := uint32(0); i < 10; i++ {
+ // Receive some data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: rseq,
+ AckNum: sseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+ rseq += uint32(len(tcp)) - header.TCPMinimumSize
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ack for data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+ }
+
+ // Send FIN.
+ tcp = tcp[:header.TCPMinimumSize]
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 30000,
+ })
+ sseq++
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive FIN/ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: rseq,
+ AckNum: sseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 50000,
+ })
+ rseq++
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
+ }
+}
+
+func TestIgnoreBadResetOnSynSent(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive a RST with a bad ACK, it should not cause the connection to
+ // be reset.
+ acks := []uint32{1234, 1236, 1000, 5000}
+ flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck}
+ for _, a := range acks {
+ for _, f := range flags {
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: a,
+ DataOffset: header.TCPMinimumSize,
+ Flags: f,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+ }
+ }
+
+ // Complete the handshake.
+ // Receive SYN-ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 1235,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 790,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+}