diff options
-rw-r--r-- | dhcpv4/options.go | 3 | ||||
-rw-r--r-- | dhcpv4/options_test.go | 5 | ||||
-rw-r--r-- | dhcpv6/async/client_test.go | 69 | ||||
-rw-r--r-- | netboot/netboot.go | 4 |
4 files changed, 57 insertions, 24 deletions
diff --git a/dhcpv4/options.go b/dhcpv4/options.go index d869b7d..02fa6e4 100644 --- a/dhcpv4/options.go +++ b/dhcpv4/options.go @@ -126,6 +126,9 @@ func OptionsFromBytesWithoutMagicCookie(data []byte) ([]Option, error) { return nil, err } options = append(options, opt) + if opt.Code() == OptionEnd { + break + } // Options with zero length have no length byte, so here we handle the // ones with nonzero length diff --git a/dhcpv4/options_test.go b/dhcpv4/options_test.go index 0268483..899fb2c 100644 --- a/dhcpv4/options_test.go +++ b/dhcpv4/options_test.go @@ -167,12 +167,9 @@ func TestOptionsFromBytes(t *testing.T) { } opts, err := OptionsFromBytes(options) require.NoError(t, err) - require.Equal(t, 5, len(opts)) + require.Equal(t, 2, len(opts)) require.Equal(t, opts[0].(*OptionGeneric), &OptionGeneric{OptionCode: OptionNameServer, Data: []byte{192, 168, 1, 1}}) require.Equal(t, opts[1].(*OptionGeneric), &OptionGeneric{OptionCode: OptionEnd}) - require.Equal(t, opts[2].(*OptionGeneric), &OptionGeneric{OptionCode: OptionPad}) - require.Equal(t, opts[3].(*OptionGeneric), &OptionGeneric{OptionCode: OptionPad}) - require.Equal(t, opts[4].(*OptionGeneric), &OptionGeneric{OptionCode: OptionPad}) } func TestOptionsFromBytesZeroLength(t *testing.T) { diff --git a/dhcpv6/async/client_test.go b/dhcpv6/async/client_test.go index 8665589..0bc3a87 100644 --- a/dhcpv6/async/client_test.go +++ b/dhcpv6/async/client_test.go @@ -1,15 +1,17 @@ package async import ( + "context" "net" "testing" "time" - "github.com/stretchr/testify/require" "github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/iana" + "github.com/stretchr/testify/require" ) +const retries = 5 // solicit creates new solicit based on the mac address func solicit(input string) (dhcpv6.DHCPv6, error) { @@ -26,21 +28,39 @@ func solicit(input string) (dhcpv6.DHCPv6, error) { return dhcpv6.NewSolicitWithCID(duid) } -// server creates a server which responds with predefined answers -func serve(t *testing.T, addr *net.UDPAddr, responses ...dhcpv6.DHCPv6) { +// server creates a server which responds with a predefined response +func serve(ctx context.Context, addr *net.UDPAddr, response dhcpv6.DHCPv6) error { conn, err := net.ListenUDP("udp6", addr) - require.NoError(t, err) - defer conn.Close() - oobdata := []byte{} - buffer := make([]byte, dhcpv6.MaxUDPReceivedPacketSize) - for _, packet := range responses { - n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata) - require.NoError(t, err) - _, err = dhcpv6.FromBytes(buffer[:n]) - require.NoError(t, err) - _, err = conn.WriteTo(packet.ToBytes(), src) - require.NoError(t, err) + if err != nil { + return err } + go func() { + defer conn.Close() + oobdata := []byte{} + buffer := make([]byte, dhcpv6.MaxUDPReceivedPacketSize) + for { + select { + case <-ctx.Done(): + return + default: + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata) + if err != nil { + continue + } + _, err = dhcpv6.FromBytes(buffer[:n]) + if err != nil { + continue + } + conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err = conn.WriteTo(response.ToBytes(), src) + if err != nil { + continue + } + } + } + }() + return nil } func TestNewClient(t *testing.T) { @@ -108,15 +128,26 @@ func TestSend(t *testing.T) { c.LocalAddr = addr c.RemoteAddr = remote - go serve(t, remote, a) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err = serve(ctx, remote, a) + require.NoError(t, err) err = c.Open(16) require.NoError(t, err) defer c.Close() f := c.Send(s) - response, err, timeout := f.GetOrTimeout(1000) - require.False(t, timeout) - require.NoError(t, err) - require.Equal(t, a, response) + + var passed bool + for i := 0; i < retries; i++ { + response, err, timeout := f.GetOrTimeout(1000) + if timeout { + continue + } + passed = true + require.NoError(t, err) + require.Equal(t, a, response) + } + require.True(t, passed, "All attempts to TestSend timed out") } diff --git a/netboot/netboot.go b/netboot/netboot.go index 81ba144..2f4cc8f 100644 --- a/netboot/netboot.go +++ b/netboot/netboot.go @@ -15,7 +15,6 @@ func RequestNetbootv6(ifname string, timeout time.Duration, retries int, modifie var ( conversation []dhcpv6.DHCPv6 ) - modifiers = append(modifiers, dhcpv6.WithNetboot) delay := 2 * time.Second for i := 0; i <= retries; i++ { log.Printf("sending request, attempt #%d", i+1) @@ -26,6 +25,9 @@ func RequestNetbootv6(ifname string, timeout time.Duration, retries int, modifie client := dhcpv6.NewClient() client.ReadTimeout = timeout + // WithNetboot is added only later, to avoid applying it twice (one + // here and one in the above call to NewSolicitForInterface) + modifiers = append(modifiers, dhcpv6.WithNetboot) conversation, err = client.Exchange(ifname, solicit, modifiers...) if err != nil { log.Printf("Client.Exchange failed: %v", err) |