diff options
Diffstat (limited to 'pkg/tcpip/checker')
-rw-r--r-- | pkg/tcpip/checker/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/checker/checker.go | 42 |
2 files changed, 36 insertions, 7 deletions
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index ed434807f..c984470e6 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -12,5 +12,6 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index c1745ba6a..1e5f5abf2 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -21,6 +21,7 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -169,10 +170,9 @@ func ReceiveTClass(want uint32) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { t.Helper() if !cm.HasTClass { - t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want) - } - if got := cm.TClass; got != want { - t.Fatalf("got cm.TClass = %d, want %d", got, want) + t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass) + } else if got := cm.TClass; got != want { + t.Errorf("got cm.TClass = %d, want %d", got, want) } } } @@ -182,10 +182,22 @@ func ReceiveTOS(want uint8) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { t.Helper() if !cm.HasTOS { - t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) + t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS) + } else if got := cm.TOS; got != want { + t.Errorf("got cm.TOS = %d, want %d", got, want) } - if got := cm.TOS; got != want { - t.Fatalf("got cm.TOS = %d, want %d", got, want) + } +} + +// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in +// ControlMessages. +func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasIPPacketInfo { + t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo) + } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" { + t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff) } } } @@ -320,6 +332,22 @@ func DstPort(port uint16) TransportChecker { } } +// NoChecksum creates a checker that checks if the checksum is zero. +func NoChecksum(noChecksum bool) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + udp, ok := h.(header.UDP) + if !ok { + return + } + + if b := udp.Checksum() == 0; b != noChecksum { + t.Errorf("bad checksum state, got %t, want %t", b, noChecksum) + } + } +} + // SeqNum creates a checker that checks the sequence number. func SeqNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { |