diff options
Diffstat (limited to 'dhcpv4')
-rw-r--r-- | dhcpv4/bsdp/bsdp.go | 2 | ||||
-rw-r--r-- | dhcpv4/bsdp/bsdp_test.go | 53 | ||||
-rw-r--r-- | dhcpv4/bsdp/option_vendor_specific_information.go | 6 | ||||
-rw-r--r-- | dhcpv4/bsdp/option_vendor_specific_information_test.go | 6 | ||||
-rw-r--r-- | dhcpv4/dhcpv4.go | 12 | ||||
-rw-r--r-- | dhcpv4/dhcpv4_test.go | 42 |
6 files changed, 53 insertions, 68 deletions
diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go index 22360ef..42edf7f 100644 --- a/dhcpv4/bsdp/bsdp.go +++ b/dhcpv4/bsdp/bsdp.go @@ -26,7 +26,7 @@ func ParseBootImageListFromAck(ack dhcpv4.DHCPv4) ([]BootImage, error) { if err != nil { return nil, err } - bootImageOpts := vendorOpt.GetOptions(OptionBootImageList) + bootImageOpts := vendorOpt.GetOption(OptionBootImageList) for _, opt := range bootImageOpts { images = append(images, opt.(*OptBootImageList).Images...) } diff --git a/dhcpv4/bsdp/bsdp_test.go b/dhcpv4/bsdp/bsdp_test.go index da82d43..4cc55a6 100644 --- a/dhcpv4/bsdp/bsdp_test.go +++ b/dhcpv4/bsdp/bsdp_test.go @@ -9,21 +9,6 @@ import ( "github.com/stretchr/testify/require" ) -func RequireEqualIPAddr(t *testing.T, a, b net.IP, msg ...interface{}) { - if !net.IP.Equal(a, b) { - t.Fatalf("Invalid %s. %v != %v", msg, a, b) - } -} - -func RequireHasOption(t *testing.T, opts []dhcpv4.Option, opcode dhcpv4.OptionCode) { - for _, opt := range opts { - if opt.Code() == opcode { - return - } - } - require.FailNow(t, "option not present in opts", dhcpv4.OptionCodeToString[opcode]) -} - func TestParseBootImageListFromAck(t *testing.T) { expectedBootImages := []BootImage{ BootImage{ @@ -73,16 +58,16 @@ func TestNewInformList_NoReplyPort(t *testing.T) { m, err := NewInformList(hwAddr, localIP, 0) require.NoError(t, err) - RequireHasOption(t, m.Options(), dhcpv4.OptionVendorSpecificInformation) - RequireHasOption(t, m.Options(), dhcpv4.OptionParameterRequestList) - RequireHasOption(t, m.Options(), dhcpv4.OptionMaximumDHCPMessageSize) - RequireHasOption(t, m.Options(), dhcpv4.OptionEnd) + require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionVendorSpecificInformation)) + require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionParameterRequestList)) + require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionMaximumDHCPMessageSize)) + require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionEnd)) opt := m.GetOneOption(dhcpv4.OptionVendorSpecificInformation) require.NotNil(t, opt, "vendor opts not present") vendorInfo := opt.(*OptVendorSpecificInformation) - RequireHasOption(t, vendorInfo.Options, OptionMessageType) - RequireHasOption(t, vendorInfo.Options, OptionVersion) + require.True(t, dhcpv4.HasOption(vendorInfo, OptionMessageType)) + require.True(t, dhcpv4.HasOption(vendorInfo, OptionVersion)) opt = vendorInfo.GetOneOption(OptionMessageType) require.Equal(t, MessageTypeList, opt.(*OptMessageType).Type) @@ -104,7 +89,7 @@ func TestNewInformList_ReplyPort(t *testing.T) { opt := m.GetOneOption(dhcpv4.OptionVendorSpecificInformation) vendorInfo := opt.(*OptVendorSpecificInformation) - RequireHasOption(t, vendorInfo.Options, OptionReplyPort) + require.True(t, dhcpv4.HasOption(vendorInfo, OptionReplyPort)) opt = vendorInfo.GetOneOption(OptionReplyPort) require.Equal(t, replyPort, opt.(*OptReplyPort).Port) @@ -146,27 +131,27 @@ func TestInformSelectForAck_Broadcast(t *testing.T) { require.True(t, m.IsBroadcast()) // Validate options. - RequireHasOption(t, m.Options(), dhcpv4.OptionClassIdentifier) - RequireHasOption(t, m.Options(), dhcpv4.OptionParameterRequestList) - RequireHasOption(t, m.Options(), dhcpv4.OptionDHCPMessageType) + require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionClassIdentifier)) + require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionParameterRequestList)) + require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionDHCPMessageType)) opt := m.GetOneOption(dhcpv4.OptionDHCPMessageType) require.Equal(t, dhcpv4.MessageTypeInform, opt.(*dhcpv4.OptMessageType).MessageType) - RequireHasOption(t, m.Options(), dhcpv4.OptionEnd) + require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionEnd)) // Validate vendor opts. - RequireHasOption(t, m.Options(), dhcpv4.OptionVendorSpecificInformation) + require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionVendorSpecificInformation)) opt = m.GetOneOption(dhcpv4.OptionVendorSpecificInformation) vendorInfo := opt.(*OptVendorSpecificInformation) - RequireHasOption(t, vendorInfo.Options, OptionMessageType) + require.True(t, dhcpv4.HasOption(vendorInfo, OptionMessageType)) opt = vendorInfo.GetOneOption(OptionMessageType) require.Equal(t, MessageTypeSelect, opt.(*OptMessageType).Type) - RequireHasOption(t, vendorInfo.Options, OptionVersion) - RequireHasOption(t, vendorInfo.Options, OptionSelectedBootImageID) + require.True(t, dhcpv4.HasOption(vendorInfo, OptionVersion)) + require.True(t, dhcpv4.HasOption(vendorInfo, OptionSelectedBootImageID)) opt = vendorInfo.GetOneOption(OptionSelectedBootImageID) require.Equal(t, bootImage.ID, opt.(*OptSelectedBootImageID).ID) - RequireHasOption(t, vendorInfo.Options, OptionServerIdentifier) + require.True(t, dhcpv4.HasOption(vendorInfo, OptionServerIdentifier)) opt = vendorInfo.GetOneOption(OptionServerIdentifier) - RequireEqualIPAddr(t, serverID, opt.(*OptServerIdentifier).ServerID) + require.True(t, serverID.Equal(opt.(*OptServerIdentifier).ServerID)) } func TestInformSelectForAck_NoServerID(t *testing.T) { @@ -226,10 +211,10 @@ func TestInformSelectForAck_ReplyPort(t *testing.T) { m, err := InformSelectForAck(*ack, replyPort, bootImage) require.NoError(t, err) - RequireHasOption(t, m.Options(), dhcpv4.OptionVendorSpecificInformation) + require.True(t, dhcpv4.HasOption(m, dhcpv4.OptionVendorSpecificInformation)) opt := m.GetOneOption(dhcpv4.OptionVendorSpecificInformation) vendorInfo := opt.(*OptVendorSpecificInformation) - RequireHasOption(t, vendorInfo.Options, OptionReplyPort) + require.True(t, dhcpv4.HasOption(vendorInfo, OptionReplyPort)) opt = vendorInfo.GetOneOption(OptionReplyPort) require.Equal(t, replyPort, opt.(*OptReplyPort).Port) } diff --git a/dhcpv4/bsdp/option_vendor_specific_information.go b/dhcpv4/bsdp/option_vendor_specific_information.go index 5c8533e..e735b57 100644 --- a/dhcpv4/bsdp/option_vendor_specific_information.go +++ b/dhcpv4/bsdp/option_vendor_specific_information.go @@ -131,8 +131,8 @@ func (o *OptVendorSpecificInformation) Length() int { return length } -// GetOptions returns all suboptions that match the given OptionCode code. -func (o *OptVendorSpecificInformation) GetOptions(code dhcpv4.OptionCode) []dhcpv4.Option { +// GetOption returns all suboptions that match the given OptionCode code. +func (o *OptVendorSpecificInformation) GetOption(code dhcpv4.OptionCode) []dhcpv4.Option { var opts []dhcpv4.Option for _, opt := range o.Options { if opt.Code() == code { @@ -144,7 +144,7 @@ func (o *OptVendorSpecificInformation) GetOptions(code dhcpv4.OptionCode) []dhcp // GetOneOption returns the first suboption that matches the OptionCode code. func (o *OptVendorSpecificInformation) GetOneOption(code dhcpv4.OptionCode) dhcpv4.Option { - opts := o.GetOptions(code) + opts := o.GetOption(code) if len(opts) == 0 { return nil } diff --git a/dhcpv4/bsdp/option_vendor_specific_information_test.go b/dhcpv4/bsdp/option_vendor_specific_information_test.go index 9827618..5e7689d 100644 --- a/dhcpv4/bsdp/option_vendor_specific_information_test.go +++ b/dhcpv4/bsdp/option_vendor_specific_information_test.go @@ -136,7 +136,7 @@ func TestOptVendorSpecificInformationGetOptions(t *testing.T) { &OptVersion{Version1_1}, }, } - foundOpts := o.GetOptions(OptionBootImageList) + foundOpts := o.GetOption(OptionBootImageList) require.Empty(t, foundOpts, "should not get any options") // One option @@ -146,7 +146,7 @@ func TestOptVendorSpecificInformationGetOptions(t *testing.T) { &OptVersion{Version1_1}, }, } - foundOpts = o.GetOptions(OptionMessageType) + foundOpts = o.GetOption(OptionMessageType) require.Equal(t, 1, len(foundOpts), "should only get one option") require.Equal(t, MessageTypeList, foundOpts[0].(*OptMessageType).Type) @@ -158,7 +158,7 @@ func TestOptVendorSpecificInformationGetOptions(t *testing.T) { &OptVersion{Version1_0}, }, } - foundOpts = o.GetOptions(OptionVersion) + foundOpts = o.GetOption(OptionVersion) require.Equal(t, 2, len(foundOpts), "should get two options") require.Equal(t, Version1_1, foundOpts[0].(*OptVersion).Version) require.Equal(t, Version1_0, foundOpts[1].(*OptVersion).Version) diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index c7cf2cf..eb0f467 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -679,3 +679,15 @@ func (d *DHCPv4) ToBytes() []byte { } return ret } + +// OptionGetter is a interface that knows how to retrieve an option from a +// structure of options given an OptionCode. +type OptionGetter interface { + GetOption(OptionCode) []Option + GetOneOption(OptionCode) Option +} + +// HasOption checks whether the OptionGetter `o` has the given `opcode` Option. +func HasOption(o OptionGetter, opcode OptionCode) bool { + return o.GetOneOption(opcode) != nil +} diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go index 204ba6a..e0afaba 100644 --- a/dhcpv4/dhcpv4_test.go +++ b/dhcpv4/dhcpv4_test.go @@ -8,19 +8,7 @@ import ( "github.com/stretchr/testify/require" ) -func RequireEqualIPAddr(t *testing.T, a, b net.IP, msg ...interface{}) { - if !net.IP.Equal(a, b) { - t.Fatalf("Invalid %s. %v != %v", msg, a, b) - } -} - -func RequireHasOption(t *testing.T, packet *DHCPv4, opcode OptionCode) { - require.NotNil(t, packet, "packet cannot be nil") - packetOpt := packet.GetOneOption(opcode) - require.NotNil(t, packetOpt, "option not present in packet") -} - -func TestIPv4AddrsForInterface(t *testing.T) { +func TestGetExternalIPv4Addrs(t *testing.T) { addrs4and6 := []net.Addr{ &net.IPAddr{IP: net.IP{1, 2, 3, 4}}, &net.IPAddr{IP: net.IP{4, 3, 2, 1}}, @@ -80,9 +68,9 @@ func TestFromBytes(t *testing.T) { require.Equal(t, d.TransactionID(), uint32(0xaabbccdd)) require.Equal(t, d.NumSeconds(), uint16(3)) require.Equal(t, d.Flags(), uint16(1)) - RequireEqualIPAddr(t, d.ClientIPAddr(), net.IPv4zero) - RequireEqualIPAddr(t, d.YourIPAddr(), net.IPv4zero) - RequireEqualIPAddr(t, d.GatewayIPAddr(), net.IPv4zero) + require.True(t, d.ClientIPAddr().Equal(net.IPv4zero)) + require.True(t, d.YourIPAddr().Equal(net.IPv4zero)) + require.True(t, d.GatewayIPAddr().Equal(net.IPv4zero)) clientHwAddr := d.ClientHwAddr() require.Equal(t, clientHwAddr[:], []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) hostname := d.ServerHostName() @@ -204,24 +192,24 @@ func TestSettersAndGetters(t *testing.T) { require.Equal(t, uint16(0), d.Flags()) // getter/setter for ClientIPAddr - RequireEqualIPAddr(t, net.IPv4(1, 2, 3, 4), d.ClientIPAddr()) + require.True(t, d.ClientIPAddr().Equal(net.IPv4(1, 2, 3, 4))) d.SetClientIPAddr(net.IPv4(4, 3, 2, 1)) - RequireEqualIPAddr(t, net.IPv4(4, 3, 2, 1), d.ClientIPAddr()) + require.True(t, d.ClientIPAddr().Equal(net.IPv4(4, 3, 2, 1))) // getter/setter for YourIPAddr - RequireEqualIPAddr(t, net.IPv4(5, 6, 7, 8), d.YourIPAddr()) + require.True(t, d.YourIPAddr().Equal(net.IPv4(5, 6, 7, 8))) d.SetYourIPAddr(net.IPv4(8, 7, 6, 5)) - RequireEqualIPAddr(t, net.IPv4(8, 7, 6, 5), d.YourIPAddr()) + require.True(t, d.YourIPAddr().Equal(net.IPv4(8, 7, 6, 5))) // getter/setter for ServerIPAddr - RequireEqualIPAddr(t, net.IPv4(9, 10, 11, 12), d.ServerIPAddr()) + require.True(t, d.ServerIPAddr().Equal(net.IPv4(9, 10, 11, 12))) d.SetServerIPAddr(net.IPv4(12, 11, 10, 9)) - RequireEqualIPAddr(t, net.IPv4(12, 11, 10, 9), d.ServerIPAddr()) + require.True(t, d.ServerIPAddr().Equal(net.IPv4(12, 11, 10, 9))) // getter/setter for GatewayIPAddr - RequireEqualIPAddr(t, net.IPv4(13, 14, 15, 16), d.GatewayIPAddr()) + require.True(t, d.GatewayIPAddr().Equal(net.IPv4(13, 14, 15, 16))) d.SetGatewayIPAddr(net.IPv4(16, 15, 14, 13)) - RequireEqualIPAddr(t, net.IPv4(16, 15, 14, 13), d.GatewayIPAddr()) + require.True(t, d.GatewayIPAddr().Equal(net.IPv4(16, 15, 14, 13))) // getter/setter for ClientHwAddr hwaddr := d.ClientHwAddr() @@ -459,8 +447,8 @@ func TestNewDiscovery(t *testing.T) { require.Equal(t, expectedHwAddr, m.ClientHwAddr()) require.Equal(t, len(hwAddr), int(m.HwAddrLen())) require.True(t, m.IsBroadcast()) - RequireHasOption(t, m, OptionParameterRequestList) - RequireHasOption(t, m, OptionEnd) + require.True(t, HasOption(m, OptionParameterRequestList)) + require.True(t, HasOption(m, OptionEnd)) } func TestNewInform(t *testing.T) { @@ -477,7 +465,7 @@ func TestNewInform(t *testing.T) { require.Equal(t, len(hwAddr), int(m.HwAddrLen())) require.NotNil(t, m.MessageType()) require.Equal(t, MessageTypeInform, *m.MessageType()) - RequireEqualIPAddr(t, localIP, m.ClientIPAddr()) + require.True(t, m.ClientIPAddr().Equal(localIP)) } // TODO |