summaryrefslogtreecommitdiffhomepage
path: root/dhcpv4
diff options
context:
space:
mode:
Diffstat (limited to 'dhcpv4')
-rw-r--r--dhcpv4/bsdp/bsdp.go21
-rw-r--r--dhcpv4/bsdp/bsdp_test.go25
2 files changed, 19 insertions, 27 deletions
diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go
index 172399b..8d4430a 100644
--- a/dhcpv4/bsdp/bsdp.go
+++ b/dhcpv4/bsdp/bsdp.go
@@ -54,24 +54,21 @@ func needsReplyPort(replyPort uint16) bool {
// MessageTypeFromPacket extracts the BSDP message type (LIST, SELECT) from the
// vendor-specific options and returns it. If the message type option cannot be
// found, returns false.
-func MessageTypeFromPacket(packet *dhcpv4.DHCPv4) (MessageType, bool) {
+func MessageTypeFromPacket(packet *dhcpv4.DHCPv4) *MessageType {
var (
- messageType MessageType
- vendorOpts *OptVendorSpecificInformation
- err error
+ vendorOpts *OptVendorSpecificInformation
+ err error
)
for _, opt := range packet.GetOption(dhcpv4.OptionVendorSpecificInformation) {
- if vendorOpts, err = ParseOptVendorSpecificInformation(opt.ToBytes()); err != nil {
- return messageType, false
- }
- if o := vendorOpts.GetOneOption(OptionMessageType); o != nil {
- if optMessageType, ok := o.(*OptMessageType); ok {
- return optMessageType.Type, true
+ if vendorOpts, err = ParseOptVendorSpecificInformation(opt.ToBytes()); err == nil {
+ if o := vendorOpts.GetOneOption(OptionMessageType); o != nil {
+ if optMessageType, ok := o.(*OptMessageType); ok {
+ return &optMessageType.Type
+ }
}
- return messageType, false
}
}
- return messageType, false
+ return nil
}
// NewInformListForInterface creates a new INFORM packet for interface ifname
diff --git a/dhcpv4/bsdp/bsdp_test.go b/dhcpv4/bsdp/bsdp_test.go
index c9c868a..9da86fa 100644
--- a/dhcpv4/bsdp/bsdp_test.go
+++ b/dhcpv4/bsdp/bsdp_test.go
@@ -375,27 +375,26 @@ func TestNewReplyForInformSelect(t *testing.T) {
func TestMessageTypeForPacket(t *testing.T) {
var (
pkt *dhcpv4.DHCPv4
- gotMessageType MessageType
- gotOK bool
+ gotMessageType *MessageType
)
+ list := new(MessageType)
+ *list = MessageTypeList
+
testcases := []struct {
tcName string
opts []dhcpv4.Option
- wantOK bool
- wantMessageType MessageType
+ wantMessageType *MessageType
}{
{
tcName: "No options",
opts: []dhcpv4.Option{},
- wantOK: false,
},
{
tcName: "Some options, no vendor opts",
opts: []dhcpv4.Option{
&dhcpv4.OptHostName{HostName: "foobar1234"},
},
- wantOK: false,
},
{
tcName: "Vendor opts, no message type",
@@ -407,7 +406,6 @@ func TestMessageTypeForPacket(t *testing.T) {
},
},
},
- wantOK: false,
},
{
tcName: "Vendor opts, with message type",
@@ -420,8 +418,7 @@ func TestMessageTypeForPacket(t *testing.T) {
},
},
},
- wantOK: true,
- wantMessageType: MessageTypeList,
+ wantMessageType: list,
},
}
for _, tt := range testcases {
@@ -430,12 +427,10 @@ func TestMessageTypeForPacket(t *testing.T) {
for _, opt := range tt.opts {
pkt.AddOption(opt)
}
- gotMessageType, gotOK = MessageTypeFromPacket(pkt)
- if tt.wantOK {
- require.True(t, gotOK)
- require.Equal(t, tt.wantMessageType, gotMessageType)
- } else {
- require.False(t, gotOK)
+ gotMessageType = MessageTypeFromPacket(pkt)
+ require.Equal(t, tt.wantMessageType, gotMessageType)
+ if tt.wantMessageType != nil {
+ require.Equal(t, *tt.wantMessageType, *gotMessageType)
}
})
}