summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/link/channel/channel.go43
-rw-r--r--pkg/tcpip/network/arp/arp_test.go16
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go7
-rw-r--r--pkg/tcpip/stack/ndp_test.go87
-rw-r--r--pkg/tcpip/stack/stack_test.go4
-rw-r--r--pkg/tcpip/stack/transport_test.go6
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go86
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go182
8 files changed, 229 insertions, 202 deletions
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 70188551f..71b9da797 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -18,6 +18,8 @@
package channel
import (
+ "context"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -38,25 +40,52 @@ type Endpoint struct {
linkAddr tcpip.LinkAddress
GSO bool
- // C is where outbound packets are queued.
- C chan PacketInfo
+ // c is where outbound packets are queued.
+ c chan PacketInfo
}
// New creates a new channel endpoint.
func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
return &Endpoint{
- C: make(chan PacketInfo, size),
+ c: make(chan PacketInfo, size),
mtu: mtu,
linkAddr: linkAddr,
}
}
+// Close closes e. Further packet injections will panic. Reads continue to
+// succeed until all packets are read.
+func (e *Endpoint) Close() {
+ close(e.c)
+}
+
+// Read does non-blocking read for one packet from the outbound packet queue.
+func (e *Endpoint) Read() (PacketInfo, bool) {
+ select {
+ case pkt := <-e.c:
+ return pkt, true
+ default:
+ return PacketInfo{}, false
+ }
+}
+
+// ReadContext does blocking read for one packet from the outbound packet queue.
+// It can be cancelled by ctx, and in this case, it returns false.
+func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ select {
+ case pkt := <-e.c:
+ return pkt, true
+ case <-ctx.Done():
+ return PacketInfo{}, false
+ }
+}
+
// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for {
select {
- case <-e.C:
+ case <-e.c:
c++
default:
return c
@@ -125,7 +154,7 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
select {
- case e.C <- p:
+ case e.c <- p:
default:
}
@@ -150,7 +179,7 @@ packetLoop:
}
select {
- case e.C <- p:
+ case e.c <- p:
n++
default:
break packetLoop
@@ -169,7 +198,7 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
}
select {
- case e.C <- p:
+ case e.c <- p:
default:
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 8e6048a21..03cf03b6d 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -15,6 +15,7 @@
package arp_test
import (
+ "context"
"strconv"
"testing"
"time"
@@ -83,7 +84,7 @@ func newTestContext(t *testing.T) *testContext {
}
func (c *testContext) cleanup() {
- close(c.linkEP.C)
+ c.linkEP.Close()
}
func TestDirectRequest(t *testing.T) {
@@ -110,7 +111,7 @@ func TestDirectRequest(t *testing.T) {
for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
inject(address)
- pi := <-c.linkEP.C
+ pi, _ := c.linkEP.ReadContext(context.Background())
if pi.Proto != arp.ProtocolNumber {
t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
}
@@ -134,12 +135,11 @@ func TestDirectRequest(t *testing.T) {
}
inject(stackAddrBad)
- select {
- case pkt := <-c.linkEP.C:
+ // Sleep tests are gross, but this will only potentially flake
+ // if there's a bug. If there is no bug this will reliably
+ // succeed.
+ ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ if pkt, ok := c.linkEP.ReadContext(ctx); ok {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
- case <-time.After(100 * time.Millisecond):
- // Sleep tests are gross, but this will only potentially flake
- // if there's a bug. If there is no bug this will reliably
- // succeed.
}
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index a2fdc5dcd..7a6820643 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "context"
"reflect"
"strings"
"testing"
@@ -264,8 +265,8 @@ func newTestContext(t *testing.T) *testContext {
}
func (c *testContext) cleanup() {
- close(c.linkEP0.C)
- close(c.linkEP1.C)
+ c.linkEP0.Close()
+ c.linkEP1.Close()
}
type routeArgs struct {
@@ -276,7 +277,7 @@ type routeArgs struct {
func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
t.Helper()
- pi := <-args.src.C
+ pi, _ := args.src.ReadContext(context.Background())
{
views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index f9460bd51..ad2c6f601 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "context"
"encoding/binary"
"fmt"
"testing"
@@ -405,7 +406,7 @@ func TestDADResolve(t *testing.T) {
// Validate the sent Neighbor Solicitation messages.
for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
- p := <-e.C
+ p, _ := e.ReadContext(context.Background())
// Make sure its an IPv6 packet.
if p.Proto != header.IPv6ProtocolNumber {
@@ -3285,29 +3286,29 @@ func TestRouterSolicitation(t *testing.T) {
e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
waitForPkt := func(timeout time.Duration) {
t.Helper()
- select {
- case p := <-e.C:
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t,
- p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(),
- )
-
- case <-time.After(timeout):
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ p, ok := e.ReadContext(ctx)
+ if !ok {
t.Fatal("timed out waiting for packet")
+ return
}
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t,
+ p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(),
+ )
}
waitForNothing := func(timeout time.Duration) {
t.Helper()
- select {
- case <-e.C:
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet")
- case <-time.After(timeout):
}
}
s := stack.New(stack.Options{
@@ -3362,20 +3363,21 @@ func TestStopStartSolicitingRouters(t *testing.T) {
e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
waitForPkt := func(timeout time.Duration) {
t.Helper()
- select {
- case p := <-e.C:
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS())
-
- case <-time.After(timeout):
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ p, ok := e.ReadContext(ctx)
+ if !ok {
t.Fatal("timed out waiting for packet")
+ return
}
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
@@ -3391,23 +3393,20 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Enable forwarding which should stop router solicitations.
s.SetForwarding(true)
- select {
- case <-e.C:
+ ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
// A single RS may have been sent before forwarding was enabled.
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ if _, ok = e.ReadContext(ctx); ok {
t.Fatal("Should not have sent more than one RS message")
- case <-time.After(interval + defaultTimeout):
}
- case <-time.After(delay + defaultTimeout):
}
// Enabling forwarding again should do nothing.
s.SetForwarding(true)
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after becoming a router")
- case <-time.After(delay + defaultTimeout):
}
// Disable forwarding which should start router solicitations.
@@ -3415,17 +3414,15 @@ func TestStopStartSolicitingRouters(t *testing.T) {
waitForPkt(delay + defaultAsyncEventTimeout)
waitForPkt(interval + defaultAsyncEventTimeout)
waitForPkt(interval + defaultAsyncEventTimeout)
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
- case <-time.After(interval + defaultTimeout):
}
// Disabling forwarding again should do nothing.
s.SetForwarding(false)
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after becoming a router")
- case <-time.After(delay + defaultTimeout):
}
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index dad288642..834fe9487 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -1880,9 +1880,7 @@ func TestNICForwarding(t *testing.T) {
Data: buf.ToVectorisedView(),
})
- select {
- case <-ep2.C:
- default:
+ if _, ok := ep2.Read(); !ok {
t.Fatal("Packet not forwarded")
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index f50604a8a..869c69a6d 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -623,10 +623,8 @@ func TestTransportForwarding(t *testing.T) {
t.Fatalf("Write failed: %v", err)
}
- var p channel.PacketInfo
- select {
- case p = <-ep2.C:
- default:
+ p, ok := ep2.Read()
+ if !ok {
t.Fatal("Response packet not forwarded")
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 822907998..730ac4292 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -18,6 +18,7 @@ package context
import (
"bytes"
+ "context"
"testing"
"time"
@@ -215,11 +216,9 @@ func (c *Context) Stack() *stack.Stack {
func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
c.t.Helper()
- select {
- case <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), wait)
+ if _, ok := c.linkEP.ReadContext(ctx); ok {
c.t.Fatal(errMsg)
-
- case <-time.After(wait):
}
}
@@ -234,27 +233,27 @@ func (c *Context) CheckNoPacket(errMsg string) {
// 2 seconds.
func (c *Context) GetPacket() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
- if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
- c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
- }
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
- case <-time.After(2 * time.Second):
- c.t.Fatalf("Packet wasn't written out")
+ if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
+ c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
}
- return nil
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
}
// GetPacketNonBlocking reads a packet from the link layer endpoint
@@ -263,20 +262,21 @@ func (c *Context) GetPacket() []byte {
// nil immediately.
func (c *Context) GetPacketNonBlocking() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
-
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
- default:
+ p, ok := c.linkEP.Read()
+ if !ok {
return nil
}
+
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
}
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
@@ -484,23 +484,23 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
func (c *Context) GetV6Packet() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
- }
- b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
- copy(b, p.Pkt.Header.View())
- copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
-
- checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
- return b
- case <-time.After(2 * time.Second):
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ if p.Proto != ipv6.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
+ b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
+ copy(b, p.Pkt.Header.View())
+ copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
- return nil
+ checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
+ return b
}
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index c6927cfe3..f0ff3fe71 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@ package udp_test
import (
"bytes"
+ "context"
"fmt"
"math/rand"
"testing"
@@ -357,30 +358,29 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != flow.netProto() {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
- }
-
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
-
- h := flow.header4Tuple(outgoing)
- checkers := append(
- checkers,
- checker.SrcAddr(h.srcAddr.Addr),
- checker.DstAddr(h.dstAddr.Addr),
- checker.UDP(checker.DstPort(h.dstAddr.Port)),
- )
- flow.checkerFn()(c.t, b, checkers...)
- return b
-
- case <-time.After(2 * time.Second):
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
c.t.Fatalf("Packet wasn't written out")
+ return nil
}
- return nil
+ if p.Proto != flow.netProto() {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
+ }
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+
+ h := flow.header4Tuple(outgoing)
+ checkers = append(
+ checkers,
+ checker.SrcAddr(h.srcAddr.Addr),
+ checker.DstAddr(h.dstAddr.Addr),
+ checker.UDP(checker.DstPort(h.dstAddr.Port)),
+ )
+ flow.checkerFn()(c.t, b, checkers...)
+ return b
}
// injectPacket creates a packet of the given flow and with the given payload,
@@ -1541,48 +1541,50 @@ func TestV4UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- select {
- case p := <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
- case <-time.After(1 * time.Second):
- return
}
+ return
}
- select {
- case p := <-c.linkEP.C:
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
- if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
+ // ICMP required.
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
- hdr := header.IPv4(pkt)
- checker.IPv4(t, hdr, checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ var pkt []byte
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
+ if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
- icmpPkt := header.ICMPv4(hdr.Payload())
- payloadIPHeader := header.IPv4(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
- }
+ hdr := header.IPv4(pkt)
+ checker.IPv4(t, hdr, checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+ icmpPkt := header.ICMPv4(hdr.Payload())
+ payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ }
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("packet wasn't written out")
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
}
})
}
@@ -1615,47 +1617,49 @@ func TestV6UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- select {
- case p := <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
- case <-time.After(1 * time.Second):
- return
}
+ return
}
- select {
- case p := <-c.linkEP.C:
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
- if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
+ // ICMP required.
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
+
+ var pkt []byte
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
+ if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
- hdr := header.IPv6(pkt)
- checker.IPv6(t, hdr, checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+ hdr := header.IPv6(pkt)
+ checker.IPv6(t, hdr, checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
- icmpPkt := header.ICMPv6(hdr.Payload())
- payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
- }
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+ icmpPkt := header.ICMPv6(hdr.Payload())
+ payloadIPHeader := header.IPv6(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ }
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("packet wasn't written out")
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
}
})
}