diff options
76 files changed, 544 insertions, 1277 deletions
diff --git a/dhcpv4/bsdp/boot_image.go b/dhcpv4/bsdp/boot_image.go index 88a9404..fa9b1a6 100644 --- a/dhcpv4/bsdp/boot_image.go +++ b/dhcpv4/bsdp/boot_image.go @@ -3,6 +3,8 @@ package bsdp import ( "encoding/binary" "fmt" + + "github.com/u-root/u-root/pkg/uio" ) // BootImageType represents the different BSDP boot image types. @@ -60,16 +62,14 @@ func (b BootImageID) String() string { return s + " " + t + " image" } -// BootImageIDFromBytes deserializes a collection of 4 bytes to a BootImageID. -func BootImageIDFromBytes(bytes []byte) (*BootImageID, error) { - if len(bytes) < 4 { - return nil, fmt.Errorf("not enough bytes to serialize BootImageID") - } - return &BootImageID{ - IsInstall: bytes[0]&0x80 != 0, - ImageType: BootImageType(bytes[0] & 0x7f), - Index: binary.BigEndian.Uint16(bytes[2:]), - }, nil +// Unmarshal reads b's binary representation from buf. +func (b *BootImageID) Unmarshal(buf *uio.Lexer) error { + byte0 := buf.Read8() + _ = buf.Read8() + b.IsInstall = byte0&0x80 != 0 + b.ImageType = BootImageType(byte0 & 0x7f) + b.Index = buf.Read16() + return buf.Error() } // BootImage describes a boot image - contains the boot image ID and the name. @@ -87,25 +87,16 @@ func (b *BootImage) ToBytes() []byte { } // String converts a BootImage to a human-readable representation. -func (b *BootImage) String() string { +func (b BootImage) String() string { return fmt.Sprintf("%v %v", b.Name, b.ID.String()) } -// BootImageFromBytes returns a deserialized BootImage struct from bytes. -func BootImageFromBytes(bytes []byte) (*BootImage, error) { - // Should at least contain 4 bytes of BootImageID + byte for length of - // boot image name. - if len(bytes) < 5 { - return nil, fmt.Errorf("not enough bytes to serialize BootImage") - } - imageID, err := BootImageIDFromBytes(bytes[:4]) - if err != nil { - return nil, err - } - nameLength := int(bytes[4]) - if 5+nameLength > len(bytes) { - return nil, fmt.Errorf("not enough bytes for BootImage") +// Unmarshal reads data from buf into b. +func (b *BootImage) Unmarshal(buf *uio.Lexer) error { + if err := (&b.ID).Unmarshal(buf); err != nil { + return err } - name := string(bytes[5 : 5+nameLength]) - return &BootImage{ID: *imageID, Name: name}, nil + nameLength := buf.Read8() + b.Name = string(buf.Consume(int(nameLength))) + return buf.Error() } diff --git a/dhcpv4/bsdp/boot_image_test.go b/dhcpv4/bsdp/boot_image_test.go index 004aa30..d8e3aeb 100644 --- a/dhcpv4/bsdp/boot_image_test.go +++ b/dhcpv4/bsdp/boot_image_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/u-root/u-root/pkg/uio" ) func TestBootImageIDToBytes(t *testing.T) { @@ -28,25 +29,23 @@ func TestBootImageIDFromBytes(t *testing.T) { ImageType: BootImageTypeMacOSX, Index: 0x1000, } - newBootImage, err := BootImageIDFromBytes(b.ToBytes()) - require.NoError(t, err) - require.Equal(t, b, *newBootImage) + var newBootImage BootImageID + require.NoError(t, uio.FromBigEndian(&newBootImage, b.ToBytes())) + require.Equal(t, b, newBootImage) b = BootImageID{ IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 0x1011, } - newBootImage, err = BootImageIDFromBytes(b.ToBytes()) - require.NoError(t, err) - require.Equal(t, b, *newBootImage) + require.NoError(t, uio.FromBigEndian(&newBootImage, b.ToBytes())) + require.Equal(t, b, newBootImage) } func TestBootImageIDFromBytesFail(t *testing.T) { serialized := []byte{0x81, 0, 0x10} // intentionally left short - deserialized, err := BootImageIDFromBytes(serialized) - require.Nil(t, deserialized) - require.Error(t, err) + var deserialized BootImageID + require.Error(t, uio.FromBigEndian(&deserialized, serialized)) } func TestBootImageIDString(t *testing.T) { @@ -97,8 +96,8 @@ func TestBootImageFromBytes(t *testing.T) { 7, // len(Name) 98, 115, 100, 112, 45, 50, 49, // byte-encoding of Name } - b, err := BootImageFromBytes(input) - require.NoError(t, err) + var b BootImage + require.NoError(t, uio.FromBigEndian(&b, input)) expectedBootImage := BootImage{ ID: BootImageID{ IsInstall: false, @@ -107,15 +106,14 @@ func TestBootImageFromBytes(t *testing.T) { }, Name: "bsdp-21", } - require.Equal(t, expectedBootImage, *b) + require.Equal(t, expectedBootImage, b) } func TestBootImageFromBytesOnlyBootImageID(t *testing.T) { // Only a BootImageID, nothing else. input := []byte{0x1, 0, 0x10, 0x10} - b, err := BootImageFromBytes(input) - require.Nil(t, b) - require.Error(t, err) + var b BootImage + require.Error(t, uio.FromBigEndian(&b, input)) } func TestBootImageFromBytesShortBootImage(t *testing.T) { @@ -124,9 +122,8 @@ func TestBootImageFromBytesShortBootImage(t *testing.T) { 7, // len(Name) 98, 115, 100, 112, 45, 50, // Name bytes (intentionally off-by-one) } - b, err := BootImageFromBytes(input) - require.Nil(t, b) - require.Error(t, err) + var b BootImage + require.Error(t, uio.FromBigEndian(&b, input)) } func TestBootImageString(t *testing.T) { diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go index 3f97602..6bc0cfd 100644 --- a/dhcpv4/bsdp/bsdp.go +++ b/dhcpv4/bsdp/bsdp.go @@ -23,7 +23,7 @@ const AppleVendorID = "AAPLBSDPC" type ReplyConfig struct { ServerIP net.IP ServerHostname, BootFileName string - ServerPriority int + ServerPriority uint16 Images []BootImage DefaultImage, SelectedImage *BootImage } @@ -36,7 +36,7 @@ func ParseBootImageListFromAck(ack dhcpv4.DHCPv4) ([]BootImage, error) { if opt == nil { return nil, errors.New("ParseBootImageListFromAck: could not find vendor-specific option") } - vendorOpt, err := ParseOptVendorSpecificInformation(opt.ToBytes()) + vendorOpt, err := ParseOptVendorSpecificInformation(opt.ToBytes()[2:]) if err != nil { return nil, err } @@ -60,7 +60,7 @@ func MessageTypeFromPacket(packet *dhcpv4.DHCPv4) *MessageType { err error ) for _, opt := range packet.GetOption(dhcpv4.OptionVendorSpecificInformation) { - if vendorOpts, err = ParseOptVendorSpecificInformation(opt.ToBytes()); err == nil { + if vendorOpts, err = ParseOptVendorSpecificInformation(opt.ToBytes()[2:]); err == nil { if o := vendorOpts.GetOneOption(OptionMessageType); o != nil { if optMessageType, ok := o.(*OptMessageType); ok { return &optMessageType.Type diff --git a/dhcpv4/bsdp/bsdp_option_boot_image_list.go b/dhcpv4/bsdp/bsdp_option_boot_image_list.go index 6417221..d018655 100644 --- a/dhcpv4/bsdp/bsdp_option_boot_image_list.go +++ b/dhcpv4/bsdp/bsdp_option_boot_image_list.go @@ -1,9 +1,8 @@ package bsdp import ( - "fmt" - "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) // Implements the BSDP option listing the boot images. @@ -17,34 +16,15 @@ type OptBootImageList struct { // ParseOptBootImageList constructs an OptBootImageList struct from a sequence // of bytes and returns it, or an error. func ParseOptBootImageList(data []byte) (*OptBootImageList, error) { - // Should have at least code + length - if len(data) < 2 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionBootImageList { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionBootImageList, code) - } - length := int(data[1]) - if len(data) < length+2 { - return nil, fmt.Errorf("expected length %d, got %d instead", length, len(data)) - } + buf := uio.NewBigEndianBuffer(data) - // Offset from code + length byte var bootImages []BootImage - idx := 2 - for { - if idx >= length+2 { - break - } - image, err := BootImageFromBytes(data[idx:]) - if err != nil { - return nil, fmt.Errorf("parsing bytes stream: %v", err) + for buf.Has(5) { + var image BootImage + if err := (&image).Unmarshal(buf); err != nil { + return nil, err } - bootImages = append(bootImages, *image) - - // 4 bytes of BootImageID, 1 byte of name length, name - idx += 4 + 1 + len(image.Name) + bootImages = append(bootImages, image) } return &OptBootImageList{bootImages}, nil diff --git a/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go b/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go index d2784ae..0819d64 100644 --- a/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go +++ b/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go @@ -45,8 +45,6 @@ func TestOptBootImageListInterfaceMethods(t *testing.T) { func TestParseOptBootImageList(t *testing.T) { data := []byte{ - 9, // code - 22, // length // boot image 1 0x1, 0x0, 0x03, 0xe9, // ID 6, // name length @@ -78,25 +76,8 @@ func TestParseOptBootImageList(t *testing.T) { } require.Equal(t, &OptBootImageList{expectedBootImages}, o) - // Short byte stream - data = []byte{9} - _, err = ParseOptBootImageList(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 1, 1} - _, err = ParseOptBootImageList(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{9, 10, 1, 1, 1} - _, err = ParseOptBootImageList(data) - require.Error(t, err, "should get error from bad length") - // Error parsing boot image (malformed) data = []byte{ - 9, // code - 22, // length // boot image 1 0x1, 0x0, 0x03, 0xe9, // ID 4, // name length @@ -108,25 +89,6 @@ func TestParseOptBootImageList(t *testing.T) { } _, err = ParseOptBootImageList(data) require.Error(t, err, "should get error from bad boot image") - - // Should not get error parsing boot image with excess length. - data = []byte{ - 9, // code - 22, // length - // boot image 1 - 0x1, 0x0, 0x03, 0xe9, // ID - 6, // name length - 'b', 's', 'd', 'p', '-', '1', - // boot image 2 - 0x80, 0x0, 0x23, 0x31, // ID - 6, // name length - 'b', 's', 'd', 'p', '-', '2', - - // Simulate another option after boot image list - 7, 4, 0x80, 0x0, 0x23, 0x32, - } - _, err = ParseOptBootImageList(data) - require.NoError(t, err, "should not get error from options after boot image list") } func TestOptBootImageListString(t *testing.T) { diff --git a/dhcpv4/bsdp/bsdp_option_default_boot_image_id.go b/dhcpv4/bsdp/bsdp_option_default_boot_image_id.go index 4c87df8..52f7780 100644 --- a/dhcpv4/bsdp/bsdp_option_default_boot_image_id.go +++ b/dhcpv4/bsdp/bsdp_option_default_boot_image_id.go @@ -4,12 +4,13 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) +// OptDefaultBootImageID contains the selected boot image ID. +// // Implements the BSDP option default boot image ID, which tells the client // which image is the default boot image if one is not selected. - -// OptDefaultBootImageID contains the selected boot image ID. type OptDefaultBootImageID struct { ID BootImageID } @@ -17,22 +18,12 @@ type OptDefaultBootImageID struct { // ParseOptDefaultBootImageID constructs an OptDefaultBootImageID struct from a sequence of // bytes and returns it, or an error. func ParseOptDefaultBootImageID(data []byte) (*OptDefaultBootImageID, error) { - if len(data) < 6 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionDefaultBootImageID { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionDefaultBootImageID, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("expected length 4, got %d instead", length) - } - id, err := BootImageIDFromBytes(data[2:6]) - if err != nil { + var o OptDefaultBootImageID + buf := uio.NewBigEndianBuffer(data) + if err := o.ID.Unmarshal(buf); err != nil { return nil, err } - return &OptDefaultBootImageID{*id}, nil + return &o, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_default_boot_image_id_test.go b/dhcpv4/bsdp/bsdp_option_default_boot_image_id_test.go index e062e2d..ad29c30 100644 --- a/dhcpv4/bsdp/bsdp_option_default_boot_image_id_test.go +++ b/dhcpv4/bsdp/bsdp_option_default_boot_image_id_test.go @@ -17,24 +17,17 @@ func TestOptDefaultBootImageIDInterfaceMethods(t *testing.T) { func TestParseOptDefaultBootImageID(t *testing.T) { b := BootImageID{IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 1001} - bootImageBytes := b.ToBytes() - data := append([]byte{byte(OptionDefaultBootImageID), 4}, bootImageBytes...) - o, err := ParseOptDefaultBootImageID(data) + o, err := ParseOptDefaultBootImageID(b.ToBytes()) require.NoError(t, err) require.Equal(t, &OptDefaultBootImageID{b}, o) // Short byte stream - data = []byte{byte(OptionDefaultBootImageID), 4} + data := []byte{} _, err = ParseOptDefaultBootImageID(data) require.Error(t, err, "should get error from short byte stream") - // Wrong code - data = []byte{54, 2, 1, 0, 0, 0} - _, err = ParseOptDefaultBootImageID(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{byte(OptionDefaultBootImageID), 5, 1, 0, 0, 0, 0} + data = []byte{1, 0, 0, 0, 0} _, err = ParseOptDefaultBootImageID(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/bsdp/bsdp_option_generic.go b/dhcpv4/bsdp/bsdp_option_generic.go index 6a51e29..6702e9c 100644 --- a/dhcpv4/bsdp/bsdp_option_generic.go +++ b/dhcpv4/bsdp/bsdp_option_generic.go @@ -16,21 +16,8 @@ type OptGeneric struct { // ParseOptGeneric parses a bytestream and creates a new OptGeneric from it, // or an error. -func ParseOptGeneric(data []byte) (*OptGeneric, error) { - if len(data) == 0 { - return nil, dhcpv4.ErrZeroLengthByteStream - } - var ( - length int - optionData []byte - ) - code := dhcpv4.OptionCode(data[0]) - length = int(data[1]) - if len(data) < length+2 { - return nil, fmt.Errorf("invalid data length: declared %v, actual %v", length, len(data)) - } - optionData = data[2 : length+2] - return &OptGeneric{OptionCode: code, Data: optionData}, nil +func ParseOptGeneric(code dhcpv4.OptionCode, data []byte) (*OptGeneric, error) { + return &OptGeneric{OptionCode: code, Data: data}, nil } // Code returns the generic option code. @@ -45,7 +32,7 @@ func (o OptGeneric) ToBytes() []byte { // String returns a human-readable representation of a generic option. func (o OptGeneric) String() string { - code, ok := OptionCodeToString[o.Code()] + code, ok := optionCodeToString[o.Code()] if !ok { code = "Unknown" } diff --git a/dhcpv4/bsdp/bsdp_option_generic_test.go b/dhcpv4/bsdp/bsdp_option_generic_test.go index 27436dd..eae77e1 100644 --- a/dhcpv4/bsdp/bsdp_option_generic_test.go +++ b/dhcpv4/bsdp/bsdp_option_generic_test.go @@ -7,19 +7,11 @@ import ( ) func TestParseOptGeneric(t *testing.T) { - // Empty bytestream produces error - _, err := ParseOptGeneric([]byte{}) - require.Error(t, err, "error from empty bytestream") - // Good parse - o, err := ParseOptGeneric([]byte{1, 1, 1}) + o, err := ParseOptGeneric(OptionMessageType, []byte{1}) require.NoError(t, err) require.Equal(t, OptionMessageType, o.Code()) require.Equal(t, MessageTypeList, MessageType(o.Data[0])) - - // Bad parse - o, err = ParseOptGeneric([]byte{1, 2, 1}) - require.Error(t, err, "invalid length") } func TestOptGenericCode(t *testing.T) { diff --git a/dhcpv4/bsdp/bsdp_option_machine_name.go b/dhcpv4/bsdp/bsdp_option_machine_name.go index dc05378..cffba2e 100644 --- a/dhcpv4/bsdp/bsdp_option_machine_name.go +++ b/dhcpv4/bsdp/bsdp_option_machine_name.go @@ -1,15 +1,13 @@ package bsdp import ( - "fmt" - "github.com/insomniacslk/dhcp/dhcpv4" ) +// OptMachineName represents a BSDP message type. +// // Implements the BSDP option machine name, which gives the Netboot server's // machine name. - -// OptMachineName represents a BSDP message type. type OptMachineName struct { Name string } @@ -17,18 +15,7 @@ type OptMachineName struct { // ParseOptMachineName constructs an OptMachineName struct from a sequence of // bytes and returns it, or an error. func ParseOptMachineName(data []byte) (*OptMachineName, error) { - if len(data) < 2 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionMachineName { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionMachineName, code) - } - length := int(data[1]) - if len(data) < length+2 { - return nil, fmt.Errorf("expected length %d, got %d instead", length, len(data)) - } - return &OptMachineName{Name: string(data[2 : length+2])}, nil + return &OptMachineName{Name: string(data)}, nil } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_machine_name_test.go b/dhcpv4/bsdp/bsdp_option_machine_name_test.go index 9019020..712bc49 100644 --- a/dhcpv4/bsdp/bsdp_option_machine_name_test.go +++ b/dhcpv4/bsdp/bsdp_option_machine_name_test.go @@ -15,25 +15,10 @@ func TestOptMachineNameInterfaceMethods(t *testing.T) { } func TestParseOptMachineName(t *testing.T) { - data := []byte{130, 7, 's', 'o', 'm', 'e', 'b', 'o', 'x'} + data := []byte{'s', 'o', 'm', 'e', 'b', 'o', 'x'} o, err := ParseOptMachineName(data) require.NoError(t, err) require.Equal(t, &OptMachineName{"somebox"}, o) - - // Short byte stream - data = []byte{130} - _, err = ParseOptMachineName(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 1, 1} - _, err = ParseOptMachineName(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{130, 5, 1} - _, err = ParseOptMachineName(data) - require.Error(t, err, "should get error from bad length") } func TestOptMachineNameString(t *testing.T) { diff --git a/dhcpv4/bsdp/bsdp_option_message_type.go b/dhcpv4/bsdp/bsdp_option_message_type.go index c11b56b..8c3c3d4 100644 --- a/dhcpv4/bsdp/bsdp_option_message_type.go +++ b/dhcpv4/bsdp/bsdp_option_message_type.go @@ -4,12 +4,13 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) +// MessageType represents the different BSDP message types. +// // Implements the BSDP option message type. Can be one of LIST, SELECT, or // FAILED. - -// MessageType represents the different BSDP message types. type MessageType byte // BSDP Message types - e.g. LIST, SELECT, FAILED @@ -20,14 +21,14 @@ const ( ) func (m MessageType) String() string { - if s, ok := MessageTypeToString[m]; ok { + if s, ok := messageTypeToString[m]; ok { return s } return "Unknown" } -// MessageTypeToString maps each BSDP message type to a human-readable string. -var MessageTypeToString = map[MessageType]string{ +// messageTypeToString maps each BSDP message type to a human-readable string. +var messageTypeToString = map[MessageType]string{ MessageTypeList: "LIST", MessageTypeSelect: "SELECT", MessageTypeFailed: "FAILED", @@ -41,18 +42,8 @@ type OptMessageType struct { // ParseOptMessageType constructs an OptMessageType struct from a sequence of // bytes and returns it, or an error. func ParseOptMessageType(data []byte) (*OptMessageType, error) { - if len(data) < 3 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionMessageType { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionMessageType, code) - } - length := int(data[1]) - if length != 1 { - return nil, fmt.Errorf("expected length 1, got %d instead", length) - } - return &OptMessageType{Type: MessageType(data[2])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptMessageType{Type: MessageType(buf.Read8())}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_message_type_test.go b/dhcpv4/bsdp/bsdp_option_message_type_test.go index 9b7564f..41652be 100644 --- a/dhcpv4/bsdp/bsdp_option_message_type_test.go +++ b/dhcpv4/bsdp/bsdp_option_message_type_test.go @@ -14,25 +14,10 @@ func TestOptMessageTypeInterfaceMethods(t *testing.T) { } func TestParseOptMessageType(t *testing.T) { - data := []byte{1, 1, 1} // DISCOVER + data := []byte{1} // DISCOVER o, err := ParseOptMessageType(data) require.NoError(t, err) require.Equal(t, &OptMessageType{MessageTypeList}, o) - - // Short byte stream - data = []byte{1, 1} - _, err = ParseOptMessageType(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 1, 1} - _, err = ParseOptMessageType(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{1, 5, 1} - _, err = ParseOptMessageType(data) - require.Error(t, err, "should get error from bad length") } func TestOptMessageTypeString(t *testing.T) { diff --git a/dhcpv4/bsdp/bsdp_option_reply_port.go b/dhcpv4/bsdp/bsdp_option_reply_port.go index f1cc49f..da5e9c4 100644 --- a/dhcpv4/bsdp/bsdp_option_reply_port.go +++ b/dhcpv4/bsdp/bsdp_option_reply_port.go @@ -5,14 +5,15 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) +// OptReplyPort represents a BSDP protocol version. +// // Implements the BSDP option reply port. This is used when BSDP responses // should be sent to a reply port other than the DHCP default. The macOS GUI // "Startup Disk Select" sends this option since it's operating in an // unprivileged context. - -// OptReplyPort represents a BSDP protocol version. type OptReplyPort struct { Port uint16 } @@ -20,19 +21,8 @@ type OptReplyPort struct { // ParseOptReplyPort constructs an OptReplyPort struct from a sequence of // bytes and returns it, or an error. func ParseOptReplyPort(data []byte) (*OptReplyPort, error) { - if len(data) < 4 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionReplyPort { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionReplyPort, code) - } - length := int(data[1]) - if length != 2 { - return nil, fmt.Errorf("expected length 2, got %d instead", length) - } - port := binary.BigEndian.Uint16(data[2:4]) - return &OptReplyPort{port}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptReplyPort{buf.Read16()}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_reply_port_test.go b/dhcpv4/bsdp/bsdp_option_reply_port_test.go index c9906ff..719bbc8 100644 --- a/dhcpv4/bsdp/bsdp_option_reply_port_test.go +++ b/dhcpv4/bsdp/bsdp_option_reply_port_test.go @@ -14,23 +14,18 @@ func TestOptReplyPortInterfaceMethods(t *testing.T) { } func TestParseOptReplyPort(t *testing.T) { - data := []byte{byte(OptionReplyPort), 2, 0, 1} + data := []byte{0, 1} o, err := ParseOptReplyPort(data) require.NoError(t, err) require.Equal(t, &OptReplyPort{1}, o) // Short byte stream - data = []byte{byte(OptionReplyPort), 2} + data = []byte{} _, err = ParseOptReplyPort(data) require.Error(t, err, "should get error from short byte stream") - // Wrong code - data = []byte{54, 2, 1, 0} - _, err = ParseOptReplyPort(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{byte(OptionReplyPort), 4, 1, 0} + data = []byte{1} _, err = ParseOptReplyPort(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id.go b/dhcpv4/bsdp/bsdp_option_selected_boot_image_id.go index 5b00ded..52b6eab 100644 --- a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id.go +++ b/dhcpv4/bsdp/bsdp_option_selected_boot_image_id.go @@ -4,12 +4,13 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) +// OptSelectedBootImageID contains the selected boot image ID. +// // Implements the BSDP option selected boot image ID, which tells the server // which boot image has been selected by the client. - -// OptSelectedBootImageID contains the selected boot image ID. type OptSelectedBootImageID struct { ID BootImageID } @@ -17,22 +18,12 @@ type OptSelectedBootImageID struct { // ParseOptSelectedBootImageID constructs an OptSelectedBootImageID struct from a sequence of // bytes and returns it, or an error. func ParseOptSelectedBootImageID(data []byte) (*OptSelectedBootImageID, error) { - if len(data) < 6 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionSelectedBootImageID { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionSelectedBootImageID, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("expected length 4, got %d instead", length) - } - id, err := BootImageIDFromBytes(data[2:6]) - if err != nil { + var o OptSelectedBootImageID + buf := uio.NewBigEndianBuffer(data) + if err := o.ID.Unmarshal(buf); err != nil { return nil, err } - return &OptSelectedBootImageID{*id}, nil + return &o, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id_test.go b/dhcpv4/bsdp/bsdp_option_selected_boot_image_id_test.go index 9529e41..a55fd9f 100644 --- a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id_test.go +++ b/dhcpv4/bsdp/bsdp_option_selected_boot_image_id_test.go @@ -17,24 +17,18 @@ func TestOptSelectedBootImageIDInterfaceMethods(t *testing.T) { func TestParseOptSelectedBootImageID(t *testing.T) { b := BootImageID{IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 1001} - bootImageBytes := b.ToBytes() - data := append([]byte{byte(OptionSelectedBootImageID), 4}, bootImageBytes...) + data := b.ToBytes() o, err := ParseOptSelectedBootImageID(data) require.NoError(t, err) require.Equal(t, &OptSelectedBootImageID{b}, o) // Short byte stream - data = []byte{byte(OptionSelectedBootImageID), 4} + data = []byte{} _, err = ParseOptSelectedBootImageID(data) require.Error(t, err, "should get error from short byte stream") - // Wrong code - data = []byte{54, 2, 1, 0, 0, 0} - _, err = ParseOptSelectedBootImageID(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{byte(OptionSelectedBootImageID), 5, 1, 0, 0, 0, 0} + data = []byte{1, 0, 0, 0, 0} _, err = ParseOptSelectedBootImageID(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/bsdp/bsdp_option_server_identifier.go b/dhcpv4/bsdp/bsdp_option_server_identifier.go index 252a0aa..26ec37a 100644 --- a/dhcpv4/bsdp/bsdp_option_server_identifier.go +++ b/dhcpv4/bsdp/bsdp_option_server_identifier.go @@ -5,6 +5,7 @@ import ( "net" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) // OptServerIdentifier represents an option encapsulating the server identifier. @@ -15,21 +16,8 @@ type OptServerIdentifier struct { // ParseOptServerIdentifier returns a new OptServerIdentifier from a byte // stream, or error if any. func ParseOptServerIdentifier(data []byte) (*OptServerIdentifier, error) { - if len(data) < 2 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionServerIdentifier { - return nil, fmt.Errorf("expected code %v, got %v", OptionServerIdentifier, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("unexpected length: expected 4, got %v", length) - } - if len(data) < 6 { - return nil, dhcpv4.ErrShortByteStream - } - return &OptServerIdentifier{ServerID: net.IP(data[2 : 2+length])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptServerIdentifier{ServerID: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_server_identifier_test.go b/dhcpv4/bsdp/bsdp_option_server_identifier_test.go index 5267caa..d832c40 100644 --- a/dhcpv4/bsdp/bsdp_option_server_identifier_test.go +++ b/dhcpv4/bsdp/bsdp_option_server_identifier_test.go @@ -26,15 +26,9 @@ func TestParseOptServerIdentifier(t *testing.T) { require.Error(t, err, "empty byte stream") o, err = ParseOptServerIdentifier([]byte{3, 4, 192}) - require.Error(t, err, "short byte stream") - - o, err = ParseOptServerIdentifier([]byte{3, 3, 192, 168, 0, 1}) require.Error(t, err, "wrong IP length") - o, err = ParseOptServerIdentifier([]byte{53, 4, 192, 168, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptServerIdentifier([]byte{3, 4, 192, 168, 0, 1}) + o, err = ParseOptServerIdentifier([]byte{192, 168, 0, 1}) require.NoError(t, err) require.Equal(t, net.IP{192, 168, 0, 1}, o.ServerID) } diff --git a/dhcpv4/bsdp/bsdp_option_server_priority.go b/dhcpv4/bsdp/bsdp_option_server_priority.go index 1952b7e..66bfa44 100644 --- a/dhcpv4/bsdp/bsdp_option_server_priority.go +++ b/dhcpv4/bsdp/bsdp_option_server_priority.go @@ -5,31 +5,19 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) -// This option implements the server identifier option -// https://tools.ietf.org/html/rfc2132 - // OptServerPriority represents an option encapsulating the server priority. type OptServerPriority struct { - Priority int + Priority uint16 } // ParseOptServerPriority returns a new OptServerPriority from a byte stream, or // error if any. func ParseOptServerPriority(data []byte) (*OptServerPriority, error) { - if len(data) < 4 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionServerPriority { - return nil, fmt.Errorf("expected code %v, got %v", OptionServerPriority, code) - } - length := int(data[1]) - if length != 2 { - return nil, fmt.Errorf("unexpected length: expected 2, got %v", length) - } - return &OptServerPriority{Priority: int(binary.BigEndian.Uint16(data[2:4]))}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptServerPriority{Priority: buf.Read16()}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_server_priority_test.go b/dhcpv4/bsdp/bsdp_option_server_priority_test.go index d12ad55..cbcef1d 100644 --- a/dhcpv4/bsdp/bsdp_option_server_priority_test.go +++ b/dhcpv4/bsdp/bsdp_option_server_priority_test.go @@ -22,16 +22,10 @@ func TestParseOptServerPriority(t *testing.T) { o, err = ParseOptServerPriority([]byte{}) require.Error(t, err, "empty byte stream") - o, err = ParseOptServerPriority([]byte{4, 2, 1}) + o, err = ParseOptServerPriority([]byte{1}) require.Error(t, err, "short byte stream") - o, err = ParseOptServerPriority([]byte{4, 3, 1, 1}) - require.Error(t, err, "wrong priority length") - - o, err = ParseOptServerPriority([]byte{53, 2, 168, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptServerPriority([]byte{4, 2, 0, 100}) + o, err = ParseOptServerPriority([]byte{0, 100}) require.NoError(t, err) - require.Equal(t, 100, o.Priority) + require.Equal(t, uint16(100), o.Priority) } diff --git a/dhcpv4/bsdp/bsdp_option_version.go b/dhcpv4/bsdp/bsdp_option_version.go index 8431a94..38158d7 100644 --- a/dhcpv4/bsdp/bsdp_option_version.go +++ b/dhcpv4/bsdp/bsdp_option_version.go @@ -4,10 +4,9 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) -// Implements the BSDP option version. Can be one of 1.0 or 1.1 - // Specific versions. var ( Version1_0 = []byte{1, 0} @@ -15,6 +14,8 @@ var ( ) // OptVersion represents a BSDP protocol version. +// +// Implements the BSDP option version. Can be one of 1.0 or 1.1 type OptVersion struct { Version []byte } @@ -22,18 +23,8 @@ type OptVersion struct { // ParseOptVersion constructs an OptVersion struct from a sequence of // bytes and returns it, or an error. func ParseOptVersion(data []byte) (*OptVersion, error) { - if len(data) < 4 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionVersion { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionVersion, code) - } - length := int(data[1]) - if length != 2 { - return nil, fmt.Errorf("expected length 2, got %d instead", length) - } - return &OptVersion{data[2:4]}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptVersion{buf.CopyN(2)}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_version_test.go b/dhcpv4/bsdp/bsdp_option_version_test.go index c6a6afc..d28f243 100644 --- a/dhcpv4/bsdp/bsdp_option_version_test.go +++ b/dhcpv4/bsdp/bsdp_option_version_test.go @@ -14,25 +14,15 @@ func TestOptVersionInterfaceMethods(t *testing.T) { } func TestParseOptVersion(t *testing.T) { - data := []byte{2, 2, 1, 1} + data := []byte{1, 1} o, err := ParseOptVersion(data) require.NoError(t, err) require.Equal(t, &OptVersion{Version1_1}, o) // Short byte stream - data = []byte{2, 2} + data = []byte{2} _, err = ParseOptVersion(data) require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 0} - _, err = ParseOptVersion(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{2, 4, 1, 0} - _, err = ParseOptVersion(data) - require.Error(t, err, "should get error from bad length") } func TestOptVersionString(t *testing.T) { diff --git a/dhcpv4/bsdp/option_vendor_specific_information.go b/dhcpv4/bsdp/option_vendor_specific_information.go index e735b57..1bd41a7 100644 --- a/dhcpv4/bsdp/option_vendor_specific_information.go +++ b/dhcpv4/bsdp/option_vendor_specific_information.go @@ -1,8 +1,6 @@ package bsdp import ( - "errors" - "fmt" "strings" "github.com/insomniacslk/dhcp/dhcpv4" @@ -11,20 +9,17 @@ import ( // OptVendorSpecificInformation encapsulates the BSDP-specific options used for // the protocol. type OptVendorSpecificInformation struct { - Options []dhcpv4.Option + Options dhcpv4.Options } // parseOption is similar to dhcpv4.ParseOption, except that it switches based // on the BSDP specific options. -func parseOption(data []byte) (dhcpv4.Option, error) { - if len(data) == 0 { - return nil, dhcpv4.ErrZeroLengthByteStream - } +func parseOption(code dhcpv4.OptionCode, data []byte) (dhcpv4.Option, error) { var ( opt dhcpv4.Option err error ) - switch dhcpv4.OptionCode(data[0]) { + switch code { case OptionBootImageList: opt, err = ParseOptBootImageList(data) case OptionDefaultBootImageID: @@ -44,7 +39,7 @@ func parseOption(data []byte) (dhcpv4.Option, error) { case OptionVersion: opt, err = ParseOptVersion(data) default: - opt, err = ParseOptGeneric(data) + opt, err = ParseOptGeneric(code, data) } if err != nil { return nil, err @@ -55,39 +50,10 @@ func parseOption(data []byte) (dhcpv4.Option, error) { // ParseOptVendorSpecificInformation constructs an OptVendorSpecificInformation struct from a sequence of // bytes and returns it, or an error. func ParseOptVendorSpecificInformation(data []byte) (*OptVendorSpecificInformation, error) { - // Should at least have code + length - if len(data) < 2 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != dhcpv4.OptionVendorSpecificInformation { - return nil, fmt.Errorf("expected option %v, got %v instead", dhcpv4.OptionVendorSpecificInformation, code) - } - length := int(data[1]) - if len(data) < length+2 { - return nil, fmt.Errorf("expected length 2, got %d instead", length) - } - - options := make([]dhcpv4.Option, 0, 10) - idx := 2 - for { - if idx == len(data) { - break - } - // This should never happen. - if idx > len(data) { - return nil, errors.New("read past the end of options") - } - opt, err := parseOption(data[idx:]) - if err != nil { - return nil, err - } - options = append(options, opt) - - // Account for code + length bytes - idx += 2 + opt.Length() + options, err := dhcpv4.OptionsFromBytesWithParser(data, parseOption, false /* don't check for OptionEnd tag */) + if err != nil { + return nil, err } - return &OptVendorSpecificInformation{options}, nil } @@ -133,20 +99,10 @@ func (o *OptVendorSpecificInformation) Length() int { // GetOption returns all suboptions that match the given OptionCode code. func (o *OptVendorSpecificInformation) GetOption(code dhcpv4.OptionCode) []dhcpv4.Option { - var opts []dhcpv4.Option - for _, opt := range o.Options { - if opt.Code() == code { - opts = append(opts, opt) - } - } - return opts + return o.Options.GetOption(code) } // GetOneOption returns the first suboption that matches the OptionCode code. func (o *OptVendorSpecificInformation) GetOneOption(code dhcpv4.OptionCode) dhcpv4.Option { - opts := o.GetOption(code) - if len(opts) == 0 { - return nil - } - return opts[0] + return o.Options.GetOneOption(code) } diff --git a/dhcpv4/bsdp/option_vendor_specific_information_test.go b/dhcpv4/bsdp/option_vendor_specific_information_test.go index 8a4368f..8aa2bdf 100644 --- a/dhcpv4/bsdp/option_vendor_specific_information_test.go +++ b/dhcpv4/bsdp/option_vendor_specific_information_test.go @@ -34,19 +34,11 @@ func TestParseOptVendorSpecificInformation(t *testing.T) { o *OptVendorSpecificInformation err error ) - o, err = ParseOptVendorSpecificInformation([]byte{}) - require.Error(t, err, "empty byte stream") - o, err = ParseOptVendorSpecificInformation([]byte{1, 2}) require.Error(t, err, "short byte stream") - o, err = ParseOptVendorSpecificInformation([]byte{53, 2, 1, 1}) - require.Error(t, err, "wrong option code") - // Good byte stream data := []byte{ - 43, // code - 7, // length 1, 1, 1, // List option 2, 2, 1, 1, // Version option } @@ -59,13 +51,13 @@ func TestParseOptVendorSpecificInformation(t *testing.T) { }, } require.Equal(t, 2, len(o.Options), "number of parsed suboptions") - require.Equal(t, expected.Options[0].Code(), o.Options[0].Code()) - require.Equal(t, expected.Options[1].Code(), o.Options[1].Code()) + typ := o.GetOneOption(OptionMessageType) + version := o.GetOneOption(OptionVersion) + require.Equal(t, expected.Options[0].Code(), typ.Code()) + require.Equal(t, expected.Options[1].Code(), version.Code()) // Short byte stream (length and data mismatch) data = []byte{ - 43, // code - 7, // length 1, 1, 1, // List option 2, 2, 1, // Version option } @@ -74,8 +66,6 @@ func TestParseOptVendorSpecificInformation(t *testing.T) { // Bad option data = []byte{ - 43, // code - 7, // length 1, 1, 1, // List option 2, 2, 1, // Version option 5, 3, 1, 1, 1, // Reply port option @@ -85,8 +75,6 @@ func TestParseOptVendorSpecificInformation(t *testing.T) { // Boot images + default. data = []byte{ - 43, // code - 7, // length 1, 1, 1, // List option 2, 2, 1, 1, // Version option 5, 2, 1, 1, // Reply port option diff --git a/dhcpv4/bsdp/types.go b/dhcpv4/bsdp/types.go index ac4b05b..aa9a824 100644 --- a/dhcpv4/bsdp/types.go +++ b/dhcpv4/bsdp/types.go @@ -26,9 +26,9 @@ const ( OptionMachineName dhcpv4.OptionCode = 130 ) -// OptionCodeToString maps BSDP OptionCodes to human-readable strings +// optionCodeToString maps BSDP OptionCodes to human-readable strings // describing what they are. -var OptionCodeToString = map[dhcpv4.OptionCode]string{ +var optionCodeToString = map[dhcpv4.OptionCode]string{ OptionMessageType: "BSDP Message Type", OptionVersion: "BSDP Version", OptionServerIdentifier: "BSDP Server Identifier", diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index b24502f..2f832e8 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -46,7 +46,7 @@ type DHCPv4 struct { ClientHWAddr net.HardwareAddr ServerHostName string BootFileName string - Options []Option + Options Options } // Modifier defines the signature for functions that can modify DHCPv4 @@ -363,9 +363,32 @@ func (d *DHCPv4) SetUnicast() { // GetOption will attempt to get all options that match a DHCPv4 option // from its OptionCode. If the option was not found it will return an // empty list. +// +// According to RFC 3396, options that are specified more than once are +// concatenated, and hence this should always just return one option. func (d *DHCPv4) GetOption(code OptionCode) []Option { + return d.Options.GetOption(code) +} + +// GetOneOption will attempt to get an option that match a Option code. +// If there are multiple options with the same OptionCode it will only return +// the first one found. If no matching option is found nil will be returned. +func (d *DHCPv4) GetOneOption(code OptionCode) Option { + return d.Options.GetOneOption(code) +} + +// Options is a collection of options. +type Options []Option + +// GetOption will attempt to get all options that match a DHCPv4 option +// from its OptionCode. If the option was not found it will return an +// empty list. +// +// According to RFC 3396, options that are specified more than once are +// concatenated, and hence this should always just return one option. +func (o Options) GetOption(code OptionCode) []Option { opts := []Option{} - for _, opt := range d.Options { + for _, opt := range o { if opt.Code() == code { opts = append(opts, opt) } @@ -376,8 +399,8 @@ func (d *DHCPv4) GetOption(code OptionCode) []Option { // GetOneOption will attempt to get an option that match a Option code. // If there are multiple options with the same OptionCode it will only return // the first one found. If no matching option is found nil will be returned. -func (d *DHCPv4) GetOneOption(code OptionCode) Option { - for _, opt := range d.Options { +func (o Options) GetOneOption(code OptionCode) Option { + for _, opt := range o { if opt.Code() == code { return opt } @@ -385,21 +408,6 @@ func (d *DHCPv4) GetOneOption(code OptionCode) Option { return nil } -// StrippedOptions works like Options, but it does not return anything after the -// End option. -func (d *DHCPv4) StrippedOptions() []Option { - // differently from Options() this function strips away anything coming - // after the End option (normally just Pad options). - strippedOptions := []Option{} - for _, opt := range d.Options { - strippedOptions = append(strippedOptions, opt) - if opt.Code() == OptionEnd { - break - } - } - return strippedOptions -} - // AddOption appends an option to the existing ones. If the last option is an // OptionEnd, it will be inserted before that. It does not deal with End // options that appead before the end, like in malformed packets. diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go index f6f9d7c..9be2a1a 100644 --- a/dhcpv4/dhcpv4_test.go +++ b/dhcpv4/dhcpv4_test.go @@ -44,6 +44,7 @@ func TestFromBytes(t *testing.T) { 0, 0, 0, 0, // gateway IP address 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // client MAC address + padding } + // server host name expectedHostname := []byte{} for i := 0; i < 64; i++ { @@ -57,7 +58,7 @@ func TestFromBytes(t *testing.T) { } data = append(data, expectedBootfilename...) // magic cookie, then no options - data = append(data, []byte{99, 130, 83, 99}...) + data = append(data, magicCookie[:]...) d, err := FromBytes(data) require.NoError(t, err) @@ -179,11 +180,11 @@ func TestGetOption(t *testing.T) { } hostnameOpt := &OptionGeneric{OptionCode: OptionHostName, Data: []byte("darkstar")} - bootFileOpt1 := &OptBootfileName{[]byte("boot.img")} - bootFileOpt2 := &OptBootfileName{[]byte("boot2.img")} + bootFileOpt1 := &OptBootfileName{"boot.img"} + bootFileOpt2 := &OptBootfileName{"boot2.img"} d.AddOption(hostnameOpt) - d.AddOption(&OptBootfileName{[]byte("boot.img")}) - d.AddOption(&OptBootfileName{[]byte("boot2.img")}) + d.AddOption(&OptBootfileName{"boot.img"}) + d.AddOption(&OptBootfileName{"boot2.img"}) require.Equal(t, d.GetOption(OptionHostName), []Option{hostnameOpt}) require.Equal(t, d.GetOption(OptionBootfileName), []Option{bootFileOpt1, bootFileOpt2}) @@ -227,32 +228,6 @@ func TestUpdateOption(t *testing.T) { require.Equal(t, OptionEnd, d.Options[1].Code()) } -func TestStrippedOptions(t *testing.T) { - // Normal set of options that terminate with OptionEnd. - d, err := New() - require.NoError(t, err) - opts := []Option{ - &OptBootfileName{[]byte("boot.img")}, - &OptClassIdentifier{"something"}, - &OptionGeneric{OptionCode: OptionEnd}, - } - d.Options = opts - stripped := d.StrippedOptions() - require.Equal(t, len(opts), len(stripped)) - for i := range stripped { - require.Equal(t, opts[i], stripped[i]) - } - - // Set of options with additional options after OptionEnd - opts = append(opts, &OptMaximumDHCPMessageSize{uint16(1234)}) - d.Options = opts - stripped = d.StrippedOptions() - require.Equal(t, len(opts)-1, len(stripped)) - for i := range stripped { - require.Equal(t, opts[i], stripped[i]) - } -} - func TestDHCPv4NewRequestFromOffer(t *testing.T) { offer, err := New() require.NoError(t, err) diff --git a/dhcpv4/option_archtype.go b/dhcpv4/option_archtype.go index 92a3769..f5882bc 100644 --- a/dhcpv4/option_archtype.go +++ b/dhcpv4/option_archtype.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/insomniacslk/dhcp/iana" + "github.com/u-root/u-root/pkg/uio" ) // OptClientArchType represents an option encapsulating the Client System @@ -35,7 +36,7 @@ func (o *OptClientArchType) ToBytes() []byte { // Length returns the length of the data portion (excluding option code an byte // length). func (o *OptClientArchType) Length() int { - return 2*len(o.ArchTypes) + return 2 * len(o.ArchTypes) } // String returns a human-readable string. @@ -53,24 +54,14 @@ func (o *OptClientArchType) String() string { // ParseOptClientArchType returns a new OptClientArchType from a byte stream, // or error if any. func ParseOptClientArchType(data []byte) (*OptClientArchType, error) { - if len(data) < 2 { - return nil, ErrShortByteStream + buf := uio.NewBigEndianBuffer(data) + if buf.Len() == 0 { + return nil, fmt.Errorf("must have at least one archtype if option is present") } - code := OptionCode(data[0]) - if code != OptionClientSystemArchitectureType { - return nil, fmt.Errorf("expected code %v, got %v", OptionClientSystemArchitectureType, code) - } - length := int(data[1]) - if length == 0 || length%2 != 0 { - return nil, fmt.Errorf("Invalid length: expected multiple of 2 larger than 2, got %v", length) - } - if len(data) < 2+length { - return nil, ErrShortByteStream - } - archTypes := make([]iana.ArchType, 0, length%2) - for idx := 0; idx < length; idx += 2 { - b := data[2+idx : 2+idx+2] - archTypes = append(archTypes, iana.ArchType(binary.BigEndian.Uint16(b))) + + archTypes := make([]iana.ArchType, 0, buf.Len()/2) + for buf.Has(2) { + archTypes = append(archTypes, iana.ArchType(buf.Read16())) } - return &OptClientArchType{ArchTypes: archTypes}, nil + return &OptClientArchType{ArchTypes: archTypes}, buf.FinError() } diff --git a/dhcpv4/option_archtype_test.go b/dhcpv4/option_archtype_test.go index d803328..482ebb1 100644 --- a/dhcpv4/option_archtype_test.go +++ b/dhcpv4/option_archtype_test.go @@ -13,7 +13,7 @@ func TestParseOptClientArchType(t *testing.T) { 2, // Length 0, 6, // EFI_IA32 } - opt, err := ParseOptClientArchType(data) + opt, err := ParseOptClientArchType(data[2:]) require.NoError(t, err) require.Equal(t, opt.ArchTypes[0], iana.EFI_IA32) } @@ -25,7 +25,7 @@ func TestParseOptClientArchTypeMultiple(t *testing.T) { 0, 6, // EFI_IA32 0, 2, // EFI_ITANIUM } - opt, err := ParseOptClientArchType(data) + opt, err := ParseOptClientArchType(data[2:]) require.NoError(t, err) require.Equal(t, opt.ArchTypes[0], iana.EFI_IA32) require.Equal(t, opt.ArchTypes[1], iana.EFI_ITANIUM) @@ -43,7 +43,7 @@ func TestOptClientArchTypeParseAndToBytes(t *testing.T) { 2, // Length 0, 8, // EFI_XSCALE } - opt, err := ParseOptClientArchType(data) + opt, err := ParseOptClientArchType(data[2:]) require.NoError(t, err) require.Equal(t, opt.ToBytes(), data) } @@ -55,7 +55,7 @@ func TestOptClientArchTypeParseAndToBytesMultiple(t *testing.T) { 0, 8, // EFI_XSCALE 0, 6, // EFI_IA32 } - opt, err := ParseOptClientArchType(data) + opt, err := ParseOptClientArchType(data[2:]) require.NoError(t, err) require.Equal(t, opt.ToBytes(), data) } diff --git a/dhcpv4/option_bootfile_name.go b/dhcpv4/option_bootfile_name.go index ca9317b..e06e4eb 100644 --- a/dhcpv4/option_bootfile_name.go +++ b/dhcpv4/option_bootfile_name.go @@ -9,7 +9,7 @@ import ( // OptBootfileName implements the BootFile Name option type OptBootfileName struct { - BootfileName []byte + BootfileName string } // Code returns the option code @@ -19,7 +19,7 @@ func (op *OptBootfileName) Code() OptionCode { // ToBytes serializes the option and returns it as a sequence of bytes func (op *OptBootfileName) ToBytes() []byte { - return append([]byte{byte(op.Code()), byte(op.Length())}, op.BootfileName...) + return append([]byte{byte(op.Code()), byte(op.Length())}, []byte(op.BootfileName)...) } // Length returns the option length in bytes @@ -34,21 +34,5 @@ func (op *OptBootfileName) String() string { // ParseOptBootfileName returns a new OptBootfile from a byte stream or error if any func ParseOptBootfileName(data []byte) (*OptBootfileName, error) { - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionBootfileName { - return nil, fmt.Errorf("ParseOptBootfileName: invalid code: %v; want %v", code, OptionBootfileName) - } - length := int(data[1]) - if length < 1 { - return nil, fmt.Errorf("Bootfile name has invalid length of %d", length) - } - bootFileNameData := data[2:] - if len(bootFileNameData) < length { - return nil, fmt.Errorf("ParseOptBootfileName: short data: %d bytes; want %d", - len(bootFileNameData), length) - } - return &OptBootfileName{BootfileName: bootFileNameData[:length]}, nil + return &OptBootfileName{BootfileName: string(data)}, nil } diff --git a/dhcpv4/option_bootfile_name_test.go b/dhcpv4/option_bootfile_name_test.go index 0c7c200..2671ac5 100644 --- a/dhcpv4/option_bootfile_name_test.go +++ b/dhcpv4/option_bootfile_name_test.go @@ -13,7 +13,7 @@ func TestOptBootfileNameCode(t *testing.T) { func TestOptBootfileNameToBytes(t *testing.T) { opt := OptBootfileName{ - BootfileName: []byte("linuxboot"), + BootfileName: "linuxboot", } data := opt.ToBytes() expected := []byte{ @@ -26,40 +26,15 @@ func TestOptBootfileNameToBytes(t *testing.T) { func TestParseOptBootfileName(t *testing.T) { expected := []byte{ - 67, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', + 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } opt, err := ParseOptBootfileName(expected) require.NoError(t, err) require.Equal(t, 9, opt.Length()) - require.Equal(t, "linuxboot", string(opt.BootfileName)) -} - -func TestParseOptBootfileNameZeroLength(t *testing.T) { - expected := []byte{ - 67, 0, - } - _, err := ParseOptBootfileName(expected) - require.Error(t, err) -} - -func TestParseOptBootfileNameInvalidLength(t *testing.T) { - expected := []byte{ - 67, 9, 'l', 'i', 'n', 'u', 'x', 'b', - } - _, err := ParseOptBootfileName(expected) - require.Error(t, err) -} - -func TestParseOptBootfileNameShortLength(t *testing.T) { - expected := []byte{ - 67, 4, 'l', 'i', 'n', 'u', 'x', - } - opt, err := ParseOptBootfileName(expected) - require.NoError(t, err) - require.Equal(t, []byte("linu"), opt.BootfileName) + require.Equal(t, "linuxboot", opt.BootfileName) } func TestOptBootfileNameString(t *testing.T) { - o := OptBootfileName{BootfileName: []byte("testy test")} + o := OptBootfileName{BootfileName: "testy test"} require.Equal(t, "Bootfile Name -> testy test", o.String()) } diff --git a/dhcpv4/option_broadcast_address.go b/dhcpv4/option_broadcast_address.go index ffa57e3..fdce946 100644 --- a/dhcpv4/option_broadcast_address.go +++ b/dhcpv4/option_broadcast_address.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "fmt" "net" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the server identifier option @@ -16,21 +18,8 @@ type OptBroadcastAddress struct { // ParseOptBroadcastAddress returns a new OptBroadcastAddress from a byte // stream, or error if any. func ParseOptBroadcastAddress(data []byte) (*OptBroadcastAddress, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionBroadcastAddress { - return nil, fmt.Errorf("expected code %v, got %v", OptionBroadcastAddress, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("unexepcted length: expected 4, got %v", length) - } - if len(data) < 6 { - return nil, ErrShortByteStream - } - return &OptBroadcastAddress{BroadcastAddress: net.IP(data[2 : 2+length])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptBroadcastAddress{BroadcastAddress: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_broadcast_address_test.go b/dhcpv4/option_broadcast_address_test.go index 3572dc0..1feb6cc 100644 --- a/dhcpv4/option_broadcast_address_test.go +++ b/dhcpv4/option_broadcast_address_test.go @@ -29,16 +29,10 @@ func TestParseOptBroadcastAddress(t *testing.T) { o, err = ParseOptBroadcastAddress([]byte{}) require.Error(t, err, "empty byte stream") - o, err = ParseOptBroadcastAddress([]byte{byte(OptionBroadcastAddress), 4, 192}) - require.Error(t, err, "short byte stream") - - o, err = ParseOptBroadcastAddress([]byte{byte(OptionBroadcastAddress), 3, 192, 168, 0, 1}) + o, err = ParseOptBroadcastAddress([]byte{192, 168, 0}) require.Error(t, err, "wrong IP length") - o, err = ParseOptBroadcastAddress([]byte{53, 4, 192, 168, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptBroadcastAddress([]byte{byte(OptionBroadcastAddress), 4, 192, 168, 0, 1}) + o, err = ParseOptBroadcastAddress([]byte{192, 168, 0, 1}) require.NoError(t, err) require.Equal(t, net.IP{192, 168, 0, 1}, o.BroadcastAddress) } diff --git a/dhcpv4/option_class_identifier.go b/dhcpv4/option_class_identifier.go index ae5dba2..1a49b87 100644 --- a/dhcpv4/option_class_identifier.go +++ b/dhcpv4/option_class_identifier.go @@ -15,19 +15,7 @@ type OptClassIdentifier struct { // ParseOptClassIdentifier constructs an OptClassIdentifier struct from a sequence of // bytes and returns it, or an error. func ParseOptClassIdentifier(data []byte) (*OptClassIdentifier, error) { - // Should at least have code and length - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionClassIdentifier { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionClassIdentifier, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - return &OptClassIdentifier{Identifier: string(data[2 : 2+length])}, nil + return &OptClassIdentifier{Identifier: string(data)}, nil } // Code returns the option code. diff --git a/dhcpv4/option_class_identifier_test.go b/dhcpv4/option_class_identifier_test.go index 786ecfb..289eafe 100644 --- a/dhcpv4/option_class_identifier_test.go +++ b/dhcpv4/option_class_identifier_test.go @@ -14,25 +14,10 @@ func TestOptClassIdentifierInterfaceMethods(t *testing.T) { } func TestParseOptClassIdentifier(t *testing.T) { - data := []byte{byte(OptionClassIdentifier), 4, 't', 'e', 's', 't'} // DISCOVER + data := []byte{'t', 'e', 's', 't'} o, err := ParseOptClassIdentifier(data) require.NoError(t, err) require.Equal(t, &OptClassIdentifier{Identifier: "test"}, o) - - // Short byte stream - data = []byte{byte(OptionClassIdentifier)} - _, err = ParseOptClassIdentifier(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptClassIdentifier(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionClassIdentifier), 6, 1, 1, 1} - _, err = ParseOptClassIdentifier(data) - require.Error(t, err, "should get error from bad length") } func TestOptClassIdentifierString(t *testing.T) { diff --git a/dhcpv4/option_domain_name.go b/dhcpv4/option_domain_name.go index e876b05..673b2a6 100644 --- a/dhcpv4/option_domain_name.go +++ b/dhcpv4/option_domain_name.go @@ -13,18 +13,7 @@ type OptDomainName struct { // ParseOptDomainName returns a new OptDomainName from a byte // stream, or error if any. func ParseOptDomainName(data []byte) (*OptDomainName, error) { - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionDomainName { - return nil, fmt.Errorf("expected code %v, got %v", OptionDomainName, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - return &OptDomainName{DomainName: string(data[2 : 2+length])}, nil + return &OptDomainName{DomainName: string(data)}, nil } // Code returns the option code. diff --git a/dhcpv4/option_domain_name_server.go b/dhcpv4/option_domain_name_server.go index 470eaa0..8633cc4 100644 --- a/dhcpv4/option_domain_name_server.go +++ b/dhcpv4/option_domain_name_server.go @@ -17,26 +17,11 @@ type OptDomainNameServer struct { // ParseOptDomainNameServer returns a new OptDomainNameServer from a byte // stream, or error if any. func ParseOptDomainNameServer(data []byte) (*OptDomainNameServer, error) { - if len(data) < 2 { - return nil, ErrShortByteStream + ips, err := ParseIPs(data) + if err != nil { + return nil, err } - code := OptionCode(data[0]) - if code != OptionDomainNameServer { - return nil, fmt.Errorf("expected code %v, got %v", OptionDomainNameServer, code) - } - length := int(data[1]) - if length == 0 || length%4 != 0 { - return nil, fmt.Errorf("Invalid length: expected multiple of 4 larger than 4, got %v", length) - } - if len(data) < 2+length { - return nil, ErrShortByteStream - } - nameservers := make([]net.IP, 0, length%4) - for idx := 0; idx < length; idx += 4 { - b := data[2+idx : 2+idx+4] - nameservers = append(nameservers, net.IPv4(b[0], b[1], b[2], b[3])) - } - return &OptDomainNameServer{NameServers: nameservers}, nil + return &OptDomainNameServer{NameServers: ips}, nil } // Code returns the option code. diff --git a/dhcpv4/option_domain_name_server_test.go b/dhcpv4/option_domain_name_server_test.go index c801cb6..546d399 100644 --- a/dhcpv4/option_domain_name_server_test.go +++ b/dhcpv4/option_domain_name_server_test.go @@ -25,37 +25,23 @@ func TestParseOptDomainNameServer(t *testing.T) { 192, 168, 0, 10, // DNS #1 192, 168, 0, 20, // DNS #2 } - o, err := ParseOptDomainNameServer(data) + o, err := ParseOptDomainNameServer(data[2:]) require.NoError(t, err) servers := []net.IP{ - net.IPv4(192, 168, 0, 10), - net.IPv4(192, 168, 0, 20), + net.IP{192, 168, 0, 10}, + net.IP{192, 168, 0, 20}, } require.Equal(t, &OptDomainNameServer{NameServers: servers}, o) - // Short byte stream - data = []byte{byte(OptionDomainNameServer)} - _, err = ParseOptDomainNameServer(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptDomainNameServer(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{byte(OptionDomainNameServer), 6, 1, 1, 1} + data = []byte{1, 1, 1} _, err = ParseOptDomainNameServer(data) require.Error(t, err, "should get error from bad length") } func TestParseOptDomainNameServerNoServers(t *testing.T) { // RFC2132 requires that at least one DNS server IP is specified - data := []byte{ - byte(OptionDomainNameServer), - 0, // Length - } - _, err := ParseOptDomainNameServer(data) + _, err := ParseOptDomainNameServer([]byte{}) require.Error(t, err) } diff --git a/dhcpv4/option_domain_name_test.go b/dhcpv4/option_domain_name_test.go index ab66e8b..e88d87e 100644 --- a/dhcpv4/option_domain_name_test.go +++ b/dhcpv4/option_domain_name_test.go @@ -14,25 +14,10 @@ func TestOptDomainNameInterfaceMethods(t *testing.T) { } func TestParseOptDomainName(t *testing.T) { - data := []byte{byte(OptionDomainName), 4, 't', 'e', 's', 't'} // DISCOVER + data := []byte{'t', 'e', 's', 't'} o, err := ParseOptDomainName(data) require.NoError(t, err) require.Equal(t, &OptDomainName{DomainName: "test"}, o) - - // Short byte stream - data = []byte{byte(OptionDomainName)} - _, err = ParseOptDomainName(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptDomainName(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionDomainName), 6, 1, 1, 1} - _, err = ParseOptDomainName(data) - require.Error(t, err, "should get error from bad length") } func TestOptDomainNameString(t *testing.T) { diff --git a/dhcpv4/option_domain_search.go b/dhcpv4/option_domain_search.go index 9c24eea..5fafb6e 100644 --- a/dhcpv4/option_domain_search.go +++ b/dhcpv4/option_domain_search.go @@ -43,18 +43,7 @@ func (op *OptDomainSearch) String() string { // ParseOptDomainSearch returns a new OptDomainSearch from a byte stream, or // error if any. func ParseOptDomainSearch(data []byte) (*OptDomainSearch, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionDNSDomainSearchList { - return nil, fmt.Errorf("expected code %v, got %v", OptionDNSDomainSearchList, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - labels, err := rfc1035label.FromBytes(data[2 : length+2]) + labels, err := rfc1035label.FromBytes(data) if err != nil { return nil, err } diff --git a/dhcpv4/option_domain_search_test.go b/dhcpv4/option_domain_search_test.go index 590ccd0..0e5f8e9 100644 --- a/dhcpv4/option_domain_search_test.go +++ b/dhcpv4/option_domain_search_test.go @@ -14,7 +14,7 @@ func TestParseOptDomainSearch(t *testing.T) { 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0, } - opt, err := ParseOptDomainSearch(data) + opt, err := ParseOptDomainSearch(data[2:]) require.NoError(t, err) require.Equal(t, 2, len(opt.DomainSearch.Labels)) require.Equal(t, data[2:], opt.DomainSearch.ToBytes()) diff --git a/dhcpv4/option_generic.go b/dhcpv4/option_generic.go index 4ff35f8..964a655 100644 --- a/dhcpv4/option_generic.go +++ b/dhcpv4/option_generic.go @@ -15,23 +15,11 @@ type OptionGeneric struct { // ParseOptionGeneric parses a bytestream and creates a new OptionGeneric from // it, or an error. -func ParseOptionGeneric(data []byte) (*OptionGeneric, error) { +func ParseOptionGeneric(code OptionCode, data []byte) (Option, error) { if len(data) == 0 { return nil, errors.New("invalid zero-length bytestream") } - var ( - length int - optionData []byte - ) - code := OptionCode(data[0]) - if code != OptionPad && code != OptionEnd { - length = int(data[1]) - if len(data) < length+2 { - return nil, fmt.Errorf("invalid data length: declared %v, actual %v", length, len(data)) - } - optionData = data[2 : length+2] - } - return &OptionGeneric{OptionCode: code, Data: optionData}, nil + return &OptionGeneric{OptionCode: code, Data: data}, nil } // Code returns the generic option code. diff --git a/dhcpv4/option_generic_test.go b/dhcpv4/option_generic_test.go index dbc0fc1..5c34903 100644 --- a/dhcpv4/option_generic_test.go +++ b/dhcpv4/option_generic_test.go @@ -8,7 +8,7 @@ import ( func TestParseOptionGeneric(t *testing.T) { // Empty bytestream produces error - _, err := ParseOptionGeneric([]byte{}) + _, err := ParseOptionGeneric(OptionHostName, []byte{}) require.Error(t, err, "error from empty bytestream") } @@ -20,14 +20,6 @@ func TestOptionGenericCode(t *testing.T) { require.Equal(t, OptionDHCPMessageType, o.Code()) } -func TestOptionGenericData(t *testing.T) { - o := OptionGeneric{ - OptionCode: OptionNameServer, - Data: []byte{192, 168, 0, 1}, - } - require.Equal(t, []byte{192, 168, 0, 1}, o.Data) -} - func TestOptionGenericToBytes(t *testing.T) { o := OptionGeneric{ OptionCode: OptionDHCPMessageType, diff --git a/dhcpv4/option_host_name.go b/dhcpv4/option_host_name.go index a922a2b..decc18b 100644 --- a/dhcpv4/option_host_name.go +++ b/dhcpv4/option_host_name.go @@ -13,18 +13,7 @@ type OptHostName struct { // ParseOptHostName returns a new OptHostName from a byte stream, or error if // any. func ParseOptHostName(data []byte) (*OptHostName, error) { - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionHostName { - return nil, fmt.Errorf("expected code %v, got %v", OptionHostName, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - return &OptHostName{HostName: string(data[2 : 2+length])}, nil + return &OptHostName{HostName: string(data)}, nil } // Code returns the option code. diff --git a/dhcpv4/option_host_name_test.go b/dhcpv4/option_host_name_test.go index 7f99100..f5e8548 100644 --- a/dhcpv4/option_host_name_test.go +++ b/dhcpv4/option_host_name_test.go @@ -14,25 +14,10 @@ func TestOptHostNameInterfaceMethods(t *testing.T) { } func TestParseOptHostName(t *testing.T) { - data := []byte{byte(OptionHostName), 4, 't', 'e', 's', 't'} + data := []byte{'t', 'e', 's', 't'} o, err := ParseOptHostName(data) require.NoError(t, err) require.Equal(t, &OptHostName{HostName: "test"}, o) - - // Short byte stream - data = []byte{byte(OptionHostName)} - _, err = ParseOptHostName(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptHostName(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionHostName), 6, 1, 1, 1} - _, err = ParseOptHostName(data) - require.Error(t, err, "should get error from bad length") } func TestOptHostNameString(t *testing.T) { diff --git a/dhcpv4/option_ip_address_lease_time.go b/dhcpv4/option_ip_address_lease_time.go index 7562c58..2f63fb7 100644 --- a/dhcpv4/option_ip_address_lease_time.go +++ b/dhcpv4/option_ip_address_lease_time.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "encoding/binary" "fmt" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the IP Address Lease Time option @@ -16,20 +18,9 @@ type OptIPAddressLeaseTime struct { // ParseOptIPAddressLeaseTime constructs an OptIPAddressLeaseTime struct from a // sequence of bytes and returns it, or an error. func ParseOptIPAddressLeaseTime(data []byte) (*OptIPAddressLeaseTime, error) { - // Should at least have code, length, and lease time. - if len(data) < 6 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionIPAddressLeaseTime { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionIPAddressLeaseTime, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("expected length 4, got %v instead", length) - } - leaseTime := binary.BigEndian.Uint32(data[2:6]) - return &OptIPAddressLeaseTime{LeaseTime: leaseTime}, nil + buf := uio.NewBigEndianBuffer(data) + leaseTime := buf.Read32() + return &OptIPAddressLeaseTime{LeaseTime: leaseTime}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_ip_address_lease_time_test.go b/dhcpv4/option_ip_address_lease_time_test.go index 7d507bf..dafa6e4 100644 --- a/dhcpv4/option_ip_address_lease_time_test.go +++ b/dhcpv4/option_ip_address_lease_time_test.go @@ -14,23 +14,18 @@ func TestOptIPAddressLeaseTimeInterfaceMethods(t *testing.T) { } func TestParseOptIPAddressLeaseTime(t *testing.T) { - data := []byte{51, 4, 0, 0, 168, 192} + data := []byte{0, 0, 168, 192} o, err := ParseOptIPAddressLeaseTime(data) require.NoError(t, err) require.Equal(t, &OptIPAddressLeaseTime{LeaseTime: 43200}, o) // Short byte stream - data = []byte{51, 4, 168, 192} + data = []byte{168, 192} _, err = ParseOptIPAddressLeaseTime(data) require.Error(t, err, "should get error from short byte stream") - // Wrong code - data = []byte{54, 4, 0, 0, 168, 192} - _, err = ParseOptIPAddressLeaseTime(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{51, 5, 1, 1, 1, 1, 1} + data = []byte{1, 1, 1, 1, 1} _, err = ParseOptIPAddressLeaseTime(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/option_maximum_dhcp_message_size.go b/dhcpv4/option_maximum_dhcp_message_size.go index e5fedc6..26865fd 100644 --- a/dhcpv4/option_maximum_dhcp_message_size.go +++ b/dhcpv4/option_maximum_dhcp_message_size.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "encoding/binary" "fmt" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the Maximum DHCP Message size option @@ -16,20 +18,8 @@ type OptMaximumDHCPMessageSize struct { // ParseOptMaximumDHCPMessageSize constructs an OptMaximumDHCPMessageSize struct from a sequence of // bytes and returns it, or an error. func ParseOptMaximumDHCPMessageSize(data []byte) (*OptMaximumDHCPMessageSize, error) { - // Should at least have code, length, and message size. - if len(data) < 4 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionMaximumDHCPMessageSize { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionMaximumDHCPMessageSize, code) - } - length := int(data[1]) - if length != 2 { - return nil, fmt.Errorf("expected length 2, got %v instead", length) - } - msgSize := binary.BigEndian.Uint16(data[2:4]) - return &OptMaximumDHCPMessageSize{Size: msgSize}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptMaximumDHCPMessageSize{Size: buf.Read16()}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_maximum_dhcp_message_size_test.go b/dhcpv4/option_maximum_dhcp_message_size_test.go index f24b499..24ba49f 100644 --- a/dhcpv4/option_maximum_dhcp_message_size_test.go +++ b/dhcpv4/option_maximum_dhcp_message_size_test.go @@ -14,23 +14,18 @@ func TestOptMaximumDHCPMessageSizeInterfaceMethods(t *testing.T) { } func TestParseOptMaximumDHCPMessageSize(t *testing.T) { - data := []byte{57, 2, 5, 220} + data := []byte{5, 220} o, err := ParseOptMaximumDHCPMessageSize(data) require.NoError(t, err) require.Equal(t, &OptMaximumDHCPMessageSize{Size: 1500}, o) // Short byte stream - data = []byte{57, 2} + data = []byte{2} _, err = ParseOptMaximumDHCPMessageSize(data) require.Error(t, err, "should get error from short byte stream") - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptMaximumDHCPMessageSize(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{57, 3, 1, 1, 1} + data = []byte{1, 1, 1} _, err = ParseOptMaximumDHCPMessageSize(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/option_message_type.go b/dhcpv4/option_message_type.go index 903a57e..3e11a9a 100644 --- a/dhcpv4/option_message_type.go +++ b/dhcpv4/option_message_type.go @@ -2,6 +2,8 @@ package dhcpv4 import ( "fmt" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the message type option @@ -15,20 +17,8 @@ type OptMessageType struct { // ParseOptMessageType constructs an OptMessageType struct from a sequence of // bytes and returns it, or an error. func ParseOptMessageType(data []byte) (*OptMessageType, error) { - // Should at least have code, length, and message type. - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionDHCPMessageType { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionDHCPMessageType, code) - } - length := int(data[1]) - if length != 1 { - return nil, ErrShortByteStream - } - messageType := MessageType(data[2]) - return &OptMessageType{MessageType: messageType}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptMessageType{MessageType: MessageType(buf.Read8())}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_message_type_test.go b/dhcpv4/option_message_type_test.go index 59e6c17..091066e 100644 --- a/dhcpv4/option_message_type_test.go +++ b/dhcpv4/option_message_type_test.go @@ -20,23 +20,13 @@ func TestOptMessageTypeNew(t *testing.T) { } func TestParseOptMessageType(t *testing.T) { - data := []byte{53, 1, 1} // DISCOVER + data := []byte{1} // DISCOVER o, err := ParseOptMessageType(data) require.NoError(t, err) require.Equal(t, &OptMessageType{MessageType: MessageTypeDiscover}, o) - // Short byte stream - data = []byte{53, 1} - _, err = ParseOptMessageType(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 1, 1} - _, err = ParseOptMessageType(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{53, 5, 1} + data = []byte{1, 2} _, err = ParseOptMessageType(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/option_ntp_servers.go b/dhcpv4/option_ntp_servers.go index 39881d6..6d30920 100644 --- a/dhcpv4/option_ntp_servers.go +++ b/dhcpv4/option_ntp_servers.go @@ -15,26 +15,11 @@ type OptNTPServers struct { // ParseOptNTPServers returns a new OptNTPServers from a byte stream, or error if any. func ParseOptNTPServers(data []byte) (*OptNTPServers, error) { - if len(data) < 2 { - return nil, ErrShortByteStream + ips, err := ParseIPs(data) + if err != nil { + return nil, err } - code := OptionCode(data[0]) - if code != OptionNTPServers { - return nil, fmt.Errorf("expected code %v, got %v", OptionNTPServers, code) - } - length := int(data[1]) - if length == 0 || length%4 != 0 { - return nil, fmt.Errorf("Invalid length: expected multiple of 4 larger than 4, got %v", length) - } - if len(data) < 2+length { - return nil, ErrShortByteStream - } - ntpServers := make([]net.IP, 0, length%4) - for idx := 0; idx < length; idx += 4 { - b := data[2+idx : 2+idx+4] - ntpServers = append(ntpServers, net.IPv4(b[0], b[1], b[2], b[3])) - } - return &OptNTPServers{NTPServers: ntpServers}, nil + return &OptNTPServers{NTPServers: ips}, nil } // Code returns the option code. diff --git a/dhcpv4/option_ntp_servers_test.go b/dhcpv4/option_ntp_servers_test.go index e7bcefd..4d321ff 100644 --- a/dhcpv4/option_ntp_servers_test.go +++ b/dhcpv4/option_ntp_servers_test.go @@ -25,37 +25,23 @@ func TestParseOptNTPServers(t *testing.T) { 192, 168, 0, 10, // NTP server #1 192, 168, 0, 20, // NTP server #2 } - o, err := ParseOptNTPServers(data) + o, err := ParseOptNTPServers(data[2:]) require.NoError(t, err) ntpServers := []net.IP{ - net.IPv4(192, 168, 0, 10), - net.IPv4(192, 168, 0, 20), + net.IP{192, 168, 0, 10}, + net.IP{192, 168, 0, 20}, } require.Equal(t, &OptNTPServers{NTPServers: ntpServers}, o) - // Short byte stream - data = []byte{byte(OptionNTPServers)} - _, err = ParseOptNTPServers(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptNTPServers(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{byte(OptionNTPServers), 6, 1, 1, 1} + data = []byte{1, 1, 1} _, err = ParseOptNTPServers(data) require.Error(t, err, "should get error from bad length") } func TestParseOptNTPserversNoNTPServers(t *testing.T) { // RFC2132 requires that at least one NTP server IP is specified - data := []byte{ - byte(OptionNTPServers), - 0, // Length - } - _, err := ParseOptNTPServers(data) + _, err := ParseOptNTPServers([]byte{}) require.Error(t, err) } diff --git a/dhcpv4/option_parameter_request_list.go b/dhcpv4/option_parameter_request_list.go index 865b2d7..8516e3b 100644 --- a/dhcpv4/option_parameter_request_list.go +++ b/dhcpv4/option_parameter_request_list.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "fmt" "strings" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the parameter request list option @@ -16,23 +18,12 @@ type OptParameterRequestList struct { // ParseOptParameterRequestList returns a new OptParameterRequestList from a // byte stream, or error if any. func ParseOptParameterRequestList(data []byte) (*OptParameterRequestList, error) { - // Should at least have code + length byte. - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionParameterRequestList { - return nil, fmt.Errorf("expected code %v, got %v", OptionParameterRequestList, code) - } - length := int(data[1]) - if len(data) < length+2 { - return nil, ErrShortByteStream - } - var requestedOpts []OptionCode - for _, opt := range data[2 : length+2] { - requestedOpts = append(requestedOpts, OptionCode(opt)) + buf := uio.NewBigEndianBuffer(data) + requestedOpts := make([]OptionCode, 0, buf.Len()) + for buf.Len() > 0 { + requestedOpts = append(requestedOpts, OptionCode(buf.Read8())) } - return &OptParameterRequestList{RequestedOpts: requestedOpts}, nil + return &OptParameterRequestList{RequestedOpts: requestedOpts}, buf.Error() } // Code returns the option code. diff --git a/dhcpv4/option_parameter_request_list_test.go b/dhcpv4/option_parameter_request_list_test.go index f600a70..d54fc0f 100644 --- a/dhcpv4/option_parameter_request_list_test.go +++ b/dhcpv4/option_parameter_request_list_test.go @@ -23,16 +23,7 @@ func TestParseOptParameterRequestList(t *testing.T) { o *OptParameterRequestList err error ) - o, err = ParseOptParameterRequestList([]byte{}) - require.Error(t, err, "empty byte stream") - - o, err = ParseOptParameterRequestList([]byte{55, 2}) - require.Error(t, err, "short byte stream") - - o, err = ParseOptParameterRequestList([]byte{53, 2, 1, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptParameterRequestList([]byte{55, 2, 67, 5}) + o, err = ParseOptParameterRequestList([]byte{67, 5}) require.NoError(t, err) expectedOpts := []OptionCode{OptionBootfileName, OptionNameServer} require.Equal(t, expectedOpts, o.RequestedOpts) diff --git a/dhcpv4/option_relay_agent_information.go b/dhcpv4/option_relay_agent_information.go index 447783e..d6547fe 100644 --- a/dhcpv4/option_relay_agent_information.go +++ b/dhcpv4/option_relay_agent_information.go @@ -8,42 +8,19 @@ import "fmt" // OptRelayAgentInformation is a "container" option for specific agent-supplied // sub-options. type OptRelayAgentInformation struct { - Options []Option + Options Options } // ParseOptRelayAgentInformation returns a new OptRelayAgentInformation from a // byte stream, or error if any. func ParseOptRelayAgentInformation(data []byte) (*OptRelayAgentInformation, error) { - if len(data) < 4 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionRelayAgentInformation { - return nil, fmt.Errorf("expected code %v, got %v", OptionRelayAgentInformation, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - options, err := OptionsFromBytesWithParser(data[2:length+2], relayParseOption) + options, err := OptionsFromBytesWithParser(data, ParseOptionGeneric, false /* don't check for OptionEnd tag */) if err != nil { return nil, err } return &OptRelayAgentInformation{Options: options}, nil } -func relayParseOption(data []byte) (Option, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - return &OptionGeneric{OptionCode: code, Data: data[2:length+2]}, nil -} - // Code returns the option code. func (o *OptRelayAgentInformation) Code() OptionCode { return OptionRelayAgentInformation diff --git a/dhcpv4/option_relay_agent_information_test.go b/dhcpv4/option_relay_agent_information_test.go index 1e99206..bff8ced 100644 --- a/dhcpv4/option_relay_agent_information_test.go +++ b/dhcpv4/option_relay_agent_information_test.go @@ -14,33 +14,21 @@ func TestParseOptRelayAgentInformation(t *testing.T) { 2, 4, 'b', 'o', 'o', 't', } - // short option bytes - opt, err := ParseOptRelayAgentInformation([]byte{}) - require.Error(t, err) - - // wrong code - opt, err = ParseOptRelayAgentInformation([]byte{1, 2, 1, 0}) - require.Error(t, err) - - // wrong option length - opt, err = ParseOptRelayAgentInformation([]byte{82, 3, 1, 0}) - require.Error(t, err) - // short sub-option bytes - opt, err = ParseOptRelayAgentInformation([]byte{82, 3, 1, 0, 1}) + opt, err := ParseOptRelayAgentInformation([]byte{1, 0, 1}) require.Error(t, err) // short sub-option length - opt, err = ParseOptRelayAgentInformation([]byte{82, 2, 1, 1}) + opt, err = ParseOptRelayAgentInformation([]byte{1, 1}) require.Error(t, err) - opt, err = ParseOptRelayAgentInformation(data) + opt, err = ParseOptRelayAgentInformation(data[2:]) require.NoError(t, err) require.Equal(t, len(opt.Options), 2) - circuit, ok := opt.Options[0].(*OptionGeneric) - require.True(t, ok) - remote, ok := opt.Options[1].(*OptionGeneric) - require.True(t, ok) + circuit := opt.Options.GetOneOption(1).(*OptionGeneric) + require.NoError(t, err) + remote := opt.Options.GetOneOption(2).(*OptionGeneric) + require.NoError(t, err) require.Equal(t, circuit.Data, []byte("linux")) require.Equal(t, remote.Data, []byte("boot")) } diff --git a/dhcpv4/option_requested_ip_address.go b/dhcpv4/option_requested_ip_address.go index 8662263..3539278 100644 --- a/dhcpv4/option_requested_ip_address.go +++ b/dhcpv4/option_requested_ip_address.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "fmt" "net" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the requested IP address option @@ -17,21 +19,8 @@ type OptRequestedIPAddress struct { // ParseOptRequestedIPAddress returns a new OptServerIdentifier from a byte // stream, or error if any. func ParseOptRequestedIPAddress(data []byte) (*OptRequestedIPAddress, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionRequestedIPAddress { - return nil, fmt.Errorf("expected code %v, got %v", OptionRequestedIPAddress, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("unexepcted length: expected 4, got %v", length) - } - if len(data) < 6 { - return nil, ErrShortByteStream - } - return &OptRequestedIPAddress{RequestedAddr: net.IP(data[2 : 2+length])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptRequestedIPAddress{RequestedAddr: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_requested_ip_address_test.go b/dhcpv4/option_requested_ip_address_test.go index efd6299..99592ea 100644 --- a/dhcpv4/option_requested_ip_address_test.go +++ b/dhcpv4/option_requested_ip_address_test.go @@ -29,16 +29,10 @@ func TestParseOptRequestedIPAddress(t *testing.T) { o, err = ParseOptRequestedIPAddress([]byte{}) require.Error(t, err, "empty byte stream") - o, err = ParseOptRequestedIPAddress([]byte{byte(OptionRequestedIPAddress), 4, 192}) - require.Error(t, err, "short byte stream") - - o, err = ParseOptRequestedIPAddress([]byte{byte(OptionRequestedIPAddress), 3, 192, 168, 0, 1}) + o, err = ParseOptRequestedIPAddress([]byte{192}) require.Error(t, err, "wrong IP length") - o, err = ParseOptRequestedIPAddress([]byte{53, 4, 192, 168, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptRequestedIPAddress([]byte{byte(OptionRequestedIPAddress), 4, 192, 168, 0, 1}) + o, err = ParseOptRequestedIPAddress([]byte{192, 168, 0, 1}) require.NoError(t, err) require.Equal(t, net.IP{192, 168, 0, 1}, o.RequestedAddr) } diff --git a/dhcpv4/option_root_path.go b/dhcpv4/option_root_path.go index 504ed17..ba6f03f 100644 --- a/dhcpv4/option_root_path.go +++ b/dhcpv4/option_root_path.go @@ -15,19 +15,7 @@ type OptRootPath struct { // ParseOptRootPath constructs an OptRootPath struct from a sequence of bytes // and returns it, or an error. func ParseOptRootPath(data []byte) (*OptRootPath, error) { - // Should at least have code and length - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionRootPath { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionRootPath, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - return &OptRootPath{Path: string(data[2 : 2+length])}, nil + return &OptRootPath{Path: string(data)}, nil } // Code returns the option code. diff --git a/dhcpv4/option_root_path_test.go b/dhcpv4/option_root_path_test.go index 53de45b..4bc7bc1 100644 --- a/dhcpv4/option_root_path_test.go +++ b/dhcpv4/option_root_path_test.go @@ -20,24 +20,9 @@ func TestOptRootPathInterfaceMethods(t *testing.T) { func TestParseOptRootPath(t *testing.T) { data := []byte{byte(OptionRootPath), 4, '/', 'f', 'o', 'o'} - o, err := ParseOptRootPath(data) + o, err := ParseOptRootPath(data[2:]) require.NoError(t, err) require.Equal(t, &OptRootPath{Path: "/foo"}, o) - - // Short byte stream - data = []byte{byte(OptionRootPath)} - _, err = ParseOptRootPath(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{43, 2, 1, 1} - _, err = ParseOptRootPath(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionRootPath), 6, 1, 1, 1} - _, err = ParseOptRootPath(data) - require.Error(t, err, "should get error from bad length") } func TestOptRootPathString(t *testing.T) { diff --git a/dhcpv4/option_router.go b/dhcpv4/option_router.go index 3154edd..60a96d1 100644 --- a/dhcpv4/option_router.go +++ b/dhcpv4/option_router.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "fmt" "net" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the router option @@ -13,28 +15,30 @@ type OptRouter struct { Routers []net.IP } -// ParseOptRouter returns a new OptRouter from a byte stream, or error if any. -func ParseOptRouter(data []byte) (*OptRouter, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionRouter { - return nil, fmt.Errorf("expected code %v, got %v", OptionRouter, code) - } - length := int(data[1]) - if length == 0 || length%4 != 0 { - return nil, fmt.Errorf("Invalid length: expected multiple of 4 larger than 4, got %v", length) +// ParseIPs parses an IPv4 address from a DHCP packet as used and specified by +// options in RFC 2132, Sections 3.5 through 3.13, 8.2, 8.3, 8.5, 8.6, 8.9, and +// 8.10. +func ParseIPs(data []byte) ([]net.IP, error) { + buf := uio.NewBigEndianBuffer(data) + + if buf.Len() == 0 { + return nil, fmt.Errorf("IP DHCP options must always list at least one IP") } - if len(data) < 2+length { - return nil, ErrShortByteStream + + ips := make([]net.IP, 0, buf.Len()/net.IPv4len) + for buf.Has(net.IPv4len) { + ips = append(ips, net.IP(buf.CopyN(net.IPv4len))) } - routers := make([]net.IP, 0, length%4) - for idx := 0; idx < length; idx += 4 { - b := data[2+idx : 2+idx+4] - routers = append(routers, net.IPv4(b[0], b[1], b[2], b[3])) + return ips, buf.FinError() +} + +// ParseOptRouter returns a new OptRouter from a byte stream, or error if any. +func ParseOptRouter(data []byte) (*OptRouter, error) { + ips, err := ParseIPs(data) + if err != nil { + return nil, err } - return &OptRouter{Routers: routers}, nil + return &OptRouter{Routers: ips}, nil } // Code returns the option code. diff --git a/dhcpv4/option_router_test.go b/dhcpv4/option_router_test.go index f492c22..3264ce3 100644 --- a/dhcpv4/option_router_test.go +++ b/dhcpv4/option_router_test.go @@ -25,11 +25,11 @@ func TestParseOptRouter(t *testing.T) { 192, 168, 0, 10, // Router #1 192, 168, 0, 20, // Router #2 } - o, err := ParseOptRouter(data) + o, err := ParseOptRouter(data[2:]) require.NoError(t, err) routers := []net.IP{ - net.IPv4(192, 168, 0, 10), - net.IPv4(192, 168, 0, 20), + net.IP{192, 168, 0, 10}, + net.IP{192, 168, 0, 20}, } require.Equal(t, &OptRouter{Routers: routers}, o) @@ -37,16 +37,6 @@ func TestParseOptRouter(t *testing.T) { data = []byte{byte(OptionRouter)} _, err = ParseOptRouter(data) require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptRouter(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionRouter), 6, 1, 1, 1} - _, err = ParseOptRouter(data) - require.Error(t, err, "should get error from bad length") } func TestParseOptRouterNoRouters(t *testing.T) { @@ -60,6 +50,6 @@ func TestParseOptRouterNoRouters(t *testing.T) { } func TestOptRouterString(t *testing.T) { - o := OptRouter{Routers: []net.IP{net.IPv4(192, 168, 0, 1), net.IPv4(192, 168, 0, 10)}} + o := OptRouter{Routers: []net.IP{net.IP{192, 168, 0, 1}, net.IP{192, 168, 0, 10}}} require.Equal(t, "Routers -> 192.168.0.1, 192.168.0.10", o.String()) } diff --git a/dhcpv4/option_server_identifier.go b/dhcpv4/option_server_identifier.go index 26c21a7..fd0311a 100644 --- a/dhcpv4/option_server_identifier.go +++ b/dhcpv4/option_server_identifier.go @@ -3,12 +3,14 @@ package dhcpv4 import ( "fmt" "net" + + "github.com/u-root/u-root/pkg/uio" ) +// OptServerIdentifier represents an option encapsulating the server identifier. +// // This option implements the server identifier option // https://tools.ietf.org/html/rfc2132 - -// OptServerIdentifier represents an option encapsulating the server identifier. type OptServerIdentifier struct { ServerID net.IP } @@ -16,21 +18,8 @@ type OptServerIdentifier struct { // ParseOptServerIdentifier returns a new OptServerIdentifier from a byte // stream, or error if any. func ParseOptServerIdentifier(data []byte) (*OptServerIdentifier, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionServerIdentifier { - return nil, fmt.Errorf("expected code %v, got %v", OptionServerIdentifier, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("unexepcted length: expected 4, got %v", length) - } - if len(data) < 6 { - return nil, ErrShortByteStream - } - return &OptServerIdentifier{ServerID: net.IP(data[2 : 2+length])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptServerIdentifier{ServerID: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_server_identifier_test.go b/dhcpv4/option_server_identifier_test.go index 2951ebb..4a85469 100644 --- a/dhcpv4/option_server_identifier_test.go +++ b/dhcpv4/option_server_identifier_test.go @@ -29,16 +29,10 @@ func TestParseOptServerIdentifier(t *testing.T) { o, err = ParseOptServerIdentifier([]byte{}) require.Error(t, err, "empty byte stream") - o, err = ParseOptServerIdentifier([]byte{54, 4, 192}) - require.Error(t, err, "short byte stream") - - o, err = ParseOptServerIdentifier([]byte{54, 3, 192, 168, 0, 1}) + o, err = ParseOptServerIdentifier([]byte{192, 168, 0}) require.Error(t, err, "wrong IP length") - o, err = ParseOptServerIdentifier([]byte{53, 4, 192, 168, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptServerIdentifier([]byte{54, 4, 192, 168, 0, 1}) + o, err = ParseOptServerIdentifier([]byte{192, 168, 0, 1}) require.NoError(t, err) require.Equal(t, net.IP{192, 168, 0, 1}, o.ServerID) } diff --git a/dhcpv4/option_subnet_mask.go b/dhcpv4/option_subnet_mask.go index f1ff4a4..86ce004 100644 --- a/dhcpv4/option_subnet_mask.go +++ b/dhcpv4/option_subnet_mask.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "fmt" "net" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the subnet mask option @@ -16,21 +18,8 @@ type OptSubnetMask struct { // ParseOptSubnetMask returns a new OptSubnetMask from a byte // stream, or error if any. func ParseOptSubnetMask(data []byte) (*OptSubnetMask, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionSubnetMask { - return nil, fmt.Errorf("expected code %v, got %v", OptionSubnetMask, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("unexepcted length: expected 4, got %v", length) - } - if len(data) < 6 { - return nil, ErrShortByteStream - } - return &OptSubnetMask{SubnetMask: net.IPMask(data[2 : 2+length])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptSubnetMask{SubnetMask: net.IPMask(buf.CopyN(net.IPv4len))}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_subnet_mask_test.go b/dhcpv4/option_subnet_mask_test.go index 4cb8819..e3c37bf 100644 --- a/dhcpv4/option_subnet_mask_test.go +++ b/dhcpv4/option_subnet_mask_test.go @@ -29,16 +29,10 @@ func TestParseOptSubnetMask(t *testing.T) { o, err = ParseOptSubnetMask([]byte{}) require.Error(t, err, "empty byte stream") - o, err = ParseOptSubnetMask([]byte{1, 4, 255}) + o, err = ParseOptSubnetMask([]byte{255}) require.Error(t, err, "short byte stream") - o, err = ParseOptSubnetMask([]byte{1, 3, 255, 255, 255, 0}) - require.Error(t, err, "wrong IP length") - - o, err = ParseOptSubnetMask([]byte{2, 4, 255, 255, 255}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptSubnetMask([]byte{1, 4, 255, 255, 255, 0}) + o, err = ParseOptSubnetMask([]byte{255, 255, 255, 0}) require.NoError(t, err) require.Equal(t, net.IPMask{255, 255, 255, 0}, o.SubnetMask) } diff --git a/dhcpv4/option_tftp_server_name.go b/dhcpv4/option_tftp_server_name.go index 3a310f9..2a4af6d 100644 --- a/dhcpv4/option_tftp_server_name.go +++ b/dhcpv4/option_tftp_server_name.go @@ -9,7 +9,7 @@ import ( // OptTFTPServerName implements the TFTP server name option. type OptTFTPServerName struct { - TFTPServerName []byte + TFTPServerName string } // Code returns the option code @@ -19,7 +19,7 @@ func (op *OptTFTPServerName) Code() OptionCode { // ToBytes serializes the option and returns it as a sequence of bytes func (op *OptTFTPServerName) ToBytes() []byte { - return append([]byte{byte(op.Code()), byte(op.Length())}, op.TFTPServerName...) + return append([]byte{byte(op.Code()), byte(op.Length())}, []byte(op.TFTPServerName)...) } // Length returns the option length in bytes @@ -33,22 +33,5 @@ func (op *OptTFTPServerName) String() string { // ParseOptTFTPServerName returns a new OptTFTPServerName from a byte stream or error if any func ParseOptTFTPServerName(data []byte) (*OptTFTPServerName, error) { - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionTFTPServerName { - return nil, fmt.Errorf("ParseOptTFTPServerName: invalid code: %v; want %v", - code, OptionTFTPServerName) - } - length := int(data[1]) - if length < 1 { - return nil, fmt.Errorf("TFTP server name has invalid length of %d", length) - } - TFTPServerNameData := data[2:] - if len(TFTPServerNameData) < length { - return nil, fmt.Errorf("ParseOptTFTPServerName: short data: %d bytes; want %d", - len(TFTPServerNameData), length) - } - return &OptTFTPServerName{TFTPServerName: TFTPServerNameData[:length]}, nil + return &OptTFTPServerName{TFTPServerName: string(data)}, nil } diff --git a/dhcpv4/option_tftp_server_name_test.go b/dhcpv4/option_tftp_server_name_test.go index 812210f..54efef8 100644 --- a/dhcpv4/option_tftp_server_name_test.go +++ b/dhcpv4/option_tftp_server_name_test.go @@ -13,7 +13,7 @@ func TestOptTFTPServerNameCode(t *testing.T) { func TestOptTFTPServerNameToBytes(t *testing.T) { opt := OptTFTPServerName{ - TFTPServerName: []byte("linuxboot"), + TFTPServerName: "linuxboot", } data := opt.ToBytes() expected := []byte{ @@ -26,7 +26,7 @@ func TestOptTFTPServerNameToBytes(t *testing.T) { func TestParseOptTFTPServerName(t *testing.T) { expected := []byte{ - 66, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', + 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } opt, err := ParseOptTFTPServerName(expected) require.NoError(t, err) @@ -34,32 +34,7 @@ func TestParseOptTFTPServerName(t *testing.T) { require.Equal(t, "linuxboot", string(opt.TFTPServerName)) } -func TestParseOptTFTPServerNameZeroLength(t *testing.T) { - expected := []byte{ - 66, 0, - } - _, err := ParseOptTFTPServerName(expected) - require.Error(t, err) -} - -func TestParseOptTFTPServerNameInvalidLength(t *testing.T) { - expected := []byte{ - 66, 9, 'l', 'i', 'n', 'u', 'x', 'b', - } - _, err := ParseOptTFTPServerName(expected) - require.Error(t, err) -} - -func TestParseOptTFTPServerNameShortLength(t *testing.T) { - expected := []byte{ - 66, 4, 'l', 'i', 'n', 'u', 'x', - } - opt, err := ParseOptTFTPServerName(expected) - require.NoError(t, err) - require.Equal(t, []byte("linu"), opt.TFTPServerName) -} - func TestOptTFTPServerNameString(t *testing.T) { - o := OptTFTPServerName{TFTPServerName: []byte("testy test")} + o := OptTFTPServerName{TFTPServerName: "testy test"} require.Equal(t, "TFTP Server Name -> testy test", o.String()) } diff --git a/dhcpv4/option_userclass.go b/dhcpv4/option_userclass.go index d6ddabc..44ce090 100644 --- a/dhcpv4/option_userclass.go +++ b/dhcpv4/option_userclass.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "strings" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the User Class option @@ -12,7 +14,7 @@ import ( // OptUserClass represents an option encapsulating User Classes. type OptUserClass struct { UserClasses [][]byte - Rfc3004 bool + Rfc3004 bool } // Code returns the option code @@ -61,21 +63,7 @@ func (op *OptUserClass) String() string { // error if any func ParseOptUserClass(data []byte) (*OptUserClass, error) { opt := OptUserClass{} - - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionUserClassInformation { - return nil, fmt.Errorf("expected code %v, got %v", OptionUserClassInformation, code) - } - - totalLength := int(data[1]) - data = data[2:] - if len(data) < totalLength { - return nil, fmt.Errorf("ParseOptUserClass: short data: length is %d but got %d bytes", - totalLength, len(data)) - } + buf := uio.NewBigEndianBuffer(data) // Check if option is Microsoft style instead of RFC compliant, issue #113 @@ -85,29 +73,24 @@ func ParseOptUserClass(data []byte) (*OptUserClass, error) { // option length. If the lengths don't add up, we assume that the option // is a single string and non RFC3004 compliant var counting int - for counting < totalLength { + for counting < buf.Len() { // UC_Len_i does not include itself so add 1 counting += int(data[counting]) + 1 } - if counting != totalLength { - opt.UserClasses = append(opt.UserClasses, data[:totalLength]) + if counting != buf.Len() { + opt.UserClasses = append(opt.UserClasses, data) return &opt, nil } opt.Rfc3004 = true - for i := 0; i < totalLength; { - ucLen := int(data[i]) + for buf.Has(1) { + ucLen := buf.Read8() if ucLen == 0 { - return nil, errors.New("User Class value has invalid length of 0") - } - base := i + 1 - if len(data) < base+ucLen { - return nil, fmt.Errorf("ParseOptUserClass: short data: %d bytes; want: %d", len(data), base+ucLen) + return nil, fmt.Errorf("DHCP user class must have length greater than 0") } - opt.UserClasses = append(opt.UserClasses, data[base:base+ucLen]) - i += base + ucLen + opt.UserClasses = append(opt.UserClasses, buf.CopyN(int(ucLen))) } - if len(opt.UserClasses) < 1 { + if len(opt.UserClasses) == 0 { return nil, errors.New("ParseOptUserClass: at least one user class is required") } - return &opt, nil + return &opt, buf.FinError() } diff --git a/dhcpv4/option_userclass_test.go b/dhcpv4/option_userclass_test.go index f6039df..d392ed8 100644 --- a/dhcpv4/option_userclass_test.go +++ b/dhcpv4/option_userclass_test.go @@ -9,7 +9,7 @@ import ( func TestOptUserClassToBytes(t *testing.T) { opt := OptUserClass{ UserClasses: [][]byte{[]byte("linuxboot")}, - Rfc3004: true, + Rfc3004: true, } data := opt.ToBytes() expected := []byte{ @@ -35,7 +35,6 @@ func TestOptUserClassMicrosoftToBytes(t *testing.T) { func TestParseOptUserClassMultiple(t *testing.T) { expected := []byte{ - 77, 15, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 4, 't', 'e', 's', 't', } @@ -54,7 +53,7 @@ func TestParseOptUserClassNone(t *testing.T) { func TestParseOptUserClassMicrosoft(t *testing.T) { expected := []byte{ - 77, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', + 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } opt, err := ParseOptUserClass(expected) require.NoError(t, err) @@ -64,7 +63,7 @@ func TestParseOptUserClassMicrosoft(t *testing.T) { func TestParseOptUserClassMicrosoftShort(t *testing.T) { expected := []byte{ - 77, 1, 'l', + 'l', } opt, err := ParseOptUserClass(expected) require.NoError(t, err) @@ -72,19 +71,9 @@ func TestParseOptUserClassMicrosoftShort(t *testing.T) { require.Equal(t, []byte("l"), opt.UserClasses[0]) } -func TestParseOptUserClassMicrosoftLongerThanLength(t *testing.T) { - expected := []byte{ - 77, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 'X', - } - opt, err := ParseOptUserClass(expected) - require.NoError(t, err) - require.Equal(t, 1, len(opt.UserClasses)) - require.Equal(t, []byte("linuxboot"), opt.UserClasses[0]) -} - func TestParseOptUserClass(t *testing.T) { expected := []byte{ - 77, 10, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', + 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } opt, err := ParseOptUserClass(expected) require.NoError(t, err) @@ -110,44 +99,18 @@ func TestOptUserClassToBytesMultiple(t *testing.T) { require.Equal(t, expected, data) } -func TestParseOptUserClassLongerThanLength(t *testing.T) { - expected := []byte{ - 77, 10, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 'X', - } - opt, err := ParseOptUserClass(expected) - require.NoError(t, err) - require.Equal(t, 1, len(opt.UserClasses)) - require.Equal(t, []byte("linuxboot"), opt.UserClasses[0]) -} - -func TestParseOptUserClassShorterTotalLength(t *testing.T) { - expected := []byte{ - 77, 11, 10, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', - } - _, err := ParseOptUserClass(expected) - require.Error(t, err) -} - func TestOptUserClassLength(t *testing.T) { expected := []byte{ - 77, 10, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 'X', + 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 'X', } opt, err := ParseOptUserClass(expected) require.NoError(t, err) - require.Equal(t, 10, opt.Length()) + require.Equal(t, 11, opt.Length()) } func TestParseOptUserClassZeroLength(t *testing.T) { expected := []byte{ - 77, 1, 0, 0, - } - _, err := ParseOptUserClass(expected) - require.Error(t, err) -} - -func TestParseOptUserClassMultipleWithZeroLength(t *testing.T) { - expected := []byte{ - 77, 12, 10, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 0, + 0, 0, } _, err := ParseOptUserClass(expected) require.Error(t, err) diff --git a/dhcpv4/option_vivc.go b/dhcpv4/option_vivc.go index 7576637..4ff42a3 100644 --- a/dhcpv4/option_vivc.go +++ b/dhcpv4/option_vivc.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "fmt" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the Vendor-Identifying Vendor Class Option @@ -23,39 +25,16 @@ type OptVIVC struct { // ParseOptVIVC contructs an OptVIVC tsruct from a sequence of bytes and returns // it, or an error. func ParseOptVIVC(data []byte) (*OptVIVC, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionVendorIdentifyingVendorClass { - return nil, fmt.Errorf("expected code %v, got %v", OptionVendorIdentifyingVendorClass, code) - } - - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - data = data[2:length+2] - - ids := []VIVCIdentifier{} - for len(data) > 5 { - entID := binary.BigEndian.Uint32(data[0:4]) - idLen := int(data[4]) - data = data[5:] - - if idLen > len(data) { - return nil, ErrShortByteStream - } - - ids = append(ids, VIVCIdentifier{EntID: entID, Data: data[:idLen]}) - data = data[idLen:] - } + buf := uio.NewBigEndianBuffer(data) - if len(data) != 0 { - return nil, ErrShortByteStream + var ids []VIVCIdentifier + for buf.Has(5) { + entID := buf.Read32() + idLen := int(buf.Read8()) + ids = append(ids, VIVCIdentifier{EntID: entID, Data: buf.CopyN(idLen)}) } - return &OptVIVC{Identifiers: ids}, nil + return &OptVIVC{Identifiers: ids}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_vivc_test.go b/dhcpv4/option_vivc_test.go index b290da0..0b4ab7c 100644 --- a/dhcpv4/option_vivc_test.go +++ b/dhcpv4/option_vivc_test.go @@ -31,36 +31,20 @@ func TestOptVIVCInterfaceMethods(t *testing.T) { } func TestParseOptVICO(t *testing.T) { - o, err := ParseOptVIVC(sampleVIVCOptRaw) + o, err := ParseOptVIVC(sampleVIVCOptRaw[2:]) require.NoError(t, err) require.Equal(t, &sampleVIVCOpt, o) - // Short byte stream - data := []byte{byte(OptionVendorIdentifyingVendorClass)} - _, err = ParseOptVIVC(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptVIVC(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionVendorIdentifyingVendorClass), 6, 1, 1, 1} - _, err = ParseOptVIVC(data) - require.Error(t, err, "should get error from bad length") - // Identifier len too long - data = make([]byte, len(sampleVIVCOptRaw)) - copy(data, sampleVIVCOptRaw) - data[6] = 40 + data := make([]byte, len(sampleVIVCOptRaw[2:])) + copy(data, sampleVIVCOptRaw[2:]) + data[4] = 40 _, err = ParseOptVIVC(data) require.Error(t, err, "should get error from bad length") // Longer than length - data[1] = 10 - data[6] = 5 - o, err = ParseOptVIVC(data) + data[4] = 5 + o, err = ParseOptVIVC(data[:10]) require.NoError(t, err) require.Equal(t, o.Identifiers[0].Data, []byte("Cisco")) } diff --git a/dhcpv4/options.go b/dhcpv4/options.go index a5a51be..215da39 100644 --- a/dhcpv4/options.go +++ b/dhcpv4/options.go @@ -2,15 +2,26 @@ package dhcpv4 import ( "errors" + "fmt" + "io" + + "github.com/u-root/u-root/pkg/uio" ) -// ErrShortByteStream is an error that is thrown any time a short byte stream is -// detected during option parsing. -var ErrShortByteStream = errors.New("short byte stream") +var ( + // ErrShortByteStream is an error that is thrown any time a short byte stream is + // detected during option parsing. + ErrShortByteStream = errors.New("short byte stream") + + // ErrZeroLengthByteStream is an error that is thrown any time a zero-length + // byte stream is encountered. + ErrZeroLengthByteStream = errors.New("zero-length byte stream") -// ErrZeroLengthByteStream is an error that is thrown any time a zero-length -// byte stream is encountered. -var ErrZeroLengthByteStream = errors.New("zero-length byte stream") + // ErrInvalidOptions is returned when invalid options data is + // encountered during parsing. The data could report an incorrect + // length or have trailing bytes which are not part of the option. + ErrInvalidOptions = errors.New("invalid options data") +) // OptionCode is a single byte representing the code for a given Option. type OptionCode byte @@ -25,15 +36,12 @@ type Option interface { // ParseOption parses a sequence of bytes as a single DHCPv4 option, returning // the specific option structure or error, if any. -func ParseOption(data []byte) (Option, error) { - if len(data) == 0 { - return nil, errors.New("invalid zero-length DHCPv4 option") - } +func ParseOption(code OptionCode, data []byte) (Option, error) { var ( opt Option err error ) - switch OptionCode(data[0]) { + switch code { case OptionSubnetMask: opt, err = ParseOptSubnetMask(data) case OptionRouter: @@ -79,7 +87,7 @@ func ParseOption(data []byte) (Option, error) { case OptionVendorIdentifyingVendorClass: opt, err = ParseOptVIVC(data) default: - opt, err = ParseOptionGeneric(data) + opt, err = ParseOptionGeneric(code, data) } if err != nil { return nil, err @@ -94,41 +102,74 @@ func ParseOption(data []byte) (Option, error) { // // Returns an error if any invalid option or length is found. func OptionsFromBytes(data []byte) ([]Option, error) { - return OptionsFromBytesWithParser(data, ParseOption) + return OptionsFromBytesWithParser(data, ParseOption, true) } // OptionParser is a function signature for option parsing -type OptionParser func(data []byte) (Option, error) +type OptionParser func(code OptionCode, data []byte) (Option, error) // OptionsFromBytesWithParser parses Options from byte sequences using the // parsing function that is passed in as a paremeter -func OptionsFromBytesWithParser(data []byte, parser OptionParser) ([]Option, error) { - options := make([]Option, 0, 10) - idx := 0 - for { - if idx == len(data) { +func OptionsFromBytesWithParser(data []byte, parser OptionParser, checkEndOption bool) (Options, error) { + if len(data) == 0 { + return nil, nil + } + buf := uio.NewBigEndianBuffer(data) + options := make(map[OptionCode][]byte, 10) + var order []OptionCode + + // Due to RFC 3396 allowing an option to be specified multiple times, + // we have to collect all option data first, and then parse it. + var end bool + for buf.Len() >= 1 { + // 1 byte: option code + // 1 byte: option length n + // n bytes: data + code := OptionCode(buf.Read8()) + + if code == OptionPad { + continue + } else if code == OptionEnd { + end = true break } - // This should never happen. - if idx > len(data) { - return nil, errors.New("read past the end of options") + length := int(buf.Read8()) + + // N bytes: option data + data := buf.Consume(length) + if data == nil { + return nil, fmt.Errorf("error collecting options: %v", buf.Error()) } - opt, err := parser(data[idx:]) - idx++ - if err != nil { - return nil, err + data = data[:length:length] + + if _, ok := options[code]; !ok { + order = append(order, code) } - options = append(options, opt) - if opt.Code() == OptionEnd { - break + // RFC 3396: Just concatenate the data if the option code was + // specified multiple times. + options[code] = append(options[code], data...) + } + + // If we never read the End option, the sender of this packet screwed + // up. + if !end && checkEndOption { + return nil, io.ErrUnexpectedEOF + } + + // Any bytes left must be padding. + for buf.Len() >= 1 { + if OptionCode(buf.Read8()) != OptionPad { + return nil, ErrInvalidOptions } + } - // Options with zero length have no length byte, so here we handle the - // ones with nonzero length - if opt.Code() != OptionPad { - idx++ + opts := make(Options, 0, 10) + for _, code := range order { + parsedOpt, err := parser(code, options[code]) + if err != nil { + return nil, fmt.Errorf("error parsing option code %s: %v", code, err) } - idx += opt.Length() + opts = append(opts, parsedOpt) } - return options, nil + return opts, nil } diff --git a/dhcpv4/options_test.go b/dhcpv4/options_test.go index 2557af5..9a2f1c0 100644 --- a/dhcpv4/options_test.go +++ b/dhcpv4/options_test.go @@ -1,6 +1,9 @@ package dhcpv4 import ( + "bytes" + "fmt" + "math" "testing" "github.com/stretchr/testify/require" @@ -9,7 +12,7 @@ import ( func TestParseOption(t *testing.T) { // Generic option := []byte{5, 4, 192, 168, 1, 254} // DNS option - opt, err := ParseOption(option) + opt, err := ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) generic := opt.(*OptionGeneric) require.Equal(t, OptionNameServer, generic.Code()) @@ -19,7 +22,7 @@ func TestParseOption(t *testing.T) { // Option subnet mask option = []byte{1, 4, 255, 255, 255, 0} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionSubnetMask, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -27,7 +30,7 @@ func TestParseOption(t *testing.T) { // Option router option = []byte{3, 4, 192, 168, 1, 1} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionRouter, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -35,7 +38,7 @@ func TestParseOption(t *testing.T) { // Option domain name server option = []byte{6, 4, 192, 168, 1, 1} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionDomainNameServer, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -43,7 +46,7 @@ func TestParseOption(t *testing.T) { // Option host name option = []byte{12, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionHostName, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -51,7 +54,7 @@ func TestParseOption(t *testing.T) { // Option domain name option = []byte{15, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionDomainName, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -59,7 +62,7 @@ func TestParseOption(t *testing.T) { // Option root path option = []byte{17, 4, '/', 'f', 'o', 'o'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionRootPath, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -67,7 +70,7 @@ func TestParseOption(t *testing.T) { // Option broadcast address option = []byte{28, 4, 255, 255, 255, 255} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionBroadcastAddress, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -75,7 +78,7 @@ func TestParseOption(t *testing.T) { // Option NTP servers option = []byte{42, 4, 10, 10, 10, 10} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionNTPServers, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -83,7 +86,7 @@ func TestParseOption(t *testing.T) { // Requested IP address option = []byte{50, 4, 1, 2, 3, 4} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionRequestedIPAddress, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -91,7 +94,7 @@ func TestParseOption(t *testing.T) { // Requested IP address lease time option = []byte{51, 4, 0, 0, 0, 0} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionIPAddressLeaseTime, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -99,7 +102,7 @@ func TestParseOption(t *testing.T) { // Message type option = []byte{53, 1, 1} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionDHCPMessageType, opt.Code(), "Code") require.Equal(t, 1, opt.Length(), "Length") @@ -107,7 +110,7 @@ func TestParseOption(t *testing.T) { // Option server ID option = []byte{54, 4, 1, 2, 3, 4} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionServerIdentifier, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -115,7 +118,7 @@ func TestParseOption(t *testing.T) { // Parameter request list option = []byte{55, 3, 5, 53, 61} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionParameterRequestList, opt.Code(), "Code") require.Equal(t, 3, opt.Length(), "Length") @@ -123,7 +126,7 @@ func TestParseOption(t *testing.T) { // Option max message size option = []byte{57, 2, 1, 2} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionMaximumDHCPMessageSize, opt.Code(), "Code") require.Equal(t, 2, opt.Length(), "Length") @@ -131,7 +134,7 @@ func TestParseOption(t *testing.T) { // Option class identifier option = []byte{60, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionClassIdentifier, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -139,7 +142,7 @@ func TestParseOption(t *testing.T) { // Option TFTP server name option = []byte{66, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionTFTPServerName, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -147,7 +150,7 @@ func TestParseOption(t *testing.T) { // Option Bootfile name option = []byte{67, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionBootfileName, opt.Code(), "Code") require.Equal(t, 9, opt.Length(), "Length") @@ -155,37 +158,131 @@ func TestParseOption(t *testing.T) { // Option user class information option = []byte{77, 5, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionUserClassInformation, opt.Code(), "Code") require.Equal(t, 5, opt.Length(), "Length") require.Equal(t, option, opt.ToBytes(), "ToBytes") // Option relay agent information - option = []byte{82, 2, 1, 0} - opt, err = ParseOption(option) + option = []byte{82, 6, 1, 4, 129, 168, 0, 1} + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionRelayAgentInformation, opt.Code(), "Code") - require.Equal(t, 2, opt.Length(), "Length") + require.Equal(t, 6, opt.Length(), "Length") require.Equal(t, option, opt.ToBytes(), "ToBytes") // Option client system architecture type option option = []byte{93, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionClientSystemArchitectureType, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") require.Equal(t, option, opt.ToBytes(), "ToBytes") } -func TestParseOptionZeroLength(t *testing.T) { - option := []byte{} - _, err := ParseOption(option) - require.Error(t, err, "should get error from zero-length options") -} - -func TestParseOptionShortOption(t *testing.T) { - option := []byte{53, 1} - _, err := ParseOption(option) - require.Error(t, err, "should get error from short options") +func TestOptionsUnmarshal(t *testing.T) { + for i, tt := range []struct { + input []byte + want Options + wantError bool + }{ + { + // Buffer missing data. + input: []byte{ + 3 /* key */, 3 /* length */, 1, + }, + wantError: true, + }, + { + input: []byte{ + // This may look too long, but 0 is padding. + // The issue here is the missing OptionEnd. + 3, 3, 0, 0, 0, 0, 0, 0, 0, + }, + wantError: true, + }, + { + // Only OptionPad and OptionEnd can stand on their own + // without a length field. So this is too short. + input: []byte{ + 3, + }, + wantError: true, + }, + { + // Option present after the End is a nono. + input: []byte{byte(OptionEnd), 3}, + wantError: true, + }, + { + input: []byte{byte(OptionEnd)}, + want: Options{}, + }, + { + input: []byte{ + 3, 2, 5, 6, + byte(OptionEnd), + }, + want: Options{ + &OptionGeneric{ + OptionCode: 3, + Data: []byte{5, 6}, + }, + }, + }, + { + // Test RFC 3396. + input: append( + append([]byte{3, math.MaxUint8}, bytes.Repeat([]byte{10}, math.MaxUint8)...), + 3, 5, 10, 10, 10, 10, 10, + byte(OptionEnd), + ), + want: Options{ + &OptionGeneric{ + OptionCode: 3, + Data: bytes.Repeat([]byte{10}, math.MaxUint8+5), + }, + }, + }, + { + input: []byte{ + 10, 2, 255, 254, + 11, 3, 5, 5, 5, + byte(OptionEnd), + }, + want: Options{ + &OptionGeneric{ + OptionCode: 10, + Data: []byte{255, 254}, + }, + &OptionGeneric{ + OptionCode: 11, + Data: []byte{5, 5, 5}, + }, + }, + }, + { + input: append( + append([]byte{10, 2, 255, 254}, bytes.Repeat([]byte{byte(OptionPad)}, 255)...), + byte(OptionEnd), + ), + want: Options{ + &OptionGeneric{ + OptionCode: 10, + Data: []byte{255, 254}, + }, + }, + }, + } { + t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) { + opt, err := OptionsFromBytesWithParser(tt.input, ParseOptionGeneric, true) + if tt.wantError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, opt, tt.want) + } + }) + } } |