diff options
Diffstat (limited to 'dhcpv6')
-rw-r--r-- | dhcpv6/async/client_test.go | 69 |
1 files changed, 50 insertions, 19 deletions
diff --git a/dhcpv6/async/client_test.go b/dhcpv6/async/client_test.go index 8665589..0bc3a87 100644 --- a/dhcpv6/async/client_test.go +++ b/dhcpv6/async/client_test.go @@ -1,15 +1,17 @@ package async import ( + "context" "net" "testing" "time" - "github.com/stretchr/testify/require" "github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/iana" + "github.com/stretchr/testify/require" ) +const retries = 5 // solicit creates new solicit based on the mac address func solicit(input string) (dhcpv6.DHCPv6, error) { @@ -26,21 +28,39 @@ func solicit(input string) (dhcpv6.DHCPv6, error) { return dhcpv6.NewSolicitWithCID(duid) } -// server creates a server which responds with predefined answers -func serve(t *testing.T, addr *net.UDPAddr, responses ...dhcpv6.DHCPv6) { +// server creates a server which responds with a predefined response +func serve(ctx context.Context, addr *net.UDPAddr, response dhcpv6.DHCPv6) error { conn, err := net.ListenUDP("udp6", addr) - require.NoError(t, err) - defer conn.Close() - oobdata := []byte{} - buffer := make([]byte, dhcpv6.MaxUDPReceivedPacketSize) - for _, packet := range responses { - n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata) - require.NoError(t, err) - _, err = dhcpv6.FromBytes(buffer[:n]) - require.NoError(t, err) - _, err = conn.WriteTo(packet.ToBytes(), src) - require.NoError(t, err) + if err != nil { + return err } + go func() { + defer conn.Close() + oobdata := []byte{} + buffer := make([]byte, dhcpv6.MaxUDPReceivedPacketSize) + for { + select { + case <-ctx.Done(): + return + default: + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata) + if err != nil { + continue + } + _, err = dhcpv6.FromBytes(buffer[:n]) + if err != nil { + continue + } + conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err = conn.WriteTo(response.ToBytes(), src) + if err != nil { + continue + } + } + } + }() + return nil } func TestNewClient(t *testing.T) { @@ -108,15 +128,26 @@ func TestSend(t *testing.T) { c.LocalAddr = addr c.RemoteAddr = remote - go serve(t, remote, a) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err = serve(ctx, remote, a) + require.NoError(t, err) err = c.Open(16) require.NoError(t, err) defer c.Close() f := c.Send(s) - response, err, timeout := f.GetOrTimeout(1000) - require.False(t, timeout) - require.NoError(t, err) - require.Equal(t, a, response) + + var passed bool + for i := 0; i < retries; i++ { + response, err, timeout := f.GetOrTimeout(1000) + if timeout { + continue + } + passed = true + require.NoError(t, err) + require.Equal(t, a, response) + } + require.True(t, passed, "All attempts to TestSend timed out") } |