diff options
author | Zeling Feng <zeling@google.com> | 2021-03-09 17:58:02 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-03-09 18:00:03 -0800 |
commit | 2a888a106da39f1d5e280417e48a05341a41f4dd (patch) | |
tree | f1e5980bcea761aa323540af82311a13352b043f /pkg | |
parent | 6ef5bdab21e1e700a362a38435b57c9a1010aaf4 (diff) |
Give TCP flags a dedicated type
- Implement Stringer for it so that we can improve error messages.
- Use TCPFlags through the code base. There used to be a mixed usage of byte,
uint8 and int as TCP flags.
PiperOrigin-RevId: 361940150
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/checker/checker.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/header/tcp.go | 28 | ||||
-rw-r--r-- | pkg/tcpip/header/tcp_test.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/link/sniffer/sniffer.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 66 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_timestamp_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go | 2 |
12 files changed, 98 insertions, 70 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 75d8e1f03..fc622b246 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -567,7 +567,7 @@ func TCPWindowLessThanEq(window uint16) TransportChecker { } // TCPFlags creates a checker that checks the tcp flags. -func TCPFlags(flags uint8) TransportChecker { +func TCPFlags(flags header.TCPFlags) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() @@ -576,15 +576,15 @@ func TCPFlags(flags uint8) TransportChecker { t.Fatalf("TCP header not found in h: %T", h) } - if f := tcp.Flags(); f != flags { - t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) + if got := tcp.Flags(); got != flags { + t.Errorf("got tcp.Flags() = %s, want %s", got, flags) } } } // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the // given mask, match the supplied flags. -func TCPFlagsMatch(flags, mask uint8) TransportChecker { +func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() @@ -593,8 +593,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker { t.Fatalf("TCP header not found in h: %T", h) } - if f := tcp.Flags(); (f & mask) != (flags & mask) { - t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) + if got := tcp.Flags(); (got & mask) != (flags & mask) { + t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask) } } } diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 4c6f808e5..adc835d30 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -45,9 +45,23 @@ const ( TCPMaxSACKBlocks = 4 ) +// TCPFlags is the dedicated type for TCP flags. +type TCPFlags uint8 + +// String implements Stringer.String. +func (f TCPFlags) String() string { + flagsStr := []byte("FSRPAU") + for i := range flagsStr { + if f&(1<<uint(i)) == 0 { + flagsStr[i] = ' ' + } + } + return string(flagsStr) +} + // Flags that may be set in a TCP segment. const ( - TCPFlagFin = 1 << iota + TCPFlagFin TCPFlags = 1 << iota TCPFlagSyn TCPFlagRst TCPFlagPsh @@ -94,7 +108,7 @@ type TCPFields struct { DataOffset uint8 // Flags is the "flags" field of a TCP packet. - Flags uint8 + Flags TCPFlags // WindowSize is the "window size" field of a TCP packet. WindowSize uint16 @@ -234,8 +248,8 @@ func (b TCP) Payload() []byte { } // Flags returns the flags field of the tcp header. -func (b TCP) Flags() uint8 { - return b[TCPFlagsOffset] +func (b TCP) Flags() TCPFlags { + return TCPFlags(b[TCPFlagsOffset]) } // WindowSize returns the "window size" field of the tcp header. @@ -319,10 +333,10 @@ func (b TCP) ParsedOptions() TCPOptions { return ParseTCPOptions(b.Options()) } -func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) { +func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) { binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq) binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack) - b[TCPFlagsOffset] = flags + b[TCPFlagsOffset] = uint8(flags) binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } @@ -338,7 +352,7 @@ func (b TCP) Encode(t *TCPFields) { // EncodePartial updates a subset of the fields of the tcp header. It is useful // in cases when similar segments are produced. -func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) { +func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) { // Add the total length and "flags" field contributions to the checksum. // We don't use the flags field directly from the header because it's a // one-byte field with an odd offset, so it would be accounted for diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go index 72563837b..96db8460f 100644 --- a/pkg/tcpip/header/tcp_test.go +++ b/pkg/tcpip/header/tcp_test.go @@ -146,3 +146,23 @@ func TestTCPParseOptions(t *testing.T) { } } } + +func TestTCPFlags(t *testing.T) { + for _, tt := range []struct { + flags header.TCPFlags + want string + }{ + {header.TCPFlagFin, "F "}, + {header.TCPFlagSyn, " S "}, + {header.TCPFlagRst, " R "}, + {header.TCPFlagPsh, " P "}, + {header.TCPFlagAck, " A "}, + {header.TCPFlagUrg, " U"}, + {header.TCPFlagSyn | header.TCPFlagAck, " S A "}, + {header.TCPFlagFin | header.TCPFlagAck, "F A "}, + } { + if got := tt.flags.String(); got != tt.want { + t.Errorf("got TCPFlags(%#b).String() = %s, want = %s", tt.flags, got, tt.want) + } + } +} diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 84189bba5..7aaee3d13 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -398,13 +398,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe // Initialize the TCP flags. flags := tcp.Flags() - flagsStr := []byte("FSRPAU") - for i := range flagsStr { - if flags&(1<<uint(i)) == 0 { - flagsStr[i] = ' ' - } - } - details = fmt.Sprintf("flags:0x%02x (%s) seqnum: %d ack: %d win: %d xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) + details = fmt.Sprintf("flags: %s seqnum: %d ack: %d win: %d xsum:0x%x", flags, tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) if flags&header.TCPFlagSyn != 0 { details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0)) } else { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 61a173fbb..3404af6bb 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -68,7 +68,7 @@ type handshake struct { ep *endpoint state handshakeState active bool - flags uint8 + flags header.TCPFlags ackNum seqnum.Value // iss is the initial send sequence number, as defined in RFC 793. @@ -700,7 +700,7 @@ type tcpFields struct { id stack.TransportEndpointID ttl uint8 tos uint8 - flags byte + flags header.TCPFlags seq seqnum.Value ack seqnum.Value rcvWnd seqnum.Size @@ -877,7 +877,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { } // sendRaw sends a TCP segment to the endpoint's peer. -func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { +func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { var sackBlocks []header.SACKBlock if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 04012cd40..2a4667906 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -226,7 +226,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) ack := seqnum.Value(0) - flags := byte(header.TCPFlagRst) + flags := header.TCPFlagRst // As per RFC 793 page 35 (Reset Generation) // 1. If the connection does not exist (CLOSED) then a reset is sent // in response to any incoming segment except another reset. In diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 744382100..8edd6775b 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -62,7 +62,7 @@ type segment struct { views [8]buffer.View `state:"nosave"` sequenceNumber seqnum.Value ackNumber seqnum.Value - flags uint8 + flags header.TCPFlags window seqnum.Size // csum is only populated for received segments. csum uint16 @@ -141,12 +141,12 @@ func (s *segment) clone() *segment { } // flagIsSet checks if at least one flag in flags is set in s.flags. -func (s *segment) flagIsSet(flags uint8) bool { +func (s *segment) flagIsSet(flags header.TCPFlags) bool { return s.flags&flags != 0 } // flagsAreSet checks if all flags in flags are set in s.flags. -func (s *segment) flagsAreSet(flags uint8) bool { +func (s *segment) flagsAreSet(flags header.TCPFlags) bool { return s.flags&flags == flags } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 83c8deb0e..18817029d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1613,7 +1613,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. -func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error { +func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { s.lastSendTime = time.Now() if seq == s.rttMeasureSeqNum { s.rttMeasureTime = s.lastSendTime diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a684f204d..fd499a47b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1372,7 +1372,7 @@ func TestTOSV4(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), ) @@ -1420,7 +1420,7 @@ func TestTrafficClassV6(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), ) @@ -2201,7 +2201,7 @@ func TestSimpleSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2241,7 +2241,7 @@ func TestZeroWindowSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2263,7 +2263,7 @@ func TestZeroWindowSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2310,7 +2310,7 @@ func TestScaledWindowConnect(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2341,7 +2341,7 @@ func TestNonScaledWindowConnect(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2414,7 +2414,7 @@ func TestScaledWindowAccept(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2487,7 +2487,7 @@ func TestNonScaledWindowAccept(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2665,7 +2665,7 @@ func TestSegmentMerging(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2688,7 +2688,7 @@ func TestSegmentMerging(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+11), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2737,7 +2737,7 @@ func TestDelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2785,7 +2785,7 @@ func TestUndelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2808,7 +2808,7 @@ func TestUndelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2871,7 +2871,7 @@ func TestMSSNotDelayed(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2922,7 +2922,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3437,7 +3437,7 @@ func TestMaxRTO(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) const numRetransmits = 2 @@ -3446,7 +3446,7 @@ func TestMaxRTO(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { @@ -3489,7 +3489,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { checker.FragmentFlags(0), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}} @@ -3501,7 +3501,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { checker.FragmentFlags(0), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) id := header.IPv4(pkt).ID() @@ -3632,7 +3632,7 @@ func TestFinWithNoPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3709,7 +3709,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3728,7 +3728,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3795,7 +3795,7 @@ func TestFinWithPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3821,7 +3821,7 @@ func TestFinWithPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3885,7 +3885,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3906,7 +3906,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3922,7 +3922,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -4032,7 +4032,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -4844,7 +4844,7 @@ func TestPathMTUDiscovery(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(seqNum), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) seqNum += uint32(size) @@ -5129,7 +5129,7 @@ func TestKeepalive(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -7174,7 +7174,7 @@ func TestTCPCloseWithData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -7274,7 +7274,7 @@ func TestTCPUserTimeout(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 5a9745ad7..cb4f82903 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -170,7 +170,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), checker.TCPTimestampChecker(true, 0, tsVal+1), ), ) @@ -231,7 +231,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), checker.TCPTimestampChecker(false, 0, 0), ), ) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index b1cb9a324..2f1c1011d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -101,7 +101,7 @@ type Headers struct { AckNum seqnum.Value // Flags are the TCP flags in the TCP header. - Flags int + Flags header.TCPFlags // RcvWnd is the window to be advertised in the ReceiveWindow field of // the TCP header. @@ -452,7 +452,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp SeqNum: uint32(h.SeqNum), AckNum: uint32(h.AckNum), DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), - Flags: uint8(h.Flags), + Flags: h.Flags, WindowSize: uint16(h.RcvWnd), }) @@ -544,7 +544,7 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -571,7 +571,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -650,7 +650,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp SeqNum: uint32(h.SeqNum), AckNum: uint32(h.AckNum), DataOffset: header.TCPMinimumSize, - Flags: uint8(h.Flags), + Flags: h.Flags, WindowSize: uint16(h.RcvWnd), }) @@ -780,7 +780,7 @@ type RawEndpoint struct { C *Context SrcPort uint16 DstPort uint16 - Flags int + Flags header.TCPFlags NextSeqNum seqnum.Value AckNum seqnum.Value WndSize seqnum.Size diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go index 5e271b7ca..6c5ddc3c7 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go @@ -465,7 +465,7 @@ func TestIgnoreBadResetOnSynSent(t *testing.T) { // Receive a RST with a bad ACK, it should not cause the connection to // be reset. acks := []uint32{1234, 1236, 1000, 5000} - flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} + flags := []header.TCPFlags{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} for _, a := range acks { for _, f := range flags { tcp.Encode(&header.TCPFields{ |