From 8f70c6ef351110cf94e758b6dc295387f0388707 Mon Sep 17 00:00:00 2001 From: Arthur Sfez Date: Thu, 15 Oct 2020 09:27:03 -0700 Subject: Refactor compareFragments to follow Go style Test helpers should be used for test setup/teardown, not actual testing. Use cmp.Diff instead of bytes.Equal to improve readability. PiperOrigin-RevId: 337323242 --- pkg/tcpip/network/ipv4/ipv4_test.go | 62 ++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 29 deletions(-) (limited to 'pkg/tcpip/network/ipv4') diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 9916d783f..819b5d71f 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -15,9 +15,9 @@ package ipv4_test import ( - "bytes" "context" "encoding/hex" + "fmt" "math" "net" "testing" @@ -243,7 +243,7 @@ func TestIPv4Sanity(t *testing.T) { // Default routes for IPv4 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv4EmptySubnet, NIC: nicID, }, @@ -369,11 +369,10 @@ func TestIPv4Sanity(t *testing.T) { // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { - t.Helper() - // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize]) - vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views()) +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32) error { + // Make a complete array of the sourcePacket packet. + source := header.IPv4(packets[0].NetworkHeader().View()) + vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make @@ -384,46 +383,49 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI sourceCopy.SetTotalLength(0) var offset uint16 // Build up an array of the bytes sent. - var reassembledPayload []byte + var reassembledPayload buffer.VectorisedView for i, packet := range packets { // Confirm that the packet is valid. allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) - ip := header.IPv4(allBytes.ToView()) - if !ip.IsValid(len(ip)) { - t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) + fragmentIPHeader := header.IPv4(allBytes.ToView()) + if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) { + return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader)) } - if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { - t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) + if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { + return fmt.Errorf("fragment #%d: fragmentIPHeader.CalculateChecksum() got %#x, want %#x", i, got, want) } - if got, want := len(ip), int(mtu); got > want { - t.Errorf("fragment is too large, got %d want %d", got, want) + if got := len(fragmentIPHeader); got > int(mtu) { + return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu) } - if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { - t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) + if got, want := packet.AvailableHeaderBytes(), sourcePacket.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { + return fmt.Errorf("fragment #%d: should have the same available space for prepending as source: got %d, want %d", i, got, want) } - if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { - t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want) + if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { + return fmt.Errorf("fragment #%d: has wrong network protocol number: got %d, want %d", i, got, want) } if i < len(packets)-1 { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) } else { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) } - reassembledPayload = append(reassembledPayload, ip.Payload()...) - offset += ip.TotalLength() - uint16(ip.HeaderLength()) + reassembledPayload.AppendView(packet.TransportHeader().View()) + reassembledPayload.Append(packet.Data) + offset += fragmentIPHeader.TotalLength() - uint16(fragmentIPHeader.HeaderLength()) // Clear out the checksum and length from the ip because we can't compare // it. - sourceCopy.SetTotalLength(uint16(len(ip))) + sourceCopy.SetTotalLength(uint16(len(fragmentIPHeader))) sourceCopy.SetChecksum(0) sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { - t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) + if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { + return fmt.Errorf("fragment #%d: fragmentIPHeader[:fragmentIPHeader.HeaderLength()] mismatch (-want +got):\n%s", i, diff) } } - expected := source[source.HeaderLength():] - if !bytes.Equal(reassembledPayload, expected) { - t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) + expected := buffer.View(source[source.HeaderLength():]) + if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" { + return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) } + + return nil } func TestFragmentation(t *testing.T) { @@ -477,7 +479,9 @@ func TestFragmentation(t *testing.T) { if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - compareFragments(t, ep.WrittenPackets, source, ft.mtu) + if err := compareFragments(ep.WrittenPackets, source, ft.mtu); err != nil { + t.Error(err) + } }) } } @@ -1633,7 +1637,7 @@ func TestPacketQueing(t *testing.T) { } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), NIC: nicID, }, -- cgit v1.2.3