diff options
author | gVisor bot <gvisor-bot@google.com> | 2019-10-09 17:54:51 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-10-09 17:56:05 -0700 |
commit | bf870c1a423063eb86a62c6268fe5d83fb6b87ba (patch) | |
tree | f08f7db5122ad778647fcc7f564f7e5cab657376 /pkg/tcpip/transport/udp/udp_test.go | |
parent | 7a2d5b2fa7c398f7710a134b5790265bf620fced (diff) |
Internal change.
PiperOrigin-RevId: 273861936
Diffstat (limited to 'pkg/tcpip/transport/udp/udp_test.go')
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 175 |
1 files changed, 158 insertions, 17 deletions
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index faa728b68..4ada73475 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -384,15 +384,17 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { - c.injectV4Packet(payload, &h) + c.injectV4Packet(payload, &h, true /* valid */) } else { - c.injectV6Packet(payload, &h) + c.injectV6Packet(payload, &h, true /* valid */) } } // injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -409,10 +411,16 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) + l := uint16(header.UDPMinimumSize + len(payload)) + if !valid { + // Change the UDP payload length to corrupt the header + // as requested by the caller. + l++ + } u.Encode(&header.UDPFields{ SrcPort: h.srcAddr.Port, DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), + Length: l, }) // Calculate the UDP pseudo-header checksum. @@ -426,9 +434,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) } -// injectV6Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { +// injectV4Packet creates a V4 test packet with the given payload and header +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -536,7 +546,7 @@ func TestBindToDeviceOption(t *testing.T) { // injecting it into the link endpoint. It then attempts to read it from the // UDP endpoint and depending on if this was expected to succeed verifies its // correctness. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) { +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { c.t.Helper() payload := newPayload() @@ -547,6 +557,9 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) c.wq.EventRegister(&we, waiter.EventIn) defer c.wq.EventUnregister(&we) + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + var addr tcpip.FullAddress v, _, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { @@ -563,6 +576,11 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) } } + if expectReadError && err != nil { + c.checkEndpointReadStats(1, epstats, err) + return + } + if err != nil { c.t.Fatal("Read failed:", err) } @@ -581,6 +599,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) if !bytes.Equal(payload, v) { c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + c.checkEndpointReadStats(1, epstats, err) } // testRead sends a packet of the given test flow into the stack by injecting it @@ -588,15 +607,15 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) // its correctness. func testRead(c *testContext, flow testFlow) { c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */) + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) } // testFailingRead sends a packet of the given test flow into the stack by // injecting it into the link endpoint. It then tries to read it from the UDP // endpoint and expects this to fail. -func testFailingRead(c *testContext, flow testFlow) { +func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { c.t.Helper() - testReadInternal(c, flow, true /* packetShouldBeDropped */) + testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) } func TestBindEphemeralPort(t *testing.T) { @@ -771,8 +790,8 @@ func TestReadOnBoundToMulticast(t *testing.T) { // Check that we receive multicast packets but not unicast or broadcast // ones. testRead(c, flow) - testFailingRead(c, broadcast) - testFailingRead(c, unicastV4) + testFailingRead(c, broadcast, false /* expectReadError */) + testFailingRead(c, unicastV4, false /* expectReadError */) }) } } @@ -795,7 +814,7 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { // Check that we receive broadcast packets but not unicast ones. testRead(c, flow) - testFailingRead(c, unicastV4) + testFailingRead(c, unicastV4, false /* expectReadError */) }) } } @@ -826,7 +845,8 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { // and verifies it fails with the provided error code. func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { c.t.Helper() - + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() h := flow.header4Tuple(outgoing) writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) @@ -834,6 +854,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) + c.checkEndpointWriteStats(1, epstats, gotErr) if gotErr != wantErr { c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) } @@ -859,6 +880,8 @@ func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...chec func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() writeOpts := tcpip.WriteOptions{} if setDest { @@ -876,7 +899,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - + c.checkEndpointWriteStats(1, epstats, err) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) var udp header.UDP @@ -945,6 +968,10 @@ func TestDualWriteConnectedToV6(t *testing.T) { // Write to V4 mapped address. testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) + const want = 1 + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { + c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) + } } func TestDualWriteConnectedToV4Mapped(t *testing.T) { @@ -1453,3 +1480,117 @@ func TestV6UnknownDestination(t *testing.T) { }) } } + +// TestIncrementMalformedPacketsReceived verifies if the malformed received +// global and endpoint stats get incremented. +func TestIncrementMalformedPacketsReceived(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + payload := newPayload() + c.t.Helper() + h := unicastV6.header4Tuple(incoming) + c.injectV6Packet(payload, &h, false /* !valid */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + } +} + +// TestShutdownRead verifies endpoint read shutdown and error +// stats increment on packet receive. +func TestShutdownRead(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingRead(c, unicastV6, true /* expectReadError */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { + t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) + } +} + +// TestShutdownWrite verifies endpoint write shutdown and error +// stats increment on packet write. +func TestShutdownWrite(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) +} + +func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil: + want.PacketsSent.IncrementBy(incr) + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + want.WriteErrors.InvalidArgs.IncrementBy(incr) + case tcpip.ErrClosedForSend: + want.WriteErrors.WriteClosed.IncrementBy(incr) + case tcpip.ErrInvalidEndpointState: + want.WriteErrors.InvalidEndpointState.IncrementBy(incr) + case tcpip.ErrNoLinkAddress: + want.SendErrors.NoLinkAddr.IncrementBy(incr) + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + want.SendErrors.NoRoute.IncrementBy(incr) + default: + want.SendErrors.SendToNetworkFailed.IncrementBy(incr) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} + +func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil, tcpip.ErrWouldBlock: + case tcpip.ErrClosedForReceive: + want.ReadErrors.ReadClosed.IncrementBy(incr) + default: + c.t.Errorf("Endpoint error missing stats update err %v", err) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} |