diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/packetimpact/tests/udp_icmp_error_propagation_test.go | 144 |
1 files changed, 123 insertions, 21 deletions
diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go index 66f64eaa2..c47af9a3e 100644 --- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go +++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "net" + "sync" "syscall" "testing" "time" @@ -79,6 +80,38 @@ type testData struct { wantErrno syscall.Errno } +// wantErrno computes the errno to expect given the connection mode of a UDP +// socket and the ICMP error it will receive. +func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno { + if c && icmpErr == portUnreachable { + return syscall.Errno(unix.ECONNREFUSED) + } + return syscall.Errno(0) +} + +// sendICMPError sends an ICMP error message in response to a UDP datagram. +func sendICMPError(conn *tb.UDPIPv4, icmpErr icmpError, udp *tb.UDP) error { + if icmpErr == timeToLiveExceeded { + ip, ok := udp.Prev().(*tb.IPv4) + if !ok { + return fmt.Errorf("expected %s to be IPv4", udp.Prev()) + } + *ip.TTL = 1 + // Let serialization recalculate the checksum since we set the TTL + // to 1. + ip.Checksum = nil + + // Note that the ICMP payload is valid in this case because the UDP + // payload is empty. If the UDP payload were not empty, the packet + // length during serialization may not be calculated correctly, + // resulting in a mal-formed packet. + conn.SendIP(icmpErr.ToICMPv4(), ip, udp) + } else { + conn.SendIP(icmpErr.ToICMPv4(), udp.Prev(), udp) + } + return nil +} + // testRecv tests observing the ICMP error through the recv syscall. A packet // is sent to the DUT, and if wantErrno is non-zero, then the first recv should // fail and the second should succeed. Otherwise if wantErrno is zero then the @@ -174,10 +207,8 @@ func testSockOpt(_ context.Context, d testData) error { func TestUDPICMPErrorPropagation(t *testing.T) { for _, connect := range []connectionMode{true, false} { for _, icmpErr := range []icmpError{portUnreachable, timeToLiveExceeded} { - wantErrno := syscall.Errno(0) - if connect && icmpErr == portUnreachable { - wantErrno = unix.ECONNREFUSED - } + wantErrno := wantErrno(connect, icmpErr) + for _, errDetect := range []errorDetection{ errorDetection{"SendTo", false, testSendTo}, // Send to an address that's different from the one that caused an ICMP @@ -212,23 +243,8 @@ func TestUDPICMPErrorPropagation(t *testing.T) { t.Fatalf("did not receive message from DUT: %s", err) } - if icmpErr == timeToLiveExceeded { - ip, ok := udp.Prev().(*tb.IPv4) - if !ok { - t.Fatalf("expected %s to be IPv4", udp.Prev()) - } - *ip.TTL = 1 - // Let serialization recalculate the checksum since we set the TTL - // to 1. - ip.Checksum = nil - - // Note that the ICMP payload is valid in this case because the UDP - // payload is empty. If the UDP payload were not empty, the packet - // length during serialization may not be calculated correctly, - // resulting in a mal-formed packet. - conn.SendIP(icmpErr.ToICMPv4(), ip, udp) - } else { - conn.SendIP(icmpErr.ToICMPv4(), udp.Prev(), udp) + if err := sendICMPError(&conn, icmpErr, udp); err != nil { + t.Fatal(err) } errDetectConn := &conn @@ -251,3 +267,89 @@ func TestUDPICMPErrorPropagation(t *testing.T) { } } } + +// TestICMPErrorDuringUDPRecv tests behavior when a UDP socket is in the middle +// of a blocking recv and receives an ICMP error. +func TestICMPErrorDuringUDPRecv(t *testing.T) { + for _, connect := range []connectionMode{true, false} { + for _, icmpErr := range []icmpError{portUnreachable, timeToLiveExceeded} { + wantErrno := wantErrno(connect, icmpErr) + + t.Run(fmt.Sprintf("%s/%s", connect, icmpErr), func(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + + remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(remoteFD) + + // Create a second, clean socket on the DUT to ensure that the ICMP + // error messages only affect the sockets they are intended for. + cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(cleanFD) + + conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort}) + defer conn.Close() + + if connect { + dut.Connect(remoteFD, conn.LocalAddr()) + dut.Connect(cleanFD, conn.LocalAddr()) + } + + dut.SendTo(remoteFD, nil, 0, conn.LocalAddr()) + udp, err := conn.Expect(tb.UDP{}, time.Second) + if err != nil { + t.Fatalf("did not receive message from DUT: %s", err) + } + + var wg sync.WaitGroup + wg.Add(2) + go func() { + if wantErrno != syscall.Errno(0) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0) + if ret != -1 { + t.Fatalf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno) + } + if err != wantErrno { + t.Fatalf("recv during ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno) + } + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0); ret == -1 { + t.Fatalf("recv after ICMP error failed with (%[1]d) %[1]", err) + } + wg.Done() + }() + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if ret, _, err := dut.RecvWithErrno(ctx, cleanFD, 100, 0); ret == -1 { + t.Fatalf("recv on clean socket failed with (%[1]d) %[1]", err) + } + wg.Done() + }() + + // TODO(b/155684889) This sleep is to allow time for the DUT to + // actually call recv since we want the ICMP error to arrive during the + // blocking recv, and should be replaced when a better synchronization + // alternative is available. + time.Sleep(2 * time.Second) + + if err := sendICMPError(&conn, icmpErr, udp); err != nil { + t.Fatal(err) + } + + conn.Send(tb.UDP{DstPort: &cleanPort}) + conn.Send(tb.UDP{}) + wg.Wait() + }) + } + } +} |