diff options
Diffstat (limited to 'pkg/tcpip/checker/checker.go')
-rw-r--r-- | pkg/tcpip/checker/checker.go | 133 |
1 files changed, 89 insertions, 44 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 8e0e49efa..206531f20 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -39,40 +39,52 @@ type TransportChecker func(*testing.T, header.Transport) // // checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + ipv4 := header.IPv4(b) if !ipv4.IsValid(len(b)) { - t.Fatalf("Not a valid IPv4 packet") + t.Error("Not a valid IPv4 packet") } xsum := ipv4.CalculateChecksum() if xsum != 0 && xsum != 0xffff { - t.Fatalf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) + t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) } for _, f := range checkers { f(t, []header.Network{ipv4}) } + if t.Failed() { + t.FailNow() + } } // IPv6 checks the validity and properties of the given IPv6 packet. The usage // is similar to IPv4. func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + ipv6 := header.IPv6(b) if !ipv6.IsValid(len(b)) { - t.Fatalf("Not a valid IPv6 packet") + t.Error("Not a valid IPv6 packet") } for _, f := range checkers { f(t, []header.Network{ipv6}) } + if t.Failed() { + t.FailNow() + } } // SrcAddr creates a checker that checks the source address. func SrcAddr(addr tcpip.Address) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if a := h[0].SourceAddress(); a != addr { - t.Fatalf("Bad source address, got %v, want %v", a, addr) + t.Errorf("Bad source address, got %v, want %v", a, addr) } } } @@ -80,8 +92,10 @@ func SrcAddr(addr tcpip.Address) NetworkChecker { // DstAddr creates a checker that checks the destination address. func DstAddr(addr tcpip.Address) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if a := h[0].DestinationAddress(); a != addr { - t.Fatalf("Bad destination address, got %v, want %v", a, addr) + t.Errorf("Bad destination address, got %v, want %v", a, addr) } } } @@ -105,8 +119,10 @@ func TTL(ttl uint8) NetworkChecker { // PayloadLen creates a checker that checks the payload length. func PayloadLen(plen int) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if l := len(h[0].Payload()); l != plen { - t.Fatalf("Bad payload length, got %v, want %v", l, plen) + t.Errorf("Bad payload length, got %v, want %v", l, plen) } } } @@ -114,11 +130,13 @@ func PayloadLen(plen int) NetworkChecker { // FragmentOffset creates a checker that checks the FragmentOffset field. func FragmentOffset(offset uint16) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + // We only do this of IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.FragmentOffset(); v != offset { - t.Fatalf("Bad fragment offset, got %v, want %v", v, offset) + t.Errorf("Bad fragment offset, got %v, want %v", v, offset) } } } @@ -127,11 +145,13 @@ func FragmentOffset(offset uint16) NetworkChecker { // FragmentFlags creates a checker that checks the fragment flags field. func FragmentFlags(flags uint8) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + // We only do this of IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.Flags(); v != flags { - t.Fatalf("Bad fragment offset, got %v, want %v", v, flags) + t.Errorf("Bad fragment offset, got %v, want %v", v, flags) } } } @@ -140,8 +160,10 @@ func FragmentFlags(flags uint8) NetworkChecker { // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if v, l := h[0].TOS(); v != tos || l != label { - t.Fatalf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) + t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) } } } @@ -153,8 +175,10 @@ func TOS(tos uint8, label uint32) NetworkChecker { // the bytes added by the IPv6 fragmentation. func Raw(want []byte) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { - t.Fatalf("Wrong payload, got %v, want %v", got, want) + t.Errorf("Wrong payload, got %v, want %v", got, want) } } } @@ -162,18 +186,23 @@ func Raw(want []byte) NetworkChecker { // IPv6Fragment creates a checker that validates an IPv6 fragment. func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Fatalf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) } ipv6Frag := header.IPv6Fragment(h[0].Payload()) if !ipv6Frag.IsValid() { - t.Fatalf("Not a valid IPv6 fragment") + t.Error("Not a valid IPv6 fragment") } for _, f := range checkers { f(t, []header.Network{h[0], ipv6Frag}) } + if t.Failed() { + t.FailNow() + } } } @@ -181,11 +210,13 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { // potentially additional transport header fields. func TCP(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + first := h[0] last := h[len(h)-1] if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Fatalf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) + t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) } // Verify the checksum. @@ -199,13 +230,16 @@ func TCP(checkers ...TransportChecker) NetworkChecker { xsum = header.Checksum(tcp, xsum) if xsum != 0 && xsum != 0xffff { - t.Fatalf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) + t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) } // Run the transport checkers. for _, f := range checkers { f(t, tcp) } + if t.Failed() { + t.FailNow() + } } } @@ -213,24 +247,31 @@ func TCP(checkers ...TransportChecker) NetworkChecker { // potentially additional transport header fields. func UDP(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + last := h[len(h)-1] if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Fatalf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) } udp := header.UDP(last.Payload()) for _, f := range checkers { f(t, udp) } + if t.Failed() { + t.FailNow() + } } } // SrcPort creates a checker that checks the source port. func SrcPort(port uint16) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + if p := h.SourcePort(); p != port { - t.Fatalf("Bad source port, got %v, want %v", p, port) + t.Errorf("Bad source port, got %v, want %v", p, port) } } } @@ -239,7 +280,7 @@ func SrcPort(port uint16) TransportChecker { func DstPort(port uint16) TransportChecker { return func(t *testing.T, h header.Transport) { if p := h.DestinationPort(); p != port { - t.Fatalf("Bad destination port, got %v, want %v", p, port) + t.Errorf("Bad destination port, got %v, want %v", p, port) } } } @@ -247,13 +288,15 @@ func DstPort(port uint16) TransportChecker { // SeqNum creates a checker that checks the sequence number. func SeqNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return } if s := tcp.SequenceNumber(); s != seq { - t.Fatalf("Bad sequence number, got %v, want %v", s, seq) + t.Errorf("Bad sequence number, got %v, want %v", s, seq) } } } @@ -268,7 +311,7 @@ func AckNum(seq uint32) TransportChecker { } if s := tcp.AckNumber(); s != seq { - t.Fatalf("Bad ack number, got %v, want %v", s, seq) + t.Errorf("Bad ack number, got %v, want %v", s, seq) } } } @@ -282,7 +325,7 @@ func Window(window uint16) TransportChecker { } if w := tcp.WindowSize(); w != window { - t.Fatalf("Bad window, got 0x%x, want 0x%x", w, window) + t.Errorf("Bad window, got 0x%x, want 0x%x", w, window) } } } @@ -290,13 +333,15 @@ func Window(window uint16) TransportChecker { // TCPFlags creates a checker that checks the tcp flags. func TCPFlags(flags uint8) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return } if f := tcp.Flags(); f != flags { - t.Fatalf("Bad flags, got 0x%x, want 0x%x", f, flags) + t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) } } } @@ -311,7 +356,7 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker { } if f := tcp.Flags(); (f & mask) != (flags & mask) { - t.Fatalf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) + t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) } } } @@ -343,26 +388,26 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { case header.TCPOptionMSS: v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) if wantOpts.MSS != v { - t.Fatalf("Bad MSS: got %v, want %v", v, wantOpts.MSS) + t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) } foundMSS = true i += 4 case header.TCPOptionWS: if wantOpts.WS < 0 { - t.Fatalf("WS present when it shouldn't be") + t.Error("WS present when it shouldn't be") } v := int(opts[i+2]) if v != wantOpts.WS { - t.Fatalf("Bad WS: got %v, want %v", v, wantOpts.WS) + t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) } foundWS = true i += 3 case header.TCPOptionTS: if i+9 >= limit { - t.Fatalf("TS Option truncated , option is only: %d bytes, want 10", limit-i) + t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) } if opts[i+1] != 10 { - t.Fatalf("Bad length %d for TS option, limit: %d", opts[i+1], limit) + t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) } tsVal = binary.BigEndian.Uint32(opts[i+2:]) tsEcr = uint32(0) @@ -375,10 +420,10 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { i += 10 case header.TCPOptionSACKPermitted: if i+1 >= limit { - t.Fatalf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) + t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) } if opts[i+1] != 2 { - t.Fatalf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) + t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) } foundSACKPermitted = true i += 2 @@ -389,23 +434,23 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { } if !foundMSS { - t.Fatalf("MSS option not found. Options: %x", opts) + t.Errorf("MSS option not found. Options: %x", opts) } if !foundWS && wantOpts.WS >= 0 { - t.Fatalf("WS option not found. Options: %x", opts) + t.Errorf("WS option not found. Options: %x", opts) } if wantOpts.TS && !foundTS { - t.Fatalf("TS option not found. Options: %x", opts) + t.Errorf("TS option not found. Options: %x", opts) } if foundTS && tsVal == 0 { - t.Fatalf("TS option specified but the timestamp value is zero") + t.Error("TS option specified but the timestamp value is zero") } if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Fatalf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) + t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) } if wantOpts.SACKPermitted && !foundSACKPermitted { - t.Fatalf("SACKPermitted option not found. Options: %x", opts) + t.Errorf("SACKPermitted option not found. Options: %x", opts) } } } @@ -435,10 +480,10 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp i++ case header.TCPOptionTS: if i+9 >= limit { - t.Fatalf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) + t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) } if opts[i+1] != 10 { - t.Fatalf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) + t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) } tsVal = binary.BigEndian.Uint32(opts[i+2:]) tsEcr = binary.BigEndian.Uint32(opts[i+6:]) @@ -458,13 +503,13 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } if wantTS != foundTS { - t.Fatalf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) + t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) } if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Fatalf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) + t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) } if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Fatalf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) + t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) } } } @@ -497,12 +542,12 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { case header.TCPOptionSACK: if i+2 > limit { // Malformed SACK block. - t.Fatalf("malformed SACK option in options: %v", opts) + t.Errorf("malformed SACK option in options: %v", opts) } sackOptionLen := int(opts[i+1]) if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { // Malformed SACK block. - t.Fatalf("malformed SACK option length in options: %v", opts) + t.Errorf("malformed SACK option length in options: %v", opts) } numBlocks := sackOptionLen / 8 for j := 0; j < numBlocks; j++ { @@ -528,7 +573,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Fatalf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) + t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) } } } @@ -537,7 +582,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { func Payload(want []byte) TransportChecker { return func(t *testing.T, h header.Transport) { if got := h.Payload(); !reflect.DeepEqual(got, want) { - t.Fatalf("Wrong payload, got %v, want %v", got, want) + t.Errorf("Wrong payload, got %v, want %v", got, want) } } } |