summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/ipv4/ipv4_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network/ipv4/ipv4_test.go')
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go41
1 files changed, 18 insertions, 23 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 01dfb5f20..e900f1b45 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -113,12 +113,12 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) {
+func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) {
t.Helper()
// Make a complete array of the sourcePacketInfo packet.
source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
source = append(source, sourcePacketInfo.Header.View()...)
- source = append(source, sourcePacketInfo.Payload.ToView()...)
+ source = append(source, sourcePacketInfo.Data.ToView()...)
// Make a copy of the IP header, which will be modified in some fields to make
// an expected header.
@@ -132,7 +132,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe
for i, packet := range packets {
// Confirm that the packet is valid.
allBytes := packet.Header.View().ToVectorisedView()
- allBytes.Append(packet.Payload)
+ allBytes.Append(packet.Data)
ip := header.IPv4(allBytes.ToView())
if !ip.IsValid(len(ip)) {
t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
@@ -173,7 +173,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe
type errorChannel struct {
*channel.Endpoint
- Ch chan packetInfo
+ Ch chan tcpip.PacketBuffer
packetCollectorErrors []*tcpip.Error
}
@@ -183,17 +183,11 @@ type errorChannel struct {
func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
return &errorChannel{
Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan packetInfo, size),
+ Ch: make(chan tcpip.PacketBuffer, size),
packetCollectorErrors: packetCollectorErrors,
}
}
-// packetInfo holds all the information about an outbound packet.
-type packetInfo struct {
- Header buffer.Prependable
- Payload buffer.VectorisedView
-}
-
// Drain removes all outbound packets from the channel and counts them.
func (e *errorChannel) Drain() int {
c := 0
@@ -208,14 +202,9 @@ func (e *errorChannel) Drain() int {
}
// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- p := packetInfo{
- Header: hdr,
- Payload: payload,
- }
-
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
select {
- case e.Ch <- p:
+ case e.Ch <- pkt:
default:
}
@@ -292,18 +281,21 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
- source := packetInfo{
+ source := tcpip.PacketBuffer{
Header: hdr,
// Save the source payload because WritePacket will modify it.
- Payload: payload.Clone([]buffer.View{}),
+ Data: payload.Clone(nil),
}
c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS})
+ err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
if err != nil {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []packetInfo
+ var results []tcpip.PacketBuffer
L:
for {
select {
@@ -345,7 +337,10 @@ func TestFragmentationErrors(t *testing.T) {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS})
+ err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)