summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/checker
diff options
context:
space:
mode:
authorJulian Elischer <jrelis@google.com>2020-09-30 13:03:15 -0700
committergVisor bot <gvisor-bot@google.com>2020-09-30 13:05:14 -0700
commit694d6ae32fbed0a62bc9d73f279db205815681e3 (patch)
tree1e1379b0b9541eb68b61f397a18abdc04f990850 /pkg/tcpip/checker
parent3e450a991b844a6b45fc57e59bfb0030ba0d0f4c (diff)
Use the ICMP error response facility
Add code in IPv6 to send ICMP packets while processing extension headers. Add some accounting in processing IPV6 Extension headers which allows us to report meaningful information back in ICMP parameter problem packets. IPv4 also needs to send a message when an unsupported protocol is requested. Add some tests to generate both ipv4 and ipv6 packets with various errors and check the responses. Add some new checkers and cleanup some inconsistencies in the messages in that file. Add new error types for the ICMPv4/6 generators. Fix a bug in the ICMPv4 generator that stopped it from generating "Unknown protocol" messages. Updates #2211 PiperOrigin-RevId: 334661716
Diffstat (limited to 'pkg/tcpip/checker')
-rw-r--r--pkg/tcpip/checker/checker.go170
1 files changed, 140 insertions, 30 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 19627fa9b..71b2f1bda 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -118,18 +118,62 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.HopLimit()
}
if v != ttl {
- t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
+ t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
+ }
+ }
+}
+
+// IPFullLength creates a checker for the full IP packet length. The
+// expected size is checked against both the Total Length in the
+// header and the number of bytes received.
+func IPFullLength(packetLength uint16) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ var v uint16
+ var l uint16
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ v = ip.TotalLength()
+ l = uint16(len(ip))
+ case header.IPv6:
+ v = ip.PayloadLength() + header.IPv6FixedHeaderSize
+ l = uint16(len(ip))
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip)
+ }
+ if l != packetLength {
+ t.Errorf("bad packet length, got = %d, want = %d", l, packetLength)
+ }
+ if v != packetLength {
+ t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength)
+ }
+ }
+}
+
+// IPv4HeaderLength creates a checker that checks the IPv4 Header length.
+func IPv4HeaderLength(headerLength int) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ if hl := ip.HeaderLength(); hl != uint8(headerLength) {
+ t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength)
+ }
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip)
}
}
}
// PayloadLen creates a checker that checks the payload length.
-func PayloadLen(plen int) NetworkChecker {
+func PayloadLen(payloadLength int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- if l := len(h[0].Payload()); l != plen {
- t.Errorf("Bad payload length, got %v, want %v", l, plen)
+ if l := len(h[0].Payload()); l != payloadLength {
+ t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength)
}
}
}
@@ -143,7 +187,7 @@ func FragmentOffset(offset uint16) NetworkChecker {
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.FragmentOffset(); v != offset {
- t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset)
}
}
}
@@ -158,7 +202,7 @@ func FragmentFlags(flags uint8) NetworkChecker {
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.Flags(); v != flags {
- t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags)
}
}
}
@@ -208,7 +252,7 @@ func TOS(tos uint8, label uint32) NetworkChecker {
t.Helper()
if v, l := h[0].TOS(); v != tos || l != label {
- t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
+ t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label)
}
}
}
@@ -234,7 +278,7 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
t.Helper()
if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
ipv6Frag := header.IPv6Fragment(h[0].Payload())
@@ -261,7 +305,7 @@ func TCP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
}
// Verify the checksum.
@@ -297,7 +341,7 @@ func UDP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
udp := header.UDP(last.Payload())
@@ -316,7 +360,7 @@ func SrcPort(port uint16) TransportChecker {
t.Helper()
if p := h.SourcePort(); p != port {
- t.Errorf("Bad source port, got %v, want %v", p, port)
+ t.Errorf("Bad source port, got = %d, want = %d", p, port)
}
}
}
@@ -327,7 +371,7 @@ func DstPort(port uint16) TransportChecker {
t.Helper()
if p := h.DestinationPort(); p != port {
- t.Errorf("Bad destination port, got %v, want %v", p, port)
+ t.Errorf("Bad destination port, got = %d, want = %d", p, port)
}
}
}
@@ -359,7 +403,7 @@ func TCPSeqNum(seq uint32) TransportChecker {
}
if s := tcp.SequenceNumber(); s != seq {
- t.Errorf("Bad sequence number, got %v, want %v", s, seq)
+ t.Errorf("Bad sequence number, got = %d, want = %d", s, seq)
}
}
}
@@ -375,7 +419,7 @@ func TCPAckNum(seq uint32) TransportChecker {
}
if s := tcp.AckNumber(); s != seq {
- t.Errorf("Bad ack number, got %v, want %v", s, seq)
+ t.Errorf("Bad ack number, got = %d, want = %d", s, seq)
}
}
}
@@ -492,7 +536,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
case header.TCPOptionMSS:
v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if wantOpts.MSS != v {
- t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
+ t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS)
}
foundMSS = true
i += 4
@@ -502,7 +546,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
}
v := int(opts[i+2])
if v != wantOpts.WS {
- t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
+ t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS)
}
foundWS = true
i += 3
@@ -551,7 +595,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
t.Error("TS option specified but the timestamp value is zero")
}
if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
+ t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr)
}
if wantOpts.SACKPermitted && !foundSACKPermitted {
t.Errorf("SACKPermitted option not found. Options: %x", opts)
@@ -589,7 +633,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
}
if opts[i+1] != 10 {
- t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
+ t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1])
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
@@ -609,13 +653,13 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
}
if wantTS != foundTS {
- t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
+ t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS)
}
if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
+ t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal)
}
if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
+ t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr)
}
}
}
@@ -679,7 +723,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
}
if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
- t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
+ t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks)
}
}
}
@@ -724,10 +768,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
@@ -739,10 +783,76 @@ func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident.
+func ICMPv4Ident(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Ident(); got != want {
+ t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence.
+func ICMPv4Seq(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Sequence(); got != want {
+ t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
+// This assumes that the payload exactly makes up the rest of the slice.
+func ICMPv4Checksum() TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ heldChecksum := icmpv4.Checksum()
+ icmpv4.SetChecksum(0)
+ newChecksum := ^header.Checksum(icmpv4, 0)
+ icmpv4.SetChecksum(heldChecksum)
+ if heldChecksum != newChecksum {
+ t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum)
+ }
+ }
+}
+
+// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet.
+func ICMPv4Payload(want []byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ payload := icmpv4.Payload()
+ if diff := cmp.Diff(payload, want); diff != "" {
+ t.Errorf("got ICMP payload mismatch (-want +got):\n%s", diff)
}
}
}
@@ -782,10 +892,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
@@ -797,10 +907,10 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
}
}
}