summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/checker/checker.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/checker/checker.go')
-rw-r--r--pkg/tcpip/checker/checker.go133
1 files changed, 89 insertions, 44 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 8e0e49efa..206531f20 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -39,40 +39,52 @@ type TransportChecker func(*testing.T, header.Transport)
//
// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
ipv4 := header.IPv4(b)
if !ipv4.IsValid(len(b)) {
- t.Fatalf("Not a valid IPv4 packet")
+ t.Error("Not a valid IPv4 packet")
}
xsum := ipv4.CalculateChecksum()
if xsum != 0 && xsum != 0xffff {
- t.Fatalf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
+ t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
}
for _, f := range checkers {
f(t, []header.Network{ipv4})
}
+ if t.Failed() {
+ t.FailNow()
+ }
}
// IPv6 checks the validity and properties of the given IPv6 packet. The usage
// is similar to IPv4.
func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
ipv6 := header.IPv6(b)
if !ipv6.IsValid(len(b)) {
- t.Fatalf("Not a valid IPv6 packet")
+ t.Error("Not a valid IPv6 packet")
}
for _, f := range checkers {
f(t, []header.Network{ipv6})
}
+ if t.Failed() {
+ t.FailNow()
+ }
}
// SrcAddr creates a checker that checks the source address.
func SrcAddr(addr tcpip.Address) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if a := h[0].SourceAddress(); a != addr {
- t.Fatalf("Bad source address, got %v, want %v", a, addr)
+ t.Errorf("Bad source address, got %v, want %v", a, addr)
}
}
}
@@ -80,8 +92,10 @@ func SrcAddr(addr tcpip.Address) NetworkChecker {
// DstAddr creates a checker that checks the destination address.
func DstAddr(addr tcpip.Address) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if a := h[0].DestinationAddress(); a != addr {
- t.Fatalf("Bad destination address, got %v, want %v", a, addr)
+ t.Errorf("Bad destination address, got %v, want %v", a, addr)
}
}
}
@@ -105,8 +119,10 @@ func TTL(ttl uint8) NetworkChecker {
// PayloadLen creates a checker that checks the payload length.
func PayloadLen(plen int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if l := len(h[0].Payload()); l != plen {
- t.Fatalf("Bad payload length, got %v, want %v", l, plen)
+ t.Errorf("Bad payload length, got %v, want %v", l, plen)
}
}
}
@@ -114,11 +130,13 @@ func PayloadLen(plen int) NetworkChecker {
// FragmentOffset creates a checker that checks the FragmentOffset field.
func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
// We only do this of IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.FragmentOffset(); v != offset {
- t.Fatalf("Bad fragment offset, got %v, want %v", v, offset)
+ t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
}
}
}
@@ -127,11 +145,13 @@ func FragmentOffset(offset uint16) NetworkChecker {
// FragmentFlags creates a checker that checks the fragment flags field.
func FragmentFlags(flags uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
// We only do this of IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.Flags(); v != flags {
- t.Fatalf("Bad fragment offset, got %v, want %v", v, flags)
+ t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
}
}
}
@@ -140,8 +160,10 @@ func FragmentFlags(flags uint8) NetworkChecker {
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if v, l := h[0].TOS(); v != tos || l != label {
- t.Fatalf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
+ t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
}
}
}
@@ -153,8 +175,10 @@ func TOS(tos uint8, label uint32) NetworkChecker {
// the bytes added by the IPv6 fragmentation.
func Raw(want []byte) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
- t.Fatalf("Wrong payload, got %v, want %v", got, want)
+ t.Errorf("Wrong payload, got %v, want %v", got, want)
}
}
}
@@ -162,18 +186,23 @@ func Raw(want []byte) NetworkChecker {
// IPv6Fragment creates a checker that validates an IPv6 fragment.
func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Fatalf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
}
ipv6Frag := header.IPv6Fragment(h[0].Payload())
if !ipv6Frag.IsValid() {
- t.Fatalf("Not a valid IPv6 fragment")
+ t.Error("Not a valid IPv6 fragment")
}
for _, f := range checkers {
f(t, []header.Network{h[0], ipv6Frag})
}
+ if t.Failed() {
+ t.FailNow()
+ }
}
}
@@ -181,11 +210,13 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
// potentially additional transport header fields.
func TCP(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
first := h[0]
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Fatalf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
+ t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
}
// Verify the checksum.
@@ -199,13 +230,16 @@ func TCP(checkers ...TransportChecker) NetworkChecker {
xsum = header.Checksum(tcp, xsum)
if xsum != 0 && xsum != 0xffff {
- t.Fatalf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
+ t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
}
// Run the transport checkers.
for _, f := range checkers {
f(t, tcp)
}
+ if t.Failed() {
+ t.FailNow()
+ }
}
}
@@ -213,24 +247,31 @@ func TCP(checkers ...TransportChecker) NetworkChecker {
// potentially additional transport header fields.
func UDP(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Fatalf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
}
udp := header.UDP(last.Payload())
for _, f := range checkers {
f(t, udp)
}
+ if t.Failed() {
+ t.FailNow()
+ }
}
}
// SrcPort creates a checker that checks the source port.
func SrcPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
if p := h.SourcePort(); p != port {
- t.Fatalf("Bad source port, got %v, want %v", p, port)
+ t.Errorf("Bad source port, got %v, want %v", p, port)
}
}
}
@@ -239,7 +280,7 @@ func SrcPort(port uint16) TransportChecker {
func DstPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
if p := h.DestinationPort(); p != port {
- t.Fatalf("Bad destination port, got %v, want %v", p, port)
+ t.Errorf("Bad destination port, got %v, want %v", p, port)
}
}
}
@@ -247,13 +288,15 @@ func DstPort(port uint16) TransportChecker {
// SeqNum creates a checker that checks the sequence number.
func SeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
}
if s := tcp.SequenceNumber(); s != seq {
- t.Fatalf("Bad sequence number, got %v, want %v", s, seq)
+ t.Errorf("Bad sequence number, got %v, want %v", s, seq)
}
}
}
@@ -268,7 +311,7 @@ func AckNum(seq uint32) TransportChecker {
}
if s := tcp.AckNumber(); s != seq {
- t.Fatalf("Bad ack number, got %v, want %v", s, seq)
+ t.Errorf("Bad ack number, got %v, want %v", s, seq)
}
}
}
@@ -282,7 +325,7 @@ func Window(window uint16) TransportChecker {
}
if w := tcp.WindowSize(); w != window {
- t.Fatalf("Bad window, got 0x%x, want 0x%x", w, window)
+ t.Errorf("Bad window, got 0x%x, want 0x%x", w, window)
}
}
}
@@ -290,13 +333,15 @@ func Window(window uint16) TransportChecker {
// TCPFlags creates a checker that checks the tcp flags.
func TCPFlags(flags uint8) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
}
if f := tcp.Flags(); f != flags {
- t.Fatalf("Bad flags, got 0x%x, want 0x%x", f, flags)
+ t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags)
}
}
}
@@ -311,7 +356,7 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker {
}
if f := tcp.Flags(); (f & mask) != (flags & mask) {
- t.Fatalf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
+ t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
}
}
}
@@ -343,26 +388,26 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
case header.TCPOptionMSS:
v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if wantOpts.MSS != v {
- t.Fatalf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
+ t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
}
foundMSS = true
i += 4
case header.TCPOptionWS:
if wantOpts.WS < 0 {
- t.Fatalf("WS present when it shouldn't be")
+ t.Error("WS present when it shouldn't be")
}
v := int(opts[i+2])
if v != wantOpts.WS {
- t.Fatalf("Bad WS: got %v, want %v", v, wantOpts.WS)
+ t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
}
foundWS = true
i += 3
case header.TCPOptionTS:
if i+9 >= limit {
- t.Fatalf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
+ t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
}
if opts[i+1] != 10 {
- t.Fatalf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
+ t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = uint32(0)
@@ -375,10 +420,10 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
i += 10
case header.TCPOptionSACKPermitted:
if i+1 >= limit {
- t.Fatalf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
+ t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
}
if opts[i+1] != 2 {
- t.Fatalf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
+ t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
}
foundSACKPermitted = true
i += 2
@@ -389,23 +434,23 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
}
if !foundMSS {
- t.Fatalf("MSS option not found. Options: %x", opts)
+ t.Errorf("MSS option not found. Options: %x", opts)
}
if !foundWS && wantOpts.WS >= 0 {
- t.Fatalf("WS option not found. Options: %x", opts)
+ t.Errorf("WS option not found. Options: %x", opts)
}
if wantOpts.TS && !foundTS {
- t.Fatalf("TS option not found. Options: %x", opts)
+ t.Errorf("TS option not found. Options: %x", opts)
}
if foundTS && tsVal == 0 {
- t.Fatalf("TS option specified but the timestamp value is zero")
+ t.Error("TS option specified but the timestamp value is zero")
}
if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Fatalf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
+ t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
}
if wantOpts.SACKPermitted && !foundSACKPermitted {
- t.Fatalf("SACKPermitted option not found. Options: %x", opts)
+ t.Errorf("SACKPermitted option not found. Options: %x", opts)
}
}
}
@@ -435,10 +480,10 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
i++
case header.TCPOptionTS:
if i+9 >= limit {
- t.Fatalf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
+ t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
}
if opts[i+1] != 10 {
- t.Fatalf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
+ t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
@@ -458,13 +503,13 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
}
if wantTS != foundTS {
- t.Fatalf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
+ t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
}
if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Fatalf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
+ t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
}
if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Fatalf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
+ t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
}
}
}
@@ -497,12 +542,12 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
case header.TCPOptionSACK:
if i+2 > limit {
// Malformed SACK block.
- t.Fatalf("malformed SACK option in options: %v", opts)
+ t.Errorf("malformed SACK option in options: %v", opts)
}
sackOptionLen := int(opts[i+1])
if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
// Malformed SACK block.
- t.Fatalf("malformed SACK option length in options: %v", opts)
+ t.Errorf("malformed SACK option length in options: %v", opts)
}
numBlocks := sackOptionLen / 8
for j := 0; j < numBlocks; j++ {
@@ -528,7 +573,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
}
if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
- t.Fatalf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
+ t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
}
}
}
@@ -537,7 +582,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
func Payload(want []byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
if got := h.Payload(); !reflect.DeepEqual(got, want) {
- t.Fatalf("Wrong payload, got %v, want %v", got, want)
+ t.Errorf("Wrong payload, got %v, want %v", got, want)
}
}
}