summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv6/async/client_test.go69
1 files changed, 50 insertions, 19 deletions
diff --git a/dhcpv6/async/client_test.go b/dhcpv6/async/client_test.go
index 8665589..0bc3a87 100644
--- a/dhcpv6/async/client_test.go
+++ b/dhcpv6/async/client_test.go
@@ -1,15 +1,17 @@
package async
import (
+ "context"
"net"
"testing"
"time"
- "github.com/stretchr/testify/require"
"github.com/insomniacslk/dhcp/dhcpv6"
"github.com/insomniacslk/dhcp/iana"
+ "github.com/stretchr/testify/require"
)
+const retries = 5
// solicit creates new solicit based on the mac address
func solicit(input string) (dhcpv6.DHCPv6, error) {
@@ -26,21 +28,39 @@ func solicit(input string) (dhcpv6.DHCPv6, error) {
return dhcpv6.NewSolicitWithCID(duid)
}
-// server creates a server which responds with predefined answers
-func serve(t *testing.T, addr *net.UDPAddr, responses ...dhcpv6.DHCPv6) {
+// server creates a server which responds with a predefined response
+func serve(ctx context.Context, addr *net.UDPAddr, response dhcpv6.DHCPv6) error {
conn, err := net.ListenUDP("udp6", addr)
- require.NoError(t, err)
- defer conn.Close()
- oobdata := []byte{}
- buffer := make([]byte, dhcpv6.MaxUDPReceivedPacketSize)
- for _, packet := range responses {
- n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata)
- require.NoError(t, err)
- _, err = dhcpv6.FromBytes(buffer[:n])
- require.NoError(t, err)
- _, err = conn.WriteTo(packet.ToBytes(), src)
- require.NoError(t, err)
+ if err != nil {
+ return err
}
+ go func() {
+ defer conn.Close()
+ oobdata := []byte{}
+ buffer := make([]byte, dhcpv6.MaxUDPReceivedPacketSize)
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ conn.SetReadDeadline(time.Now().Add(1 * time.Second))
+ n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata)
+ if err != nil {
+ continue
+ }
+ _, err = dhcpv6.FromBytes(buffer[:n])
+ if err != nil {
+ continue
+ }
+ conn.SetWriteDeadline(time.Now().Add(1 * time.Second))
+ _, err = conn.WriteTo(response.ToBytes(), src)
+ if err != nil {
+ continue
+ }
+ }
+ }
+ }()
+ return nil
}
func TestNewClient(t *testing.T) {
@@ -108,15 +128,26 @@ func TestSend(t *testing.T) {
c.LocalAddr = addr
c.RemoteAddr = remote
- go serve(t, remote, a)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ err = serve(ctx, remote, a)
+ require.NoError(t, err)
err = c.Open(16)
require.NoError(t, err)
defer c.Close()
f := c.Send(s)
- response, err, timeout := f.GetOrTimeout(1000)
- require.False(t, timeout)
- require.NoError(t, err)
- require.Equal(t, a, response)
+
+ var passed bool
+ for i := 0; i < retries; i++ {
+ response, err, timeout := f.GetOrTimeout(1000)
+ if timeout {
+ continue
+ }
+ passed = true
+ require.NoError(t, err)
+ require.Equal(t, a, response)
+ }
+ require.True(t, passed, "All attempts to TestSend timed out")
}