From 2a888a106da39f1d5e280417e48a05341a41f4dd Mon Sep 17 00:00:00 2001 From: Zeling Feng Date: Tue, 9 Mar 2021 17:58:02 -0800 Subject: Give TCP flags a dedicated type - Implement Stringer for it so that we can improve error messages. - Use TCPFlags through the code base. There used to be a mixed usage of byte, uint8 and int as TCP flags. PiperOrigin-RevId: 361940150 --- pkg/tcpip/transport/tcp/connect.go | 6 +- pkg/tcpip/transport/tcp/protocol.go | 2 +- pkg/tcpip/transport/tcp/segment.go | 6 +- pkg/tcpip/transport/tcp/snd.go | 2 +- pkg/tcpip/transport/tcp/tcp_test.go | 66 +++++++++++----------- pkg/tcpip/transport/tcp/tcp_timestamp_test.go | 4 +- pkg/tcpip/transport/tcp/testing/context/context.go | 12 ++-- 7 files changed, 49 insertions(+), 49 deletions(-) (limited to 'pkg/tcpip/transport/tcp') diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 61a173fbb..3404af6bb 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -68,7 +68,7 @@ type handshake struct { ep *endpoint state handshakeState active bool - flags uint8 + flags header.TCPFlags ackNum seqnum.Value // iss is the initial send sequence number, as defined in RFC 793. @@ -700,7 +700,7 @@ type tcpFields struct { id stack.TransportEndpointID ttl uint8 tos uint8 - flags byte + flags header.TCPFlags seq seqnum.Value ack seqnum.Value rcvWnd seqnum.Size @@ -877,7 +877,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { } // sendRaw sends a TCP segment to the endpoint's peer. -func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { +func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { var sackBlocks []header.SACKBlock if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 04012cd40..2a4667906 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -226,7 +226,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) ack := seqnum.Value(0) - flags := byte(header.TCPFlagRst) + flags := header.TCPFlagRst // As per RFC 793 page 35 (Reset Generation) // 1. If the connection does not exist (CLOSED) then a reset is sent // in response to any incoming segment except another reset. In diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 744382100..8edd6775b 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -62,7 +62,7 @@ type segment struct { views [8]buffer.View `state:"nosave"` sequenceNumber seqnum.Value ackNumber seqnum.Value - flags uint8 + flags header.TCPFlags window seqnum.Size // csum is only populated for received segments. csum uint16 @@ -141,12 +141,12 @@ func (s *segment) clone() *segment { } // flagIsSet checks if at least one flag in flags is set in s.flags. -func (s *segment) flagIsSet(flags uint8) bool { +func (s *segment) flagIsSet(flags header.TCPFlags) bool { return s.flags&flags != 0 } // flagsAreSet checks if all flags in flags are set in s.flags. -func (s *segment) flagsAreSet(flags uint8) bool { +func (s *segment) flagsAreSet(flags header.TCPFlags) bool { return s.flags&flags == flags } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 83c8deb0e..18817029d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1613,7 +1613,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. -func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error { +func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { s.lastSendTime = time.Now() if seq == s.rttMeasureSeqNum { s.rttMeasureTime = s.lastSendTime diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a684f204d..fd499a47b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1372,7 +1372,7 @@ func TestTOSV4(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), ) @@ -1420,7 +1420,7 @@ func TestTrafficClassV6(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), ) @@ -2201,7 +2201,7 @@ func TestSimpleSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2241,7 +2241,7 @@ func TestZeroWindowSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2263,7 +2263,7 @@ func TestZeroWindowSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2310,7 +2310,7 @@ func TestScaledWindowConnect(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2341,7 +2341,7 @@ func TestNonScaledWindowConnect(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2414,7 +2414,7 @@ func TestScaledWindowAccept(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2487,7 +2487,7 @@ func TestNonScaledWindowAccept(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2665,7 +2665,7 @@ func TestSegmentMerging(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2688,7 +2688,7 @@ func TestSegmentMerging(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+11), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2737,7 +2737,7 @@ func TestDelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2785,7 +2785,7 @@ func TestUndelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2808,7 +2808,7 @@ func TestUndelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2871,7 +2871,7 @@ func TestMSSNotDelayed(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2922,7 +2922,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3437,7 +3437,7 @@ func TestMaxRTO(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) const numRetransmits = 2 @@ -3446,7 +3446,7 @@ func TestMaxRTO(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { @@ -3489,7 +3489,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { checker.FragmentFlags(0), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}} @@ -3501,7 +3501,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { checker.FragmentFlags(0), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) id := header.IPv4(pkt).ID() @@ -3632,7 +3632,7 @@ func TestFinWithNoPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3709,7 +3709,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3728,7 +3728,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3795,7 +3795,7 @@ func TestFinWithPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3821,7 +3821,7 @@ func TestFinWithPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3885,7 +3885,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3906,7 +3906,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3922,7 +3922,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -4032,7 +4032,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -4844,7 +4844,7 @@ func TestPathMTUDiscovery(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(seqNum), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) seqNum += uint32(size) @@ -5129,7 +5129,7 @@ func TestKeepalive(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -7174,7 +7174,7 @@ func TestTCPCloseWithData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -7274,7 +7274,7 @@ func TestTCPUserTimeout(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 5a9745ad7..cb4f82903 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -170,7 +170,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), checker.TCPTimestampChecker(true, 0, tsVal+1), ), ) @@ -231,7 +231,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), checker.TCPTimestampChecker(false, 0, 0), ), ) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index b1cb9a324..2f1c1011d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -101,7 +101,7 @@ type Headers struct { AckNum seqnum.Value // Flags are the TCP flags in the TCP header. - Flags int + Flags header.TCPFlags // RcvWnd is the window to be advertised in the ReceiveWindow field of // the TCP header. @@ -452,7 +452,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp SeqNum: uint32(h.SeqNum), AckNum: uint32(h.AckNum), DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), - Flags: uint8(h.Flags), + Flags: h.Flags, WindowSize: uint16(h.RcvWnd), }) @@ -544,7 +544,7 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -571,7 +571,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -650,7 +650,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp SeqNum: uint32(h.SeqNum), AckNum: uint32(h.AckNum), DataOffset: header.TCPMinimumSize, - Flags: uint8(h.Flags), + Flags: h.Flags, WindowSize: uint16(h.RcvWnd), }) @@ -780,7 +780,7 @@ type RawEndpoint struct { C *Context SrcPort uint16 DstPort uint16 - Flags int + Flags header.TCPFlags NextSeqNum seqnum.Value AckNum seqnum.Value WndSize seqnum.Size -- cgit v1.2.3