diff options
-rw-r--r-- | dhcpv4/bsdp/bsdp.go | 29 | ||||
-rw-r--r-- | dhcpv4/bsdp/bsdp_test.go | 2 | ||||
-rw-r--r-- | dhcpv4/bsdp/option_vendor_specific_information.go | 2 | ||||
-rw-r--r-- | dhcpv4/bsdp/option_vendor_specific_information_test.go | 8 |
4 files changed, 19 insertions, 22 deletions
diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go index bb5166e..51ec20c 100644 --- a/dhcpv4/bsdp/bsdp.go +++ b/dhcpv4/bsdp/bsdp.go @@ -18,17 +18,17 @@ const MaxDHCPMessageSize = 1500 // ACK[LIST] packet and returns them as a list of BootImages. func ParseBootImageListFromAck(ack dhcpv4.DHCPv4) ([]BootImage, error) { var images []BootImage - for _, opt := range ack.Options() { - if opt.Code() == dhcpv4.OptionVendorSpecificInformation { - vendorOpt, err := ParseOptVendorSpecificInformation(opt.ToBytes()) - if err != nil { - return nil, err - } - bootImageOpts := vendorOpt.GetOptions(OptionBootImageList) - for _, opt := range bootImageOpts { - images = append(images, opt.(*OptBootImageList).Images...) - } - } + opt := ack.GetOneOption(dhcpv4.OptionVendorSpecificInformation) + if opt == nil { + return nil, errors.New("ParseBootImageListFromAck: could not find vendor-specific option") + } + vendorOpt, err := ParseOptVendorSpecificInformation(opt.ToBytes()) + if err != nil { + return nil, err + } + bootImageOpts := vendorOpt.GetOptions(OptionBootImageList) + for _, opt := range bootImageOpts { + images = append(images, opt.(*OptBootImageList).Images...) } return images, nil } @@ -109,11 +109,8 @@ func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootI // Find server IP address var serverIP net.IP - // TODO replace this loop with `ack.GetOneOption(OptionBootImageList)` - for _, opt := range ack.Options() { - if opt.Code() == dhcpv4.OptionServerIdentifier { - serverIP = opt.(*dhcpv4.OptServerIdentifier).ServerID - } + if opt := ack.GetOneOption(dhcpv4.OptionServerIdentifier); opt != nil { + serverIP = opt.(*dhcpv4.OptServerIdentifier).ServerID } if serverIP.To4() == nil { return nil, fmt.Errorf("could not parse server identifier from ACK") diff --git a/dhcpv4/bsdp/bsdp_test.go b/dhcpv4/bsdp/bsdp_test.go index fb0a1e5..ad2a265 100644 --- a/dhcpv4/bsdp/bsdp_test.go +++ b/dhcpv4/bsdp/bsdp_test.go @@ -40,7 +40,7 @@ func TestParseBootImageListFromAck(t *testing.T) { func TestParseBootImageListFromAckNoVendorOption(t *testing.T) { ack, _ := dhcpv4.New() images, err := ParseBootImageListFromAck(*ack) - require.NoError(t, err) + require.Error(t, err) require.Empty(t, images, "no BootImages") } diff --git a/dhcpv4/bsdp/option_vendor_specific_information.go b/dhcpv4/bsdp/option_vendor_specific_information.go index 99c72d1..645f0c8 100644 --- a/dhcpv4/bsdp/option_vendor_specific_information.go +++ b/dhcpv4/bsdp/option_vendor_specific_information.go @@ -143,7 +143,7 @@ func (o *OptVendorSpecificInformation) GetOptions(code dhcpv4.OptionCode) []dhcp } // GetOption returns the first suboption that matches the OptionCode code. -func (o *OptVendorSpecificInformation) GetOption(code dhcpv4.OptionCode) dhcpv4.Option { +func (o *OptVendorSpecificInformation) GetOneOption(code dhcpv4.OptionCode) dhcpv4.Option { opts := o.GetOptions(code) if len(opts) == 0 { return nil diff --git a/dhcpv4/bsdp/option_vendor_specific_information_test.go b/dhcpv4/bsdp/option_vendor_specific_information_test.go index f125d5a..bcd28ca 100644 --- a/dhcpv4/bsdp/option_vendor_specific_information_test.go +++ b/dhcpv4/bsdp/option_vendor_specific_information_test.go @@ -153,7 +153,7 @@ func TestOptVendorSpecificInformationGetOptions(t *testing.T) { require.Equal(t, Version1_0, foundOpts[1].(*OptVersion).Version) } -func TestOptVendorSpecificInformationGetOption(t *testing.T) { +func TestOptVendorSpecificInformationGetOneOption(t *testing.T) { // No option o := &OptVendorSpecificInformation{ []dhcpv4.Option{ @@ -161,7 +161,7 @@ func TestOptVendorSpecificInformationGetOption(t *testing.T) { &OptVersion{Version1_1}, }, } - foundOpt := o.GetOption(OptionBootImageList) + foundOpt := o.GetOneOption(OptionBootImageList) require.Nil(t, foundOpt, "should not get options") // One option @@ -171,7 +171,7 @@ func TestOptVendorSpecificInformationGetOption(t *testing.T) { &OptVersion{Version1_1}, }, } - foundOpt = o.GetOption(OptionMessageType) + foundOpt = o.GetOneOption(OptionMessageType) require.Equal(t, MessageTypeList, foundOpt.(*OptMessageType).Type) // Multiple options @@ -182,6 +182,6 @@ func TestOptVendorSpecificInformationGetOption(t *testing.T) { &OptVersion{Version1_0}, }, } - foundOpt = o.GetOption(OptionVersion) + foundOpt = o.GetOneOption(OptionVersion) require.Equal(t, Version1_1, foundOpt.(*OptVersion).Version) } |