diff options
Diffstat (limited to 'test/iptables/iptables_util.go')
-rw-r--r-- | test/iptables/iptables_util.go | 85 |
1 files changed, 39 insertions, 46 deletions
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 5125fe47b..a6ec5cca3 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -15,6 +15,7 @@ package iptables import ( + "context" "encoding/binary" "errors" "fmt" @@ -70,7 +71,7 @@ func tableRules(ipv6 bool, table string, argsList [][]string) error { // listenUDP listens on a UDP port and returns the value of net.Conn.Read() for // the first read on that port. -func listenUDP(port int, timeout time.Duration) error { +func listenUDP(ctx context.Context, port int) error { localAddr := net.UDPAddr{ Port: port, } @@ -79,68 +80,53 @@ func listenUDP(port int, timeout time.Duration) error { return err } defer conn.Close() - conn.SetDeadline(time.Now().Add(timeout)) - _, err = conn.Read([]byte{0}) - return err -} -// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified -// over a duration. -func sendUDPLoop(ip net.IP, port int, duration time.Duration) error { - conn, err := connectUDP(ip, port) - if err != nil { - return err - } - defer conn.Close() - loopUDP(conn, duration) - return nil -} + ch := make(chan error) + go func() { + _, err = conn.Read([]byte{0}) + ch <- err + }() -// spawnUDPLoop works like sendUDPLoop, but returns immediately and sends -// packets in another goroutine. -func spawnUDPLoop(ip net.IP, port int, duration time.Duration) error { - conn, err := connectUDP(ip, port) - if err != nil { + select { + case err := <-ch: return err + case <-ctx.Done(): + return ctx.Err() } - go func() { - defer conn.Close() - loopUDP(conn, duration) - }() - return nil } -func connectUDP(ip net.IP, port int) (net.Conn, error) { +// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified +// over a duration. +func sendUDPLoop(ctx context.Context, ip net.IP, port int) error { remote := net.UDPAddr{ IP: ip, Port: port, } conn, err := net.DialUDP("udp", nil, &remote) if err != nil { - return nil, err + return err } - return conn, nil -} + defer conn.Close() -func loopUDP(conn net.Conn, duration time.Duration) { - to := time.After(duration) - for timedOut := false; !timedOut; { + for { // This may return an error (connection refused) if the remote // hasn't started listening yet or they're dropping our // packets. So we ignore Write errors and depend on the remote // to report a failure if it doesn't get a packet it needs. conn.Write([]byte{0}) select { - case <-to: - timedOut = true - default: - time.Sleep(200 * time.Millisecond) + case <-ctx.Done(): + // Being cancelled or timing out isn't an error, as we + // cannot tell with UDP whether we succeeded. + return nil + // Continue looping. + case <-time.After(200 * time.Millisecond): } } } // listenTCP listens for connections on a TCP port. -func listenTCP(port int, timeout time.Duration) error { +func listenTCP(ctx context.Context, port int) error { localAddr := net.TCPAddr{ Port: port, } @@ -153,17 +139,23 @@ func listenTCP(port int, timeout time.Duration) error { defer lConn.Close() // Accept connections on port. - lConn.SetDeadline(time.Now().Add(timeout)) - conn, err := lConn.AcceptTCP() - if err != nil { + ch := make(chan error) + go func() { + conn, err := lConn.AcceptTCP() + ch <- err + conn.Close() + }() + + select { + case err := <-ch: return err + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err()) } - conn.Close() - return nil } // connectTCP connects to the given IP and port from an ephemeral local address. -func connectTCP(ip net.IP, port int, timeout time.Duration) error { +func connectTCP(ctx context.Context, ip net.IP, port int) error { contAddr := net.TCPAddr{ IP: ip, Port: port, @@ -171,13 +163,14 @@ func connectTCP(ip net.IP, port int, timeout time.Duration) error { // The container may not be listening when we first connect, so retry // upon error. callback := func() error { - conn, err := net.DialTimeout("tcp", contAddr.String(), timeout) + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", contAddr.String()) if conn != nil { conn.Close() } return err } - if err := testutil.Poll(callback, timeout); err != nil { + if err := testutil.PollContext(ctx, callback); err != nil { return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err) } |