diff options
Diffstat (limited to 'pkg/tcpip/header')
-rw-r--r-- | pkg/tcpip/header/tcp.go | 28 | ||||
-rw-r--r-- | pkg/tcpip/header/tcp_test.go | 20 |
2 files changed, 41 insertions, 7 deletions
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 4c6f808e5..adc835d30 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -45,9 +45,23 @@ const ( TCPMaxSACKBlocks = 4 ) +// TCPFlags is the dedicated type for TCP flags. +type TCPFlags uint8 + +// String implements Stringer.String. +func (f TCPFlags) String() string { + flagsStr := []byte("FSRPAU") + for i := range flagsStr { + if f&(1<<uint(i)) == 0 { + flagsStr[i] = ' ' + } + } + return string(flagsStr) +} + // Flags that may be set in a TCP segment. const ( - TCPFlagFin = 1 << iota + TCPFlagFin TCPFlags = 1 << iota TCPFlagSyn TCPFlagRst TCPFlagPsh @@ -94,7 +108,7 @@ type TCPFields struct { DataOffset uint8 // Flags is the "flags" field of a TCP packet. - Flags uint8 + Flags TCPFlags // WindowSize is the "window size" field of a TCP packet. WindowSize uint16 @@ -234,8 +248,8 @@ func (b TCP) Payload() []byte { } // Flags returns the flags field of the tcp header. -func (b TCP) Flags() uint8 { - return b[TCPFlagsOffset] +func (b TCP) Flags() TCPFlags { + return TCPFlags(b[TCPFlagsOffset]) } // WindowSize returns the "window size" field of the tcp header. @@ -319,10 +333,10 @@ func (b TCP) ParsedOptions() TCPOptions { return ParseTCPOptions(b.Options()) } -func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) { +func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) { binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq) binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack) - b[TCPFlagsOffset] = flags + b[TCPFlagsOffset] = uint8(flags) binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } @@ -338,7 +352,7 @@ func (b TCP) Encode(t *TCPFields) { // EncodePartial updates a subset of the fields of the tcp header. It is useful // in cases when similar segments are produced. -func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) { +func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) { // Add the total length and "flags" field contributions to the checksum. // We don't use the flags field directly from the header because it's a // one-byte field with an odd offset, so it would be accounted for diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go index 72563837b..96db8460f 100644 --- a/pkg/tcpip/header/tcp_test.go +++ b/pkg/tcpip/header/tcp_test.go @@ -146,3 +146,23 @@ func TestTCPParseOptions(t *testing.T) { } } } + +func TestTCPFlags(t *testing.T) { + for _, tt := range []struct { + flags header.TCPFlags + want string + }{ + {header.TCPFlagFin, "F "}, + {header.TCPFlagSyn, " S "}, + {header.TCPFlagRst, " R "}, + {header.TCPFlagPsh, " P "}, + {header.TCPFlagAck, " A "}, + {header.TCPFlagUrg, " U"}, + {header.TCPFlagSyn | header.TCPFlagAck, " S A "}, + {header.TCPFlagFin | header.TCPFlagAck, "F A "}, + } { + if got := tt.flags.String(); got != tt.want { + t.Errorf("got TCPFlags(%#b).String() = %s, want = %s", tt.flags, got, tt.want) + } + } +} |