summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorArthur Sfez <asfez@google.com>2020-10-15 09:27:03 -0700
committergVisor bot <gvisor-bot@google.com>2020-10-15 09:29:01 -0700
commit8f70c6ef351110cf94e758b6dc295387f0388707 (patch)
tree6fc481c8fcbe20f55ea67099543d6c0ee63acc25
parent6e6a9d3f3dd6dd9ce290952406ba7ca7c5570311 (diff)
Refactor compareFragments to follow Go style
Test helpers should be used for test setup/teardown, not actual testing. Use cmp.Diff instead of bytes.Equal to improve readability. PiperOrigin-RevId: 337323242
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go62
1 files changed, 33 insertions, 29 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 9916d783f..819b5d71f 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -15,9 +15,9 @@
package ipv4_test
import (
- "bytes"
"context"
"encoding/hex"
+ "fmt"
"math"
"net"
"testing"
@@ -243,7 +243,7 @@ func TestIPv4Sanity(t *testing.T) {
// Default routes for IPv4 so ICMP can find a route to the remote
// node when attempting to send the ICMP Echo Reply.
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv4EmptySubnet,
NIC: nicID,
},
@@ -369,11 +369,10 @@ func TestIPv4Sanity(t *testing.T) {
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
- t.Helper()
- // Make a complete array of the sourcePacketInfo packet.
- source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize])
- vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views())
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32) error {
+ // Make a complete array of the sourcePacket packet.
+ source := header.IPv4(packets[0].NetworkHeader().View())
+ vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
source = append(source, vv.ToView()...)
// Make a copy of the IP header, which will be modified in some fields to make
@@ -384,46 +383,49 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
sourceCopy.SetTotalLength(0)
var offset uint16
// Build up an array of the bytes sent.
- var reassembledPayload []byte
+ var reassembledPayload buffer.VectorisedView
for i, packet := range packets {
// Confirm that the packet is valid.
allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views())
- ip := header.IPv4(allBytes.ToView())
- if !ip.IsValid(len(ip)) {
- t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
+ fragmentIPHeader := header.IPv4(allBytes.ToView())
+ if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) {
+ return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader))
}
- if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want {
- t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want)
+ if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader.CalculateChecksum() got %#x, want %#x", i, got, want)
}
- if got, want := len(ip), int(mtu); got > want {
- t.Errorf("fragment is too large, got %d want %d", got, want)
+ if got := len(fragmentIPHeader); got > int(mtu) {
+ return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu)
}
- if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
- t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
+ if got, want := packet.AvailableHeaderBytes(), sourcePacket.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
+ return fmt.Errorf("fragment #%d: should have the same available space for prepending as source: got %d, want %d", i, got, want)
}
- if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want {
- t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want)
+ if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want {
+ return fmt.Errorf("fragment #%d: has wrong network protocol number: got %d, want %d", i, got, want)
}
if i < len(packets)-1 {
sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
} else {
sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset)
}
- reassembledPayload = append(reassembledPayload, ip.Payload()...)
- offset += ip.TotalLength() - uint16(ip.HeaderLength())
+ reassembledPayload.AppendView(packet.TransportHeader().View())
+ reassembledPayload.Append(packet.Data)
+ offset += fragmentIPHeader.TotalLength() - uint16(fragmentIPHeader.HeaderLength())
// Clear out the checksum and length from the ip because we can't compare
// it.
- sourceCopy.SetTotalLength(uint16(len(ip)))
+ sourceCopy.SetTotalLength(uint16(len(fragmentIPHeader)))
sourceCopy.SetChecksum(0)
sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
- if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) {
- t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()]))
+ if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader[:fragmentIPHeader.HeaderLength()] mismatch (-want +got):\n%s", i, diff)
}
}
- expected := source[source.HeaderLength():]
- if !bytes.Equal(reassembledPayload, expected) {
- t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected))
+ expected := buffer.View(source[source.HeaderLength():])
+ if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" {
+ return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
}
+
+ return nil
}
func TestFragmentation(t *testing.T) {
@@ -477,7 +479,9 @@ func TestFragmentation(t *testing.T) {
if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
}
- compareFragments(t, ep.WrittenPackets, source, ft.mtu)
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu); err != nil {
+ t.Error(err)
+ }
})
}
}
@@ -1633,7 +1637,7 @@ func TestPacketQueing(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
NIC: nicID,
},