summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/checker/checker.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/checker/checker.go')
-rw-r--r--pkg/tcpip/checker/checker.go42
1 files changed, 35 insertions, 7 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index c1745ba6a..1e5f5abf2 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -21,6 +21,7 @@ import (
"reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -169,10 +170,9 @@ func ReceiveTClass(want uint32) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
t.Helper()
if !cm.HasTClass {
- t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want)
- }
- if got := cm.TClass; got != want {
- t.Fatalf("got cm.TClass = %d, want %d", got, want)
+ t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass)
+ } else if got := cm.TClass; got != want {
+ t.Errorf("got cm.TClass = %d, want %d", got, want)
}
}
}
@@ -182,10 +182,22 @@ func ReceiveTOS(want uint8) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
t.Helper()
if !cm.HasTOS {
- t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want)
+ t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS)
+ } else if got := cm.TOS; got != want {
+ t.Errorf("got cm.TOS = %d, want %d", got, want)
}
- if got := cm.TOS; got != want {
- t.Fatalf("got cm.TOS = %d, want %d", got, want)
+ }
+}
+
+// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in
+// ControlMessages.
+func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasIPPacketInfo {
+ t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo)
+ } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" {
+ t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff)
}
}
}
@@ -320,6 +332,22 @@ func DstPort(port uint16) TransportChecker {
}
}
+// NoChecksum creates a checker that checks if the checksum is zero.
+func NoChecksum(noChecksum bool) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ udp, ok := h.(header.UDP)
+ if !ok {
+ return
+ }
+
+ if b := udp.Checksum() == 0; b != noChecksum {
+ t.Errorf("bad checksum state, got %t, want %t", b, noChecksum)
+ }
+ }
+}
+
// SeqNum creates a checker that checks the sequence number.
func SeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {