diff options
Diffstat (limited to 'pkg/tcpip/link/sharedmem/sharedmem_test.go')
-rw-r--r-- | pkg/tcpip/link/sharedmem/sharedmem_test.go | 158 |
1 files changed, 109 insertions, 49 deletions
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 9a0348deb..9a6b7d929 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -17,10 +17,12 @@ package sharedmem import ( + "bytes" "io/ioutil" "math/rand" "os" "reflect" + "strings" "sync" "syscall" "testing" @@ -129,10 +131,10 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ - addr: remoteAddr, + addr: remoteLinkAddr, proto: proto, vv: vv.Clone(nil), }) @@ -259,62 +261,120 @@ func TestSimpleSend(t *testing.T) { } for iters := 1000; iters > 0; iters-- { - // Prepare and send packet. - n := rand.Intn(10000) - hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) - hdrBuf := hdr.Prepend(n) - randomFill(hdrBuf) - - n = rand.Intn(10000) - buf := buffer.NewView(n) - randomFill(buf) - - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), proto); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } + func() { + // Prepare and send packet. + n := rand.Intn(10000) + hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) + hdrBuf := hdr.Prepend(n) + randomFill(hdrBuf) + + n = rand.Intn(10000) + buf := buffer.NewView(n) + randomFill(buf) + + proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) + if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), proto); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } - // Receive packet. - desc := c.txq.tx.Pull() - pi := queue.DecodeTxPacketHeader(desc) - contents := make([]byte, 0, pi.Size) - for i := 0; i < pi.BufferCount; i++ { - bi := queue.DecodeTxBufferHeader(desc, i) - contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) - } - c.txq.tx.Flush() + // Receive packet. + desc := c.txq.tx.Pull() + pi := queue.DecodeTxPacketHeader(desc) + if pi.Reserved != 0 { + t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) + } + contents := make([]byte, 0, pi.Size) + for i := 0; i < pi.BufferCount; i++ { + bi := queue.DecodeTxBufferHeader(desc, i) + contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) + } + c.txq.tx.Flush() + + defer func() { + // Tell the endpoint about the completion of the write. + b := c.txq.rx.Push(8) + queue.EncodeTxCompletion(b, pi.ID) + c.txq.rx.Flush() + }() + + // Check the ethernet header. + ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) + ethTemplate.Encode(&header.EthernetFields{ + SrcAddr: localLinkAddr, + DstAddr: remoteLinkAddr, + Type: proto, + }) + if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { + t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) + } - if pi.Reserved != 0 { - t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) - } + // Compare contents skipping the ethernet header added by the + // endpoint. + merged := append(hdrBuf, buf...) + if uint32(len(contents)) < pi.Size { + t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) + } + contents = contents[:pi.Size][header.EthernetMinimumSize:] - // Check the thernet header. - ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) - ethTemplate.Encode(&header.EthernetFields{ - SrcAddr: localLinkAddr, - DstAddr: remoteLinkAddr, - Type: proto, - }) - if got := contents[:header.EthernetMinimumSize]; !reflect.DeepEqual(got, []byte(ethTemplate)) { - t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) - } + if !bytes.Equal(contents, merged) { + t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged)) + } + }() + } +} - // Compare contents skipping the ethernet header added by the - // endpoint. - merged := append(hdrBuf, buf...) - if uint32(len(contents)) < pi.Size { - t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) - } - contents = contents[:pi.Size][header.EthernetMinimumSize:] +// TestPreserveSrcAddressInSend calls WritePacket once with LocalLinkAddress +// set in Route (using much of the same code as TestSimpleSend), then checks +// that the encoded ethernet header received includes the correct SrcAddr. +func TestPreserveSrcAddressInSend(t *testing.T) { + c := newTestContext(t, 20000, 1500, localLinkAddr) + defer c.cleanup() - if !reflect.DeepEqual(contents, merged) { - t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged)) - } + newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) + // Set both remote and local link address in route. + r := stack.Route{ + RemoteLinkAddress: remoteLinkAddr, + LocalLinkAddress: newLocalLinkAddress, + } + + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + hdr := buffer.NewPrependable(header.EthernetMinimumSize) + + proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) + if err := c.ep.WritePacket(&r, hdr, buffer.VectorisedView{}, proto); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } + // Receive packet. + desc := c.txq.tx.Pull() + pi := queue.DecodeTxPacketHeader(desc) + if pi.Reserved != 0 { + t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) + } + contents := make([]byte, 0, pi.Size) + for i := 0; i < pi.BufferCount; i++ { + bi := queue.DecodeTxBufferHeader(desc, i) + contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) + } + c.txq.tx.Flush() + + defer func() { // Tell the endpoint about the completion of the write. b := c.txq.rx.Push(8) queue.EncodeTxCompletion(b, pi.ID) c.txq.rx.Flush() + }() + + // Check that the ethernet header contains the expected SrcAddr. + ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) + ethTemplate.Encode(&header.EthernetFields{ + SrcAddr: newLocalLinkAddress, + DstAddr: remoteLinkAddr, + Type: proto, + }) + if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) { + t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) } } @@ -583,7 +643,7 @@ func TestSimpleReceive(t *testing.T) { c.mu.Unlock() contents = contents[header.EthernetMinimumSize:] - if !reflect.DeepEqual(contents, rcvd) { + if !bytes.Equal(contents, rcvd) { t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents) } |