diff options
-rw-r--r-- | dhcpv4/bsdp/bsdp.go | 21 | ||||
-rw-r--r-- | dhcpv4/bsdp/bsdp_test.go | 25 |
2 files changed, 19 insertions, 27 deletions
diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go index 172399b..8d4430a 100644 --- a/dhcpv4/bsdp/bsdp.go +++ b/dhcpv4/bsdp/bsdp.go @@ -54,24 +54,21 @@ func needsReplyPort(replyPort uint16) bool { // MessageTypeFromPacket extracts the BSDP message type (LIST, SELECT) from the // vendor-specific options and returns it. If the message type option cannot be // found, returns false. -func MessageTypeFromPacket(packet *dhcpv4.DHCPv4) (MessageType, bool) { +func MessageTypeFromPacket(packet *dhcpv4.DHCPv4) *MessageType { var ( - messageType MessageType - vendorOpts *OptVendorSpecificInformation - err error + vendorOpts *OptVendorSpecificInformation + err error ) for _, opt := range packet.GetOption(dhcpv4.OptionVendorSpecificInformation) { - if vendorOpts, err = ParseOptVendorSpecificInformation(opt.ToBytes()); err != nil { - return messageType, false - } - if o := vendorOpts.GetOneOption(OptionMessageType); o != nil { - if optMessageType, ok := o.(*OptMessageType); ok { - return optMessageType.Type, true + if vendorOpts, err = ParseOptVendorSpecificInformation(opt.ToBytes()); err == nil { + if o := vendorOpts.GetOneOption(OptionMessageType); o != nil { + if optMessageType, ok := o.(*OptMessageType); ok { + return &optMessageType.Type + } } - return messageType, false } } - return messageType, false + return nil } // NewInformListForInterface creates a new INFORM packet for interface ifname diff --git a/dhcpv4/bsdp/bsdp_test.go b/dhcpv4/bsdp/bsdp_test.go index c9c868a..9da86fa 100644 --- a/dhcpv4/bsdp/bsdp_test.go +++ b/dhcpv4/bsdp/bsdp_test.go @@ -375,27 +375,26 @@ func TestNewReplyForInformSelect(t *testing.T) { func TestMessageTypeForPacket(t *testing.T) { var ( pkt *dhcpv4.DHCPv4 - gotMessageType MessageType - gotOK bool + gotMessageType *MessageType ) + list := new(MessageType) + *list = MessageTypeList + testcases := []struct { tcName string opts []dhcpv4.Option - wantOK bool - wantMessageType MessageType + wantMessageType *MessageType }{ { tcName: "No options", opts: []dhcpv4.Option{}, - wantOK: false, }, { tcName: "Some options, no vendor opts", opts: []dhcpv4.Option{ &dhcpv4.OptHostName{HostName: "foobar1234"}, }, - wantOK: false, }, { tcName: "Vendor opts, no message type", @@ -407,7 +406,6 @@ func TestMessageTypeForPacket(t *testing.T) { }, }, }, - wantOK: false, }, { tcName: "Vendor opts, with message type", @@ -420,8 +418,7 @@ func TestMessageTypeForPacket(t *testing.T) { }, }, }, - wantOK: true, - wantMessageType: MessageTypeList, + wantMessageType: list, }, } for _, tt := range testcases { @@ -430,12 +427,10 @@ func TestMessageTypeForPacket(t *testing.T) { for _, opt := range tt.opts { pkt.AddOption(opt) } - gotMessageType, gotOK = MessageTypeFromPacket(pkt) - if tt.wantOK { - require.True(t, gotOK) - require.Equal(t, tt.wantMessageType, gotMessageType) - } else { - require.False(t, gotOK) + gotMessageType = MessageTypeFromPacket(pkt) + require.Equal(t, tt.wantMessageType, gotMessageType) + if tt.wantMessageType != nil { + require.Equal(t, *tt.wantMessageType, *gotMessageType) } }) } |