summaryrefslogtreecommitdiffhomepage
path: root/dhcpv4
diff options
context:
space:
mode:
Diffstat (limited to 'dhcpv4')
-rw-r--r--dhcpv4/bsdp/bsdp.go2
-rw-r--r--dhcpv4/bsdp/bsdp_test.go53
-rw-r--r--dhcpv4/bsdp/option_vendor_specific_information.go6
-rw-r--r--dhcpv4/bsdp/option_vendor_specific_information_test.go6
-rw-r--r--dhcpv4/dhcpv4.go12
-rw-r--r--dhcpv4/dhcpv4_test.go42
6 files changed, 53 insertions, 68 deletions
diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go
index 22360ef..42edf7f 100644
--- a/dhcpv4/bsdp/bsdp.go
+++ b/dhcpv4/bsdp/bsdp.go
@@ -26,7 +26,7 @@ func ParseBootImageListFromAck(ack dhcpv4.DHCPv4) ([]BootImage, error) {
if err != nil {
return nil, err
}
- bootImageOpts := vendorOpt.GetOptions(OptionBootImageList)
+ bootImageOpts := vendorOpt.GetOption(OptionBootImageList)
for _, opt := range bootImageOpts {
images = append(images, opt.(*OptBootImageList).Images...)
}
diff --git a/dhcpv4/bsdp/bsdp_test.go b/dhcpv4/bsdp/bsdp_test.go
index da82d43..4cc55a6 100644
--- a/dhcpv4/bsdp/bsdp_test.go
+++ b/dhcpv4/bsdp/bsdp_test.go
@@ -9,21 +9,6 @@ import (
"github.com/stretchr/testify/require"
)
-func RequireEqualIPAddr(t *testing.T, a, b net.IP, msg ...interface{}) {
- if !net.IP.Equal(a, b) {
- t.Fatalf("Invalid %s. %v != %v", msg, a, b)
- }
-}
-
-func RequireHasOption(t *testing.T, opts []dhcpv4.Option, opcode dhcpv4.OptionCode) {
- for _, opt := range opts {
- if opt.Code() == opcode {
- return
- }
- }
- require.FailNow(t, "option not present in opts", dhcpv4.OptionCodeToString[opcode])
-}
-
func TestParseBootImageListFromAck(t *testing.T) {
expectedBootImages := []BootImage{
BootImage{
@@ -73,16 +58,16 @@ func TestNewInformList_NoReplyPort(t *testing.T) {
m, err := NewInformList(hwAddr, localIP, 0)
require.NoError(t, err)
- RequireHasOption(t, m.Options(), dhcpv4.OptionVendorSpecificInformation)
- RequireHasOption(t, m.Options(), dhcpv4.OptionParameterRequestList)
- RequireHasOption(t, m.Options(), dhcpv4.OptionMaximumDHCPMessageSize)
- RequireHasOption(t, m.Options(), dhcpv4.OptionEnd)
+ require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionVendorSpecificInformation))
+ require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionParameterRequestList))
+ require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionMaximumDHCPMessageSize))
+ require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionEnd))
opt := m.GetOneOption(dhcpv4.OptionVendorSpecificInformation)
require.NotNil(t, opt, "vendor opts not present")
vendorInfo := opt.(*OptVendorSpecificInformation)
- RequireHasOption(t, vendorInfo.Options, OptionMessageType)
- RequireHasOption(t, vendorInfo.Options, OptionVersion)
+ require.True(t, dhcpv4.HasOption(vendorInfo, OptionMessageType))
+ require.True(t, dhcpv4.HasOption(vendorInfo, OptionVersion))
opt = vendorInfo.GetOneOption(OptionMessageType)
require.Equal(t, MessageTypeList, opt.(*OptMessageType).Type)
@@ -104,7 +89,7 @@ func TestNewInformList_ReplyPort(t *testing.T) {
opt := m.GetOneOption(dhcpv4.OptionVendorSpecificInformation)
vendorInfo := opt.(*OptVendorSpecificInformation)
- RequireHasOption(t, vendorInfo.Options, OptionReplyPort)
+ require.True(t, dhcpv4.HasOption(vendorInfo, OptionReplyPort))
opt = vendorInfo.GetOneOption(OptionReplyPort)
require.Equal(t, replyPort, opt.(*OptReplyPort).Port)
@@ -146,27 +131,27 @@ func TestInformSelectForAck_Broadcast(t *testing.T) {
require.True(t, m.IsBroadcast())
// Validate options.
- RequireHasOption(t, m.Options(), dhcpv4.OptionClassIdentifier)
- RequireHasOption(t, m.Options(), dhcpv4.OptionParameterRequestList)
- RequireHasOption(t, m.Options(), dhcpv4.OptionDHCPMessageType)
+ require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionClassIdentifier))
+ require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionParameterRequestList))
+ require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionDHCPMessageType))
opt := m.GetOneOption(dhcpv4.OptionDHCPMessageType)
require.Equal(t, dhcpv4.MessageTypeInform, opt.(*dhcpv4.OptMessageType).MessageType)
- RequireHasOption(t, m.Options(), dhcpv4.OptionEnd)
+ require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionEnd))
// Validate vendor opts.
- RequireHasOption(t, m.Options(), dhcpv4.OptionVendorSpecificInformation)
+ require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionVendorSpecificInformation))
opt = m.GetOneOption(dhcpv4.OptionVendorSpecificInformation)
vendorInfo := opt.(*OptVendorSpecificInformation)
- RequireHasOption(t, vendorInfo.Options, OptionMessageType)
+ require.True(t, dhcpv4.HasOption(vendorInfo, OptionMessageType))
opt = vendorInfo.GetOneOption(OptionMessageType)
require.Equal(t, MessageTypeSelect, opt.(*OptMessageType).Type)
- RequireHasOption(t, vendorInfo.Options, OptionVersion)
- RequireHasOption(t, vendorInfo.Options, OptionSelectedBootImageID)
+ require.True(t, dhcpv4.HasOption(vendorInfo, OptionVersion))
+ require.True(t, dhcpv4.HasOption(vendorInfo, OptionSelectedBootImageID))
opt = vendorInfo.GetOneOption(OptionSelectedBootImageID)
require.Equal(t, bootImage.ID, opt.(*OptSelectedBootImageID).ID)
- RequireHasOption(t, vendorInfo.Options, OptionServerIdentifier)
+ require.True(t, dhcpv4.HasOption(vendorInfo, OptionServerIdentifier))
opt = vendorInfo.GetOneOption(OptionServerIdentifier)
- RequireEqualIPAddr(t, serverID, opt.(*OptServerIdentifier).ServerID)
+ require.True(t, serverID.Equal(opt.(*OptServerIdentifier).ServerID))
}
func TestInformSelectForAck_NoServerID(t *testing.T) {
@@ -226,10 +211,10 @@ func TestInformSelectForAck_ReplyPort(t *testing.T) {
m, err := InformSelectForAck(*ack, replyPort, bootImage)
require.NoError(t, err)
- RequireHasOption(t, m.Options(), dhcpv4.OptionVendorSpecificInformation)
+ require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionVendorSpecificInformation))
opt := m.GetOneOption(dhcpv4.OptionVendorSpecificInformation)
vendorInfo := opt.(*OptVendorSpecificInformation)
- RequireHasOption(t, vendorInfo.Options, OptionReplyPort)
+ require.True(t, dhcpv4.HasOption(vendorInfo, OptionReplyPort))
opt = vendorInfo.GetOneOption(OptionReplyPort)
require.Equal(t, replyPort, opt.(*OptReplyPort).Port)
}
diff --git a/dhcpv4/bsdp/option_vendor_specific_information.go b/dhcpv4/bsdp/option_vendor_specific_information.go
index 5c8533e..e735b57 100644
--- a/dhcpv4/bsdp/option_vendor_specific_information.go
+++ b/dhcpv4/bsdp/option_vendor_specific_information.go
@@ -131,8 +131,8 @@ func (o *OptVendorSpecificInformation) Length() int {
return length
}
-// GetOptions returns all suboptions that match the given OptionCode code.
-func (o *OptVendorSpecificInformation) GetOptions(code dhcpv4.OptionCode) []dhcpv4.Option {
+// GetOption returns all suboptions that match the given OptionCode code.
+func (o *OptVendorSpecificInformation) GetOption(code dhcpv4.OptionCode) []dhcpv4.Option {
var opts []dhcpv4.Option
for _, opt := range o.Options {
if opt.Code() == code {
@@ -144,7 +144,7 @@ func (o *OptVendorSpecificInformation) GetOptions(code dhcpv4.OptionCode) []dhcp
// GetOneOption returns the first suboption that matches the OptionCode code.
func (o *OptVendorSpecificInformation) GetOneOption(code dhcpv4.OptionCode) dhcpv4.Option {
- opts := o.GetOptions(code)
+ opts := o.GetOption(code)
if len(opts) == 0 {
return nil
}
diff --git a/dhcpv4/bsdp/option_vendor_specific_information_test.go b/dhcpv4/bsdp/option_vendor_specific_information_test.go
index 9827618..5e7689d 100644
--- a/dhcpv4/bsdp/option_vendor_specific_information_test.go
+++ b/dhcpv4/bsdp/option_vendor_specific_information_test.go
@@ -136,7 +136,7 @@ func TestOptVendorSpecificInformationGetOptions(t *testing.T) {
&OptVersion{Version1_1},
},
}
- foundOpts := o.GetOptions(OptionBootImageList)
+ foundOpts := o.GetOption(OptionBootImageList)
require.Empty(t, foundOpts, "should not get any options")
// One option
@@ -146,7 +146,7 @@ func TestOptVendorSpecificInformationGetOptions(t *testing.T) {
&OptVersion{Version1_1},
},
}
- foundOpts = o.GetOptions(OptionMessageType)
+ foundOpts = o.GetOption(OptionMessageType)
require.Equal(t, 1, len(foundOpts), "should only get one option")
require.Equal(t, MessageTypeList, foundOpts[0].(*OptMessageType).Type)
@@ -158,7 +158,7 @@ func TestOptVendorSpecificInformationGetOptions(t *testing.T) {
&OptVersion{Version1_0},
},
}
- foundOpts = o.GetOptions(OptionVersion)
+ foundOpts = o.GetOption(OptionVersion)
require.Equal(t, 2, len(foundOpts), "should get two options")
require.Equal(t, Version1_1, foundOpts[0].(*OptVersion).Version)
require.Equal(t, Version1_0, foundOpts[1].(*OptVersion).Version)
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go
index c7cf2cf..eb0f467 100644
--- a/dhcpv4/dhcpv4.go
+++ b/dhcpv4/dhcpv4.go
@@ -679,3 +679,15 @@ func (d *DHCPv4) ToBytes() []byte {
}
return ret
}
+
+// OptionGetter is a interface that knows how to retrieve an option from a
+// structure of options given an OptionCode.
+type OptionGetter interface {
+ GetOption(OptionCode) []Option
+ GetOneOption(OptionCode) Option
+}
+
+// HasOption checks whether the OptionGetter `o` has the given `opcode` Option.
+func HasOption(o OptionGetter, opcode OptionCode) bool {
+ return o.GetOneOption(opcode) != nil
+}
diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go
index 204ba6a..e0afaba 100644
--- a/dhcpv4/dhcpv4_test.go
+++ b/dhcpv4/dhcpv4_test.go
@@ -8,19 +8,7 @@ import (
"github.com/stretchr/testify/require"
)
-func RequireEqualIPAddr(t *testing.T, a, b net.IP, msg ...interface{}) {
- if !net.IP.Equal(a, b) {
- t.Fatalf("Invalid %s. %v != %v", msg, a, b)
- }
-}
-
-func RequireHasOption(t *testing.T, packet *DHCPv4, opcode OptionCode) {
- require.NotNil(t, packet, "packet cannot be nil")
- packetOpt := packet.GetOneOption(opcode)
- require.NotNil(t, packetOpt, "option not present in packet")
-}
-
-func TestIPv4AddrsForInterface(t *testing.T) {
+func TestGetExternalIPv4Addrs(t *testing.T) {
addrs4and6 := []net.Addr{
&net.IPAddr{IP: net.IP{1, 2, 3, 4}},
&net.IPAddr{IP: net.IP{4, 3, 2, 1}},
@@ -80,9 +68,9 @@ func TestFromBytes(t *testing.T) {
require.Equal(t, d.TransactionID(), uint32(0xaabbccdd))
require.Equal(t, d.NumSeconds(), uint16(3))
require.Equal(t, d.Flags(), uint16(1))
- RequireEqualIPAddr(t, d.ClientIPAddr(), net.IPv4zero)
- RequireEqualIPAddr(t, d.YourIPAddr(), net.IPv4zero)
- RequireEqualIPAddr(t, d.GatewayIPAddr(), net.IPv4zero)
+ require.True(t, d.ClientIPAddr().Equal(net.IPv4zero))
+ require.True(t, d.YourIPAddr().Equal(net.IPv4zero))
+ require.True(t, d.GatewayIPAddr().Equal(net.IPv4zero))
clientHwAddr := d.ClientHwAddr()
require.Equal(t, clientHwAddr[:], []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
hostname := d.ServerHostName()
@@ -204,24 +192,24 @@ func TestSettersAndGetters(t *testing.T) {
require.Equal(t, uint16(0), d.Flags())
// getter/setter for ClientIPAddr
- RequireEqualIPAddr(t, net.IPv4(1, 2, 3, 4), d.ClientIPAddr())
+ require.True(t, d.ClientIPAddr().Equal(net.IPv4(1, 2, 3, 4)))
d.SetClientIPAddr(net.IPv4(4, 3, 2, 1))
- RequireEqualIPAddr(t, net.IPv4(4, 3, 2, 1), d.ClientIPAddr())
+ require.True(t, d.ClientIPAddr().Equal(net.IPv4(4, 3, 2, 1)))
// getter/setter for YourIPAddr
- RequireEqualIPAddr(t, net.IPv4(5, 6, 7, 8), d.YourIPAddr())
+ require.True(t, d.YourIPAddr().Equal(net.IPv4(5, 6, 7, 8)))
d.SetYourIPAddr(net.IPv4(8, 7, 6, 5))
- RequireEqualIPAddr(t, net.IPv4(8, 7, 6, 5), d.YourIPAddr())
+ require.True(t, d.YourIPAddr().Equal(net.IPv4(8, 7, 6, 5)))
// getter/setter for ServerIPAddr
- RequireEqualIPAddr(t, net.IPv4(9, 10, 11, 12), d.ServerIPAddr())
+ require.True(t, d.ServerIPAddr().Equal(net.IPv4(9, 10, 11, 12)))
d.SetServerIPAddr(net.IPv4(12, 11, 10, 9))
- RequireEqualIPAddr(t, net.IPv4(12, 11, 10, 9), d.ServerIPAddr())
+ require.True(t, d.ServerIPAddr().Equal(net.IPv4(12, 11, 10, 9)))
// getter/setter for GatewayIPAddr
- RequireEqualIPAddr(t, net.IPv4(13, 14, 15, 16), d.GatewayIPAddr())
+ require.True(t, d.GatewayIPAddr().Equal(net.IPv4(13, 14, 15, 16)))
d.SetGatewayIPAddr(net.IPv4(16, 15, 14, 13))
- RequireEqualIPAddr(t, net.IPv4(16, 15, 14, 13), d.GatewayIPAddr())
+ require.True(t, d.GatewayIPAddr().Equal(net.IPv4(16, 15, 14, 13)))
// getter/setter for ClientHwAddr
hwaddr := d.ClientHwAddr()
@@ -459,8 +447,8 @@ func TestNewDiscovery(t *testing.T) {
require.Equal(t, expectedHwAddr, m.ClientHwAddr())
require.Equal(t, len(hwAddr), int(m.HwAddrLen()))
require.True(t, m.IsBroadcast())
- RequireHasOption(t, m, OptionParameterRequestList)
- RequireHasOption(t, m, OptionEnd)
+ require.True(t, HasOption(m, OptionParameterRequestList))
+ require.True(t, HasOption(m, OptionEnd))
}
func TestNewInform(t *testing.T) {
@@ -477,7 +465,7 @@ func TestNewInform(t *testing.T) {
require.Equal(t, len(hwAddr), int(m.HwAddrLen()))
require.NotNil(t, m.MessageType())
require.Equal(t, MessageTypeInform, *m.MessageType())
- RequireEqualIPAddr(t, localIP, m.ClientIPAddr())
+ require.True(t, m.ClientIPAddr().Equal(localIP))
}
// TODO