diff options
author | Googler <noreply@google.com> | 2018-04-27 10:37:02 -0700 |
---|---|---|
committer | Adin Scannell <ascannell@google.com> | 2018-04-28 01:44:26 -0400 |
commit | d02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch) | |
tree | 54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/tcpip/transport/tcpconntrack | |
parent | f70210e742919f40aa2f0934a22f1c9ba6dada62 (diff) |
Check in gVisor.
PiperOrigin-RevId: 194583126
Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/tcpip/transport/tcpconntrack')
-rw-r--r-- | pkg/tcpip/transport/tcpconntrack/BUILD | 24 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go | 333 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go | 501 |
3 files changed, 858 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD new file mode 100644 index 000000000..cf83ca134 --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -0,0 +1,24 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "tcpconntrack", + srcs = ["tcp_conntrack.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcpconntrack", + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + ], +) + +go_test( + name = "tcpconntrack_test", + size = "small", + srcs = ["tcp_conntrack_test.go"], + deps = [ + ":tcpconntrack", + "//pkg/tcpip/header", + ], +) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go new file mode 100644 index 000000000..487f2572d --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -0,0 +1,333 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tcpconntrack implements a TCP connection tracking object. It allows +// users with access to a segment stream to figure out when a connection is +// established, reset, and closed (and in the last case, who closed first). +package tcpconntrack + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" +) + +// Result is returned when the state of a TCB is updated in response to an +// inbound or outbound segment. +type Result int + +const ( + // ResultDrop indicates that the segment should be dropped. + ResultDrop Result = iota + + // ResultConnecting indicates that the connection remains in a + // connecting state. + ResultConnecting + + // ResultAlive indicates that the connection remains alive (connected). + ResultAlive + + // ResultReset indicates that the connection was reset. + ResultReset + + // ResultClosedByPeer indicates that the connection was gracefully + // closed, and the inbound stream was closed first. + ResultClosedByPeer + + // ResultClosedBySelf indicates that the connection was gracefully + // closed, and the outbound stream was closed first. + ResultClosedBySelf +) + +// TCB is a TCP Control Block. It holds state necessary to keep track of a TCP +// connection and inform the caller when the connection has been closed. +type TCB struct { + inbound stream + outbound stream + + // State handlers. + handlerInbound func(*TCB, header.TCP) Result + handlerOutbound func(*TCB, header.TCP) Result + + // firstFin holds a pointer to the first stream to send a FIN. + firstFin *stream + + // state is the current state of the stream. + state Result +} + +// Init initializes the state of the TCB according to the initial SYN. +func (t *TCB) Init(initialSyn header.TCP) { + t.handlerInbound = synSentStateInbound + t.handlerOutbound = synSentStateOutbound + + iss := seqnum.Value(initialSyn.SequenceNumber()) + t.outbound.una = iss + t.outbound.nxt = iss.Add(logicalLen(initialSyn)) + t.outbound.end = t.outbound.nxt + + // Even though "end" is a sequence number, we don't know the initial + // receive sequence number yet, so we store the window size until we get + // a SYN from the peer. + t.inbound.una = 0 + t.inbound.nxt = 0 + t.inbound.end = seqnum.Value(initialSyn.WindowSize()) + t.state = ResultConnecting +} + +// UpdateStateInbound updates the state of the TCB based on the supplied inbound +// segment. +func (t *TCB) UpdateStateInbound(tcp header.TCP) Result { + st := t.handlerInbound(t, tcp) + if st != ResultDrop { + t.state = st + } + return st +} + +// UpdateStateOutbound updates the state of the TCB based on the supplied +// outbound segment. +func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { + st := t.handlerOutbound(t, tcp) + if st != ResultDrop { + t.state = st + } + return st +} + +// IsAlive returns true as long as the connection is established(Alive) +// or connecting state. +func (t *TCB) IsAlive() bool { + return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed()) +} + +// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream. +func (t *TCB) OutboundSendSequenceNumber() seqnum.Value { + return t.outbound.nxt +} + +// adapResult modifies the supplied "Result" according to the state of the TCB; +// if r is anything other than "Alive", or if one of the streams isn't closed +// yet, it is returned unmodified. Otherwise it's converted to either +// ClosedBySelf or ClosedByPeer depending on which stream was closed first. +func (t *TCB) adaptResult(r Result) Result { + // Check the unmodified case. + if r != ResultAlive || !t.inbound.closed() || !t.outbound.closed() { + return r + } + + // Find out which was closed first. + if t.firstFin == &t.outbound { + return ResultClosedBySelf + } + + return ResultClosedByPeer +} + +// synSentStateInbound is the state handler for inbound segments when the +// connection is in SYN-SENT state. +func synSentStateInbound(t *TCB, tcp header.TCP) Result { + flags := tcp.Flags() + ackPresent := flags&header.TCPFlagAck != 0 + ack := seqnum.Value(tcp.AckNumber()) + + // Ignore segment if ack is present but not acceptable. + if ackPresent && !(ack-1).InRange(t.outbound.una, t.outbound.nxt) { + return ResultConnecting + } + + // If reset is specified, we will let the packet through no matter what + // but we will also destroy the connection if the ACK is present (and + // implicitly acceptable). + if flags&header.TCPFlagRst != 0 { + if ackPresent { + t.inbound.rstSeen = true + return ResultReset + } + return ResultConnecting + } + + // Ignore segment if SYN is not set. + if flags&header.TCPFlagSyn == 0 { + return ResultConnecting + } + + // Update state informed by this SYN. + irs := seqnum.Value(tcp.SequenceNumber()) + t.inbound.una = irs + t.inbound.nxt = irs.Add(logicalLen(tcp)) + t.inbound.end += irs + + t.outbound.end = t.outbound.una.Add(seqnum.Size(tcp.WindowSize())) + + // If the ACK was set (it is acceptable), update our unacknowledgement + // tracking. + if ackPresent { + // Advance the "una" and "end" indices of the outbound stream. + if t.outbound.una.LessThan(ack) { + t.outbound.una = ack + } + + if end := ack.Add(seqnum.Size(tcp.WindowSize())); t.outbound.end.LessThan(end) { + t.outbound.end = end + } + } + + // Update handlers so that new calls will be handled by new state. + t.handlerInbound = allOtherInbound + t.handlerOutbound = allOtherOutbound + + return ResultAlive +} + +// synSentStateOutbound is the state handler for outbound segments when the +// connection is in SYN-SENT state. +func synSentStateOutbound(t *TCB, tcp header.TCP) Result { + // Drop outbound segments that aren't retransmits of the original one. + if tcp.Flags() != header.TCPFlagSyn || + tcp.SequenceNumber() != uint32(t.outbound.una) { + return ResultDrop + } + + // Update the receive window. We only remember the largest value seen. + if wnd := seqnum.Value(tcp.WindowSize()); wnd > t.inbound.end { + t.inbound.end = wnd + } + + return ResultConnecting +} + +// update updates the state of inbound and outbound streams, given the supplied +// inbound segment. For outbound segments, this same function can be called with +// swapped inbound/outbound streams. +func update(tcp header.TCP, inbound, outbound *stream, firstFin **stream) Result { + // Ignore segments out of the window. + s := seqnum.Value(tcp.SequenceNumber()) + if !inbound.acceptable(s, dataLen(tcp)) { + return ResultAlive + } + + flags := tcp.Flags() + if flags&header.TCPFlagRst != 0 { + inbound.rstSeen = true + return ResultReset + } + + // Ignore segments that don't have the ACK flag, and those with the SYN + // flag. + if flags&header.TCPFlagAck == 0 || flags&header.TCPFlagSyn != 0 { + return ResultAlive + } + + // Ignore segments that acknowledge not yet sent data. + ack := seqnum.Value(tcp.AckNumber()) + if outbound.nxt.LessThan(ack) { + return ResultAlive + } + + // Advance the "una" and "end" indices of the outbound stream. + if outbound.una.LessThan(ack) { + outbound.una = ack + } + + if end := ack.Add(seqnum.Size(tcp.WindowSize())); outbound.end.LessThan(end) { + outbound.end = end + } + + // Advance the "nxt" index of the inbound stream. + end := s.Add(logicalLen(tcp)) + if inbound.nxt.LessThan(end) { + inbound.nxt = end + } + + // Note the index of the FIN segment. And stash away a pointer to the + // first stream to see a FIN. + if flags&header.TCPFlagFin != 0 && !inbound.finSeen { + inbound.finSeen = true + inbound.fin = end - 1 + + if *firstFin == nil { + *firstFin = inbound + } + } + + return ResultAlive +} + +// allOtherInbound is the state handler for inbound segments in all states +// except SYN-SENT. +func allOtherInbound(t *TCB, tcp header.TCP) Result { + return t.adaptResult(update(tcp, &t.inbound, &t.outbound, &t.firstFin)) +} + +// allOtherOutbound is the state handler for outbound segments in all states +// except SYN-SENT. +func allOtherOutbound(t *TCB, tcp header.TCP) Result { + return t.adaptResult(update(tcp, &t.outbound, &t.inbound, &t.firstFin)) +} + +// streams holds the state of a TCP unidirectional stream. +type stream struct { + // The interval [una, end) is the allowed interval as defined by the + // receiver, i.e., anything less than una has already been acknowledged + // and anything greater than or equal to end is beyond the receiver + // window. The interval [una, nxt) is the acknowledgable range, whose + // right edge indicates the sequence number of the next byte to be sent + // by the sender, i.e., anything greater than or equal to nxt hasn't + // been sent yet. + una seqnum.Value + nxt seqnum.Value + end seqnum.Value + + // finSeen indicates if a FIN has already been sent on this stream. + finSeen bool + + // fin is the sequence number of the FIN. It is only valid after finSeen + // is set to true. + fin seqnum.Value + + // rstSeen indicates if a RST has already been sent on this stream. + rstSeen bool +} + +// acceptable determines if the segment with the given sequence number and data +// length is acceptable, i.e., if it's within the [una, end) window or, in case +// the window is zero, if it's a packet with no payload and sequence number +// equal to una. +func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { + wnd := s.una.Size(s.end) + if wnd == 0 { + return segLen == 0 && segSeq == s.una + } + + // Make sure [segSeq, seqSeq+segLen) is non-empty. + if segLen == 0 { + segLen = 1 + } + + return seqnum.Overlap(s.una, wnd, segSeq, segLen) +} + +// closed determines if the stream has already been closed. This happens when +// a FIN has been set by the sender and acknowledged by the receiver. +func (s *stream) closed() bool { + return s.finSeen && s.fin.LessThan(s.una) +} + +// dataLen returns the length of the TCP segment payload. +func dataLen(tcp header.TCP) seqnum.Size { + return seqnum.Size(len(tcp) - int(tcp.DataOffset())) +} + +// logicalLen calculates the logical length of the TCP segment. +func logicalLen(tcp header.TCP) seqnum.Size { + l := dataLen(tcp) + flags := tcp.Flags() + if flags&header.TCPFlagSyn != 0 { + l++ + } + if flags&header.TCPFlagFin != 0 { + l++ + } + return l +} diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go new file mode 100644 index 000000000..50cab3132 --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go @@ -0,0 +1,501 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcpconntrack_test + +import ( + "testing" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcpconntrack" +) + +// connected creates a connection tracker TCB and sets it to a connected state +// by performing a 3-way handshake. +func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: iss, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: irw, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: irs, + AckNum: iss + 1, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: isw, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: iss + 1, + AckNum: irs + 1, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: irw, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + return &tcb +} + +func TestConnectionRefused(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive RST. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst | header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestConnectionRefusedInSynRcvd(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive RST with no ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestConnectionResetInSynRcvd(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send RST with no ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestRetransmitOnSynSent(t *testing.T) { + // Send initial SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Retransmit the same SYN. + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting) + } +} + +func TestRetransmitOnSynRcvd(t *testing.T) { + // Send initial SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. This will cause the state to go to SYN-RCVD. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Retransmit the original SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Transmit a SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } +} + +func TestClosedBySelf(t *testing.T) { + tcb := connected(t, 1234, 789, 30000, 50000) + + // Send FIN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 1236, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1236, + AckNum: 791, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) + } +} + +func TestClosedByPeer(t *testing.T) { + tcb := connected(t, 1234, 789, 30000, 50000) + + // Receive FIN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 791, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 791, + AckNum: 1236, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer) + } +} + +func TestSendAndReceiveDataClosedBySelf(t *testing.T) { + sseq := uint32(1234) + rseq := uint32(789) + tcb := connected(t, sseq, rseq, 30000, 50000) + sseq++ + rseq++ + + // Send some data. + tcp := make(header.TCP, header.TCPMinimumSize+1024) + + for i := uint32(0); i < 10; i++ { + // Send some data. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + sseq += uint32(len(tcp)) - header.TCPMinimumSize + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive ack for data. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + + for i := uint32(0); i < 10; i++ { + // Receive some data. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + rseq += uint32(len(tcp)) - header.TCPMinimumSize + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ack for data. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + + // Send FIN. + tcp = tcp[:header.TCPMinimumSize] + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + sseq++ + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + rseq++ + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) + } +} + +func TestIgnoreBadResetOnSynSent(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive a RST with a bad ACK, it should not cause the connection to + // be reset. + acks := []uint32{1234, 1236, 1000, 5000} + flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} + for _, a := range acks { + for _, f := range flags { + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: a, + DataOffset: header.TCPMinimumSize, + Flags: f, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + } + + // Complete the handshake. + // Receive SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } +} |