summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp/udp_test.go
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2019-10-09 17:54:51 -0700
committergVisor bot <gvisor-bot@google.com>2019-10-09 17:56:05 -0700
commitbf870c1a423063eb86a62c6268fe5d83fb6b87ba (patch)
treef08f7db5122ad778647fcc7f564f7e5cab657376 /pkg/tcpip/transport/udp/udp_test.go
parent7a2d5b2fa7c398f7710a134b5790265bf620fced (diff)
Internal change.
PiperOrigin-RevId: 273861936
Diffstat (limited to 'pkg/tcpip/transport/udp/udp_test.go')
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go175
1 files changed, 158 insertions, 17 deletions
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index faa728b68..4ada73475 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -384,15 +384,17 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
h := flow.header4Tuple(incoming)
if flow.isV4() {
- c.injectV4Packet(payload, &h)
+ c.injectV4Packet(payload, &h, true /* valid */)
} else {
- c.injectV6Packet(payload, &h)
+ c.injectV6Packet(payload, &h, true /* valid */)
}
}
// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
+// values, and injects it into the link endpoint. valid indicates if the
+// caller intends to inject a packet with a valid or an invalid UDP header.
+// We can invalidate the header by corrupting the UDP payload length.
+func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -409,10 +411,16 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
+ l := uint16(header.UDPMinimumSize + len(payload))
+ if !valid {
+ // Change the UDP payload length to corrupt the header
+ // as requested by the caller.
+ l++
+ }
u.Encode(&header.UDPFields{
SrcPort: h.srcAddr.Port,
DstPort: h.dstAddr.Port,
- Length: uint16(header.UDPMinimumSize + len(payload)),
+ Length: l,
})
// Calculate the UDP pseudo-header checksum.
@@ -426,9 +434,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
-// injectV6Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
+// injectV4Packet creates a V4 test packet with the given payload and header
+// values, and injects it into the link endpoint. valid indicates if the
+// caller intends to inject a packet with a valid or an invalid UDP header.
+// We can invalidate the header by corrupting the UDP payload length.
+func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -536,7 +546,7 @@ func TestBindToDeviceOption(t *testing.T) {
// injecting it into the link endpoint. It then attempts to read it from the
// UDP endpoint and depending on if this was expected to succeed verifies its
// correctness.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) {
+func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) {
c.t.Helper()
payload := newPayload()
@@ -547,6 +557,9 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool)
c.wq.EventRegister(&we, waiter.EventIn)
defer c.wq.EventUnregister(&we)
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+
var addr tcpip.FullAddress
v, _, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
@@ -563,6 +576,11 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool)
}
}
+ if expectReadError && err != nil {
+ c.checkEndpointReadStats(1, epstats, err)
+ return
+ }
+
if err != nil {
c.t.Fatal("Read failed:", err)
}
@@ -581,6 +599,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool)
if !bytes.Equal(payload, v) {
c.t.Fatalf("bad payload: got %x, want %x", v, payload)
}
+ c.checkEndpointReadStats(1, epstats, err)
}
// testRead sends a packet of the given test flow into the stack by injecting it
@@ -588,15 +607,15 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool)
// its correctness.
func testRead(c *testContext, flow testFlow) {
c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */)
+ testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */)
}
// testFailingRead sends a packet of the given test flow into the stack by
// injecting it into the link endpoint. It then tries to read it from the UDP
// endpoint and expects this to fail.
-func testFailingRead(c *testContext, flow testFlow) {
+func testFailingRead(c *testContext, flow testFlow, expectReadError bool) {
c.t.Helper()
- testReadInternal(c, flow, true /* packetShouldBeDropped */)
+ testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError)
}
func TestBindEphemeralPort(t *testing.T) {
@@ -771,8 +790,8 @@ func TestReadOnBoundToMulticast(t *testing.T) {
// Check that we receive multicast packets but not unicast or broadcast
// ones.
testRead(c, flow)
- testFailingRead(c, broadcast)
- testFailingRead(c, unicastV4)
+ testFailingRead(c, broadcast, false /* expectReadError */)
+ testFailingRead(c, unicastV4, false /* expectReadError */)
})
}
}
@@ -795,7 +814,7 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
// Check that we receive broadcast packets but not unicast ones.
testRead(c, flow)
- testFailingRead(c, unicastV4)
+ testFailingRead(c, unicastV4, false /* expectReadError */)
})
}
}
@@ -826,7 +845,8 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
// and verifies it fails with the provided error code.
func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
c.t.Helper()
-
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
h := flow.header4Tuple(outgoing)
writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
@@ -834,6 +854,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
_, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
})
+ c.checkEndpointWriteStats(1, epstats, gotErr)
if gotErr != wantErr {
c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
}
@@ -859,6 +880,8 @@ func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...chec
func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
+ // Take a snapshot of the stats to validate them at the end of the test.
+ epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
writeOpts := tcpip.WriteOptions{}
if setDest {
@@ -876,7 +899,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
-
+ c.checkEndpointWriteStats(1, epstats, err)
// Received the packet and check the payload.
b := c.getPacketAndVerify(flow, checkers...)
var udp header.UDP
@@ -945,6 +968,10 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Write to V4 mapped address.
testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
+ const want = 1
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
+ c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
+ }
}
func TestDualWriteConnectedToV4Mapped(t *testing.T) {
@@ -1453,3 +1480,117 @@ func TestV6UnknownDestination(t *testing.T) {
})
}
}
+
+// TestIncrementMalformedPacketsReceived verifies if the malformed received
+// global and endpoint stats get incremented.
+func TestIncrementMalformedPacketsReceived(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ payload := newPayload()
+ c.t.Helper()
+ h := unicastV6.header4Tuple(incoming)
+ c.injectV6Packet(payload, &h, false /* !valid */)
+
+ var want uint64 = 1
+ if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ }
+}
+
+// TestShutdownRead verifies endpoint read shutdown and error
+// stats increment on packet receive.
+func TestShutdownRead(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %v", err)
+ }
+
+ testFailingRead(c, unicastV6, true /* expectReadError */)
+
+ var want uint64 = 1
+ if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
+ }
+}
+
+// TestShutdownWrite verifies endpoint write shutdown and error
+// stats increment on packet write.
+func TestShutdownWrite(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %v", err)
+ }
+
+ testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
+}
+
+func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+ got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+ switch err {
+ case nil:
+ want.PacketsSent.IncrementBy(incr)
+ case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ want.WriteErrors.InvalidArgs.IncrementBy(incr)
+ case tcpip.ErrClosedForSend:
+ want.WriteErrors.WriteClosed.IncrementBy(incr)
+ case tcpip.ErrInvalidEndpointState:
+ want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
+ case tcpip.ErrNoLinkAddress:
+ want.SendErrors.NoLinkAddr.IncrementBy(incr)
+ case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ want.SendErrors.NoRoute.IncrementBy(incr)
+ default:
+ want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
+ }
+ if got != want {
+ c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
+ }
+}
+
+func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+ got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
+ switch err {
+ case nil, tcpip.ErrWouldBlock:
+ case tcpip.ErrClosedForReceive:
+ want.ReadErrors.ReadClosed.IncrementBy(incr)
+ default:
+ c.t.Errorf("Endpoint error missing stats update err %v", err)
+ }
+ if got != want {
+ c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
+ }
+}