diff options
author | Christopher Koch <chrisko@google.com> | 2018-12-28 20:19:14 -0800 |
---|---|---|
committer | insomniac <insomniacslk@users.noreply.github.com> | 2019-01-10 23:01:22 +0000 |
commit | e87114a6e449d7a2e458b5529923e5668dfa3a11 (patch) | |
tree | 7c5ec350c8eb859400a159b881081ffe0dbd4ef0 /dhcpv4 | |
parent | 108ed92e1c9901936541020bc3214533acce77bb (diff) |
dhcpv4: simplify option parsing.
option's codes and lengths were being parsed twice: once in ParseOption
and once in each option type's Parse implementation. Consolidate such
that it only happens once.
Additionally, only pass data to options that they should parse -- we
know the length before the Parse function is called, so the option only
gets to see the data it needs to see.
Also, use uio.Lexer to simplify parsing code in general. Easier to read
and reason about.
Diffstat (limited to 'dhcpv4')
76 files changed, 544 insertions, 1277 deletions
diff --git a/dhcpv4/bsdp/boot_image.go b/dhcpv4/bsdp/boot_image.go index 88a9404..fa9b1a6 100644 --- a/dhcpv4/bsdp/boot_image.go +++ b/dhcpv4/bsdp/boot_image.go @@ -3,6 +3,8 @@ package bsdp import ( "encoding/binary" "fmt" + + "github.com/u-root/u-root/pkg/uio" ) // BootImageType represents the different BSDP boot image types. @@ -60,16 +62,14 @@ func (b BootImageID) String() string { return s + " " + t + " image" } -// BootImageIDFromBytes deserializes a collection of 4 bytes to a BootImageID. -func BootImageIDFromBytes(bytes []byte) (*BootImageID, error) { - if len(bytes) < 4 { - return nil, fmt.Errorf("not enough bytes to serialize BootImageID") - } - return &BootImageID{ - IsInstall: bytes[0]&0x80 != 0, - ImageType: BootImageType(bytes[0] & 0x7f), - Index: binary.BigEndian.Uint16(bytes[2:]), - }, nil +// Unmarshal reads b's binary representation from buf. +func (b *BootImageID) Unmarshal(buf *uio.Lexer) error { + byte0 := buf.Read8() + _ = buf.Read8() + b.IsInstall = byte0&0x80 != 0 + b.ImageType = BootImageType(byte0 & 0x7f) + b.Index = buf.Read16() + return buf.Error() } // BootImage describes a boot image - contains the boot image ID and the name. @@ -87,25 +87,16 @@ func (b *BootImage) ToBytes() []byte { } // String converts a BootImage to a human-readable representation. -func (b *BootImage) String() string { +func (b BootImage) String() string { return fmt.Sprintf("%v %v", b.Name, b.ID.String()) } -// BootImageFromBytes returns a deserialized BootImage struct from bytes. -func BootImageFromBytes(bytes []byte) (*BootImage, error) { - // Should at least contain 4 bytes of BootImageID + byte for length of - // boot image name. - if len(bytes) < 5 { - return nil, fmt.Errorf("not enough bytes to serialize BootImage") - } - imageID, err := BootImageIDFromBytes(bytes[:4]) - if err != nil { - return nil, err - } - nameLength := int(bytes[4]) - if 5+nameLength > len(bytes) { - return nil, fmt.Errorf("not enough bytes for BootImage") +// Unmarshal reads data from buf into b. +func (b *BootImage) Unmarshal(buf *uio.Lexer) error { + if err := (&b.ID).Unmarshal(buf); err != nil { + return err } - name := string(bytes[5 : 5+nameLength]) - return &BootImage{ID: *imageID, Name: name}, nil + nameLength := buf.Read8() + b.Name = string(buf.Consume(int(nameLength))) + return buf.Error() } diff --git a/dhcpv4/bsdp/boot_image_test.go b/dhcpv4/bsdp/boot_image_test.go index 004aa30..d8e3aeb 100644 --- a/dhcpv4/bsdp/boot_image_test.go +++ b/dhcpv4/bsdp/boot_image_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/u-root/u-root/pkg/uio" ) func TestBootImageIDToBytes(t *testing.T) { @@ -28,25 +29,23 @@ func TestBootImageIDFromBytes(t *testing.T) { ImageType: BootImageTypeMacOSX, Index: 0x1000, } - newBootImage, err := BootImageIDFromBytes(b.ToBytes()) - require.NoError(t, err) - require.Equal(t, b, *newBootImage) + var newBootImage BootImageID + require.NoError(t, uio.FromBigEndian(&newBootImage, b.ToBytes())) + require.Equal(t, b, newBootImage) b = BootImageID{ IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 0x1011, } - newBootImage, err = BootImageIDFromBytes(b.ToBytes()) - require.NoError(t, err) - require.Equal(t, b, *newBootImage) + require.NoError(t, uio.FromBigEndian(&newBootImage, b.ToBytes())) + require.Equal(t, b, newBootImage) } func TestBootImageIDFromBytesFail(t *testing.T) { serialized := []byte{0x81, 0, 0x10} // intentionally left short - deserialized, err := BootImageIDFromBytes(serialized) - require.Nil(t, deserialized) - require.Error(t, err) + var deserialized BootImageID + require.Error(t, uio.FromBigEndian(&deserialized, serialized)) } func TestBootImageIDString(t *testing.T) { @@ -97,8 +96,8 @@ func TestBootImageFromBytes(t *testing.T) { 7, // len(Name) 98, 115, 100, 112, 45, 50, 49, // byte-encoding of Name } - b, err := BootImageFromBytes(input) - require.NoError(t, err) + var b BootImage + require.NoError(t, uio.FromBigEndian(&b, input)) expectedBootImage := BootImage{ ID: BootImageID{ IsInstall: false, @@ -107,15 +106,14 @@ func TestBootImageFromBytes(t *testing.T) { }, Name: "bsdp-21", } - require.Equal(t, expectedBootImage, *b) + require.Equal(t, expectedBootImage, b) } func TestBootImageFromBytesOnlyBootImageID(t *testing.T) { // Only a BootImageID, nothing else. input := []byte{0x1, 0, 0x10, 0x10} - b, err := BootImageFromBytes(input) - require.Nil(t, b) - require.Error(t, err) + var b BootImage + require.Error(t, uio.FromBigEndian(&b, input)) } func TestBootImageFromBytesShortBootImage(t *testing.T) { @@ -124,9 +122,8 @@ func TestBootImageFromBytesShortBootImage(t *testing.T) { 7, // len(Name) 98, 115, 100, 112, 45, 50, // Name bytes (intentionally off-by-one) } - b, err := BootImageFromBytes(input) - require.Nil(t, b) - require.Error(t, err) + var b BootImage + require.Error(t, uio.FromBigEndian(&b, input)) } func TestBootImageString(t *testing.T) { diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go index 3f97602..6bc0cfd 100644 --- a/dhcpv4/bsdp/bsdp.go +++ b/dhcpv4/bsdp/bsdp.go @@ -23,7 +23,7 @@ const AppleVendorID = "AAPLBSDPC" type ReplyConfig struct { ServerIP net.IP ServerHostname, BootFileName string - ServerPriority int + ServerPriority uint16 Images []BootImage DefaultImage, SelectedImage *BootImage } @@ -36,7 +36,7 @@ func ParseBootImageListFromAck(ack dhcpv4.DHCPv4) ([]BootImage, error) { if opt == nil { return nil, errors.New("ParseBootImageListFromAck: could not find vendor-specific option") } - vendorOpt, err := ParseOptVendorSpecificInformation(opt.ToBytes()) + vendorOpt, err := ParseOptVendorSpecificInformation(opt.ToBytes()[2:]) if err != nil { return nil, err } @@ -60,7 +60,7 @@ func MessageTypeFromPacket(packet *dhcpv4.DHCPv4) *MessageType { err error ) for _, opt := range packet.GetOption(dhcpv4.OptionVendorSpecificInformation) { - if vendorOpts, err = ParseOptVendorSpecificInformation(opt.ToBytes()); err == nil { + if vendorOpts, err = ParseOptVendorSpecificInformation(opt.ToBytes()[2:]); err == nil { if o := vendorOpts.GetOneOption(OptionMessageType); o != nil { if optMessageType, ok := o.(*OptMessageType); ok { return &optMessageType.Type diff --git a/dhcpv4/bsdp/bsdp_option_boot_image_list.go b/dhcpv4/bsdp/bsdp_option_boot_image_list.go index 6417221..d018655 100644 --- a/dhcpv4/bsdp/bsdp_option_boot_image_list.go +++ b/dhcpv4/bsdp/bsdp_option_boot_image_list.go @@ -1,9 +1,8 @@ package bsdp import ( - "fmt" - "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) // Implements the BSDP option listing the boot images. @@ -17,34 +16,15 @@ type OptBootImageList struct { // ParseOptBootImageList constructs an OptBootImageList struct from a sequence // of bytes and returns it, or an error. func ParseOptBootImageList(data []byte) (*OptBootImageList, error) { - // Should have at least code + length - if len(data) < 2 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionBootImageList { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionBootImageList, code) - } - length := int(data[1]) - if len(data) < length+2 { - return nil, fmt.Errorf("expected length %d, got %d instead", length, len(data)) - } + buf := uio.NewBigEndianBuffer(data) - // Offset from code + length byte var bootImages []BootImage - idx := 2 - for { - if idx >= length+2 { - break - } - image, err := BootImageFromBytes(data[idx:]) - if err != nil { - return nil, fmt.Errorf("parsing bytes stream: %v", err) + for buf.Has(5) { + var image BootImage + if err := (&image).Unmarshal(buf); err != nil { + return nil, err } - bootImages = append(bootImages, *image) - - // 4 bytes of BootImageID, 1 byte of name length, name - idx += 4 + 1 + len(image.Name) + bootImages = append(bootImages, image) } return &OptBootImageList{bootImages}, nil diff --git a/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go b/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go index d2784ae..0819d64 100644 --- a/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go +++ b/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go @@ -45,8 +45,6 @@ func TestOptBootImageListInterfaceMethods(t *testing.T) { func TestParseOptBootImageList(t *testing.T) { data := []byte{ - 9, // code - 22, // length // boot image 1 0x1, 0x0, 0x03, 0xe9, // ID 6, // name length @@ -78,25 +76,8 @@ func TestParseOptBootImageList(t *testing.T) { } require.Equal(t, &OptBootImageList{expectedBootImages}, o) - // Short byte stream - data = []byte{9} - _, err = ParseOptBootImageList(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 1, 1} - _, err = ParseOptBootImageList(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{9, 10, 1, 1, 1} - _, err = ParseOptBootImageList(data) - require.Error(t, err, "should get error from bad length") - // Error parsing boot image (malformed) data = []byte{ - 9, // code - 22, // length // boot image 1 0x1, 0x0, 0x03, 0xe9, // ID 4, // name length @@ -108,25 +89,6 @@ func TestParseOptBootImageList(t *testing.T) { } _, err = ParseOptBootImageList(data) require.Error(t, err, "should get error from bad boot image") - - // Should not get error parsing boot image with excess length. - data = []byte{ - 9, // code - 22, // length - // boot image 1 - 0x1, 0x0, 0x03, 0xe9, // ID - 6, // name length - 'b', 's', 'd', 'p', '-', '1', - // boot image 2 - 0x80, 0x0, 0x23, 0x31, // ID - 6, // name length - 'b', 's', 'd', 'p', '-', '2', - - // Simulate another option after boot image list - 7, 4, 0x80, 0x0, 0x23, 0x32, - } - _, err = ParseOptBootImageList(data) - require.NoError(t, err, "should not get error from options after boot image list") } func TestOptBootImageListString(t *testing.T) { diff --git a/dhcpv4/bsdp/bsdp_option_default_boot_image_id.go b/dhcpv4/bsdp/bsdp_option_default_boot_image_id.go index 4c87df8..52f7780 100644 --- a/dhcpv4/bsdp/bsdp_option_default_boot_image_id.go +++ b/dhcpv4/bsdp/bsdp_option_default_boot_image_id.go @@ -4,12 +4,13 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) +// OptDefaultBootImageID contains the selected boot image ID. +// // Implements the BSDP option default boot image ID, which tells the client // which image is the default boot image if one is not selected. - -// OptDefaultBootImageID contains the selected boot image ID. type OptDefaultBootImageID struct { ID BootImageID } @@ -17,22 +18,12 @@ type OptDefaultBootImageID struct { // ParseOptDefaultBootImageID constructs an OptDefaultBootImageID struct from a sequence of // bytes and returns it, or an error. func ParseOptDefaultBootImageID(data []byte) (*OptDefaultBootImageID, error) { - if len(data) < 6 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionDefaultBootImageID { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionDefaultBootImageID, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("expected length 4, got %d instead", length) - } - id, err := BootImageIDFromBytes(data[2:6]) - if err != nil { + var o OptDefaultBootImageID + buf := uio.NewBigEndianBuffer(data) + if err := o.ID.Unmarshal(buf); err != nil { return nil, err } - return &OptDefaultBootImageID{*id}, nil + return &o, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_default_boot_image_id_test.go b/dhcpv4/bsdp/bsdp_option_default_boot_image_id_test.go index e062e2d..ad29c30 100644 --- a/dhcpv4/bsdp/bsdp_option_default_boot_image_id_test.go +++ b/dhcpv4/bsdp/bsdp_option_default_boot_image_id_test.go @@ -17,24 +17,17 @@ func TestOptDefaultBootImageIDInterfaceMethods(t *testing.T) { func TestParseOptDefaultBootImageID(t *testing.T) { b := BootImageID{IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 1001} - bootImageBytes := b.ToBytes() - data := append([]byte{byte(OptionDefaultBootImageID), 4}, bootImageBytes...) - o, err := ParseOptDefaultBootImageID(data) + o, err := ParseOptDefaultBootImageID(b.ToBytes()) require.NoError(t, err) require.Equal(t, &OptDefaultBootImageID{b}, o) // Short byte stream - data = []byte{byte(OptionDefaultBootImageID), 4} + data := []byte{} _, err = ParseOptDefaultBootImageID(data) require.Error(t, err, "should get error from short byte stream") - // Wrong code - data = []byte{54, 2, 1, 0, 0, 0} - _, err = ParseOptDefaultBootImageID(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{byte(OptionDefaultBootImageID), 5, 1, 0, 0, 0, 0} + data = []byte{1, 0, 0, 0, 0} _, err = ParseOptDefaultBootImageID(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/bsdp/bsdp_option_generic.go b/dhcpv4/bsdp/bsdp_option_generic.go index 6a51e29..6702e9c 100644 --- a/dhcpv4/bsdp/bsdp_option_generic.go +++ b/dhcpv4/bsdp/bsdp_option_generic.go @@ -16,21 +16,8 @@ type OptGeneric struct { // ParseOptGeneric parses a bytestream and creates a new OptGeneric from it, // or an error. -func ParseOptGeneric(data []byte) (*OptGeneric, error) { - if len(data) == 0 { - return nil, dhcpv4.ErrZeroLengthByteStream - } - var ( - length int - optionData []byte - ) - code := dhcpv4.OptionCode(data[0]) - length = int(data[1]) - if len(data) < length+2 { - return nil, fmt.Errorf("invalid data length: declared %v, actual %v", length, len(data)) - } - optionData = data[2 : length+2] - return &OptGeneric{OptionCode: code, Data: optionData}, nil +func ParseOptGeneric(code dhcpv4.OptionCode, data []byte) (*OptGeneric, error) { + return &OptGeneric{OptionCode: code, Data: data}, nil } // Code returns the generic option code. @@ -45,7 +32,7 @@ func (o OptGeneric) ToBytes() []byte { // String returns a human-readable representation of a generic option. func (o OptGeneric) String() string { - code, ok := OptionCodeToString[o.Code()] + code, ok := optionCodeToString[o.Code()] if !ok { code = "Unknown" } diff --git a/dhcpv4/bsdp/bsdp_option_generic_test.go b/dhcpv4/bsdp/bsdp_option_generic_test.go index 27436dd..eae77e1 100644 --- a/dhcpv4/bsdp/bsdp_option_generic_test.go +++ b/dhcpv4/bsdp/bsdp_option_generic_test.go @@ -7,19 +7,11 @@ import ( ) func TestParseOptGeneric(t *testing.T) { - // Empty bytestream produces error - _, err := ParseOptGeneric([]byte{}) - require.Error(t, err, "error from empty bytestream") - // Good parse - o, err := ParseOptGeneric([]byte{1, 1, 1}) + o, err := ParseOptGeneric(OptionMessageType, []byte{1}) require.NoError(t, err) require.Equal(t, OptionMessageType, o.Code()) require.Equal(t, MessageTypeList, MessageType(o.Data[0])) - - // Bad parse - o, err = ParseOptGeneric([]byte{1, 2, 1}) - require.Error(t, err, "invalid length") } func TestOptGenericCode(t *testing.T) { diff --git a/dhcpv4/bsdp/bsdp_option_machine_name.go b/dhcpv4/bsdp/bsdp_option_machine_name.go index dc05378..cffba2e 100644 --- a/dhcpv4/bsdp/bsdp_option_machine_name.go +++ b/dhcpv4/bsdp/bsdp_option_machine_name.go @@ -1,15 +1,13 @@ package bsdp import ( - "fmt" - "github.com/insomniacslk/dhcp/dhcpv4" ) +// OptMachineName represents a BSDP message type. +// // Implements the BSDP option machine name, which gives the Netboot server's // machine name. - -// OptMachineName represents a BSDP message type. type OptMachineName struct { Name string } @@ -17,18 +15,7 @@ type OptMachineName struct { // ParseOptMachineName constructs an OptMachineName struct from a sequence of // bytes and returns it, or an error. func ParseOptMachineName(data []byte) (*OptMachineName, error) { - if len(data) < 2 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionMachineName { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionMachineName, code) - } - length := int(data[1]) - if len(data) < length+2 { - return nil, fmt.Errorf("expected length %d, got %d instead", length, len(data)) - } - return &OptMachineName{Name: string(data[2 : length+2])}, nil + return &OptMachineName{Name: string(data)}, nil } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_machine_name_test.go b/dhcpv4/bsdp/bsdp_option_machine_name_test.go index 9019020..712bc49 100644 --- a/dhcpv4/bsdp/bsdp_option_machine_name_test.go +++ b/dhcpv4/bsdp/bsdp_option_machine_name_test.go @@ -15,25 +15,10 @@ func TestOptMachineNameInterfaceMethods(t *testing.T) { } func TestParseOptMachineName(t *testing.T) { - data := []byte{130, 7, 's', 'o', 'm', 'e', 'b', 'o', 'x'} + data := []byte{'s', 'o', 'm', 'e', 'b', 'o', 'x'} o, err := ParseOptMachineName(data) require.NoError(t, err) require.Equal(t, &OptMachineName{"somebox"}, o) - - // Short byte stream - data = []byte{130} - _, err = ParseOptMachineName(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 1, 1} - _, err = ParseOptMachineName(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{130, 5, 1} - _, err = ParseOptMachineName(data) - require.Error(t, err, "should get error from bad length") } func TestOptMachineNameString(t *testing.T) { diff --git a/dhcpv4/bsdp/bsdp_option_message_type.go b/dhcpv4/bsdp/bsdp_option_message_type.go index c11b56b..8c3c3d4 100644 --- a/dhcpv4/bsdp/bsdp_option_message_type.go +++ b/dhcpv4/bsdp/bsdp_option_message_type.go @@ -4,12 +4,13 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) +// MessageType represents the different BSDP message types. +// // Implements the BSDP option message type. Can be one of LIST, SELECT, or // FAILED. - -// MessageType represents the different BSDP message types. type MessageType byte // BSDP Message types - e.g. LIST, SELECT, FAILED @@ -20,14 +21,14 @@ const ( ) func (m MessageType) String() string { - if s, ok := MessageTypeToString[m]; ok { + if s, ok := messageTypeToString[m]; ok { return s } return "Unknown" } -// MessageTypeToString maps each BSDP message type to a human-readable string. -var MessageTypeToString = map[MessageType]string{ +// messageTypeToString maps each BSDP message type to a human-readable string. +var messageTypeToString = map[MessageType]string{ MessageTypeList: "LIST", MessageTypeSelect: "SELECT", MessageTypeFailed: "FAILED", @@ -41,18 +42,8 @@ type OptMessageType struct { // ParseOptMessageType constructs an OptMessageType struct from a sequence of // bytes and returns it, or an error. func ParseOptMessageType(data []byte) (*OptMessageType, error) { - if len(data) < 3 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionMessageType { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionMessageType, code) - } - length := int(data[1]) - if length != 1 { - return nil, fmt.Errorf("expected length 1, got %d instead", length) - } - return &OptMessageType{Type: MessageType(data[2])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptMessageType{Type: MessageType(buf.Read8())}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_message_type_test.go b/dhcpv4/bsdp/bsdp_option_message_type_test.go index 9b7564f..41652be 100644 --- a/dhcpv4/bsdp/bsdp_option_message_type_test.go +++ b/dhcpv4/bsdp/bsdp_option_message_type_test.go @@ -14,25 +14,10 @@ func TestOptMessageTypeInterfaceMethods(t *testing.T) { } func TestParseOptMessageType(t *testing.T) { - data := []byte{1, 1, 1} // DISCOVER + data := []byte{1} // DISCOVER o, err := ParseOptMessageType(data) require.NoError(t, err) require.Equal(t, &OptMessageType{MessageTypeList}, o) - - // Short byte stream - data = []byte{1, 1} - _, err = ParseOptMessageType(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 1, 1} - _, err = ParseOptMessageType(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{1, 5, 1} - _, err = ParseOptMessageType(data) - require.Error(t, err, "should get error from bad length") } func TestOptMessageTypeString(t *testing.T) { diff --git a/dhcpv4/bsdp/bsdp_option_reply_port.go b/dhcpv4/bsdp/bsdp_option_reply_port.go index f1cc49f..da5e9c4 100644 --- a/dhcpv4/bsdp/bsdp_option_reply_port.go +++ b/dhcpv4/bsdp/bsdp_option_reply_port.go @@ -5,14 +5,15 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) +// OptReplyPort represents a BSDP protocol version. +// // Implements the BSDP option reply port. This is used when BSDP responses // should be sent to a reply port other than the DHCP default. The macOS GUI // "Startup Disk Select" sends this option since it's operating in an // unprivileged context. - -// OptReplyPort represents a BSDP protocol version. type OptReplyPort struct { Port uint16 } @@ -20,19 +21,8 @@ type OptReplyPort struct { // ParseOptReplyPort constructs an OptReplyPort struct from a sequence of // bytes and returns it, or an error. func ParseOptReplyPort(data []byte) (*OptReplyPort, error) { - if len(data) < 4 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionReplyPort { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionReplyPort, code) - } - length := int(data[1]) - if length != 2 { - return nil, fmt.Errorf("expected length 2, got %d instead", length) - } - port := binary.BigEndian.Uint16(data[2:4]) - return &OptReplyPort{port}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptReplyPort{buf.Read16()}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_reply_port_test.go b/dhcpv4/bsdp/bsdp_option_reply_port_test.go index c9906ff..719bbc8 100644 --- a/dhcpv4/bsdp/bsdp_option_reply_port_test.go +++ b/dhcpv4/bsdp/bsdp_option_reply_port_test.go @@ -14,23 +14,18 @@ func TestOptReplyPortInterfaceMethods(t *testing.T) { } func TestParseOptReplyPort(t *testing.T) { - data := []byte{byte(OptionReplyPort), 2, 0, 1} + data := []byte{0, 1} o, err := ParseOptReplyPort(data) require.NoError(t, err) require.Equal(t, &OptReplyPort{1}, o) // Short byte stream - data = []byte{byte(OptionReplyPort), 2} + data = []byte{} _, err = ParseOptReplyPort(data) require.Error(t, err, "should get error from short byte stream") - // Wrong code - data = []byte{54, 2, 1, 0} - _, err = ParseOptReplyPort(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{byte(OptionReplyPort), 4, 1, 0} + data = []byte{1} _, err = ParseOptReplyPort(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id.go b/dhcpv4/bsdp/bsdp_option_selected_boot_image_id.go index 5b00ded..52b6eab 100644 --- a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id.go +++ b/dhcpv4/bsdp/bsdp_option_selected_boot_image_id.go @@ -4,12 +4,13 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) +// OptSelectedBootImageID contains the selected boot image ID. +// // Implements the BSDP option selected boot image ID, which tells the server // which boot image has been selected by the client. - -// OptSelectedBootImageID contains the selected boot image ID. type OptSelectedBootImageID struct { ID BootImageID } @@ -17,22 +18,12 @@ type OptSelectedBootImageID struct { // ParseOptSelectedBootImageID constructs an OptSelectedBootImageID struct from a sequence of // bytes and returns it, or an error. func ParseOptSelectedBootImageID(data []byte) (*OptSelectedBootImageID, error) { - if len(data) < 6 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionSelectedBootImageID { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionSelectedBootImageID, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("expected length 4, got %d instead", length) - } - id, err := BootImageIDFromBytes(data[2:6]) - if err != nil { + var o OptSelectedBootImageID + buf := uio.NewBigEndianBuffer(data) + if err := o.ID.Unmarshal(buf); err != nil { return nil, err } - return &OptSelectedBootImageID{*id}, nil + return &o, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id_test.go b/dhcpv4/bsdp/bsdp_option_selected_boot_image_id_test.go index 9529e41..a55fd9f 100644 --- a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id_test.go +++ b/dhcpv4/bsdp/bsdp_option_selected_boot_image_id_test.go @@ -17,24 +17,18 @@ func TestOptSelectedBootImageIDInterfaceMethods(t *testing.T) { func TestParseOptSelectedBootImageID(t *testing.T) { b := BootImageID{IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 1001} - bootImageBytes := b.ToBytes() - data := append([]byte{byte(OptionSelectedBootImageID), 4}, bootImageBytes...) + data := b.ToBytes() o, err := ParseOptSelectedBootImageID(data) require.NoError(t, err) require.Equal(t, &OptSelectedBootImageID{b}, o) // Short byte stream - data = []byte{byte(OptionSelectedBootImageID), 4} + data = []byte{} _, err = ParseOptSelectedBootImageID(data) require.Error(t, err, "should get error from short byte stream") - // Wrong code - data = []byte{54, 2, 1, 0, 0, 0} - _, err = ParseOptSelectedBootImageID(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{byte(OptionSelectedBootImageID), 5, 1, 0, 0, 0, 0} + data = []byte{1, 0, 0, 0, 0} _, err = ParseOptSelectedBootImageID(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/bsdp/bsdp_option_server_identifier.go b/dhcpv4/bsdp/bsdp_option_server_identifier.go index 252a0aa..26ec37a 100644 --- a/dhcpv4/bsdp/bsdp_option_server_identifier.go +++ b/dhcpv4/bsdp/bsdp_option_server_identifier.go @@ -5,6 +5,7 @@ import ( "net" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) // OptServerIdentifier represents an option encapsulating the server identifier. @@ -15,21 +16,8 @@ type OptServerIdentifier struct { // ParseOptServerIdentifier returns a new OptServerIdentifier from a byte // stream, or error if any. func ParseOptServerIdentifier(data []byte) (*OptServerIdentifier, error) { - if len(data) < 2 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionServerIdentifier { - return nil, fmt.Errorf("expected code %v, got %v", OptionServerIdentifier, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("unexpected length: expected 4, got %v", length) - } - if len(data) < 6 { - return nil, dhcpv4.ErrShortByteStream - } - return &OptServerIdentifier{ServerID: net.IP(data[2 : 2+length])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptServerIdentifier{ServerID: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_server_identifier_test.go b/dhcpv4/bsdp/bsdp_option_server_identifier_test.go index 5267caa..d832c40 100644 --- a/dhcpv4/bsdp/bsdp_option_server_identifier_test.go +++ b/dhcpv4/bsdp/bsdp_option_server_identifier_test.go @@ -26,15 +26,9 @@ func TestParseOptServerIdentifier(t *testing.T) { require.Error(t, err, "empty byte stream") o, err = ParseOptServerIdentifier([]byte{3, 4, 192}) - require.Error(t, err, "short byte stream") - - o, err = ParseOptServerIdentifier([]byte{3, 3, 192, 168, 0, 1}) require.Error(t, err, "wrong IP length") - o, err = ParseOptServerIdentifier([]byte{53, 4, 192, 168, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptServerIdentifier([]byte{3, 4, 192, 168, 0, 1}) + o, err = ParseOptServerIdentifier([]byte{192, 168, 0, 1}) require.NoError(t, err) require.Equal(t, net.IP{192, 168, 0, 1}, o.ServerID) } diff --git a/dhcpv4/bsdp/bsdp_option_server_priority.go b/dhcpv4/bsdp/bsdp_option_server_priority.go index 1952b7e..66bfa44 100644 --- a/dhcpv4/bsdp/bsdp_option_server_priority.go +++ b/dhcpv4/bsdp/bsdp_option_server_priority.go @@ -5,31 +5,19 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) -// This option implements the server identifier option -// https://tools.ietf.org/html/rfc2132 - // OptServerPriority represents an option encapsulating the server priority. type OptServerPriority struct { - Priority int + Priority uint16 } // ParseOptServerPriority returns a new OptServerPriority from a byte stream, or // error if any. func ParseOptServerPriority(data []byte) (*OptServerPriority, error) { - if len(data) < 4 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionServerPriority { - return nil, fmt.Errorf("expected code %v, got %v", OptionServerPriority, code) - } - length := int(data[1]) - if length != 2 { - return nil, fmt.Errorf("unexpected length: expected 2, got %v", length) - } - return &OptServerPriority{Priority: int(binary.BigEndian.Uint16(data[2:4]))}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptServerPriority{Priority: buf.Read16()}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_server_priority_test.go b/dhcpv4/bsdp/bsdp_option_server_priority_test.go index d12ad55..cbcef1d 100644 --- a/dhcpv4/bsdp/bsdp_option_server_priority_test.go +++ b/dhcpv4/bsdp/bsdp_option_server_priority_test.go @@ -22,16 +22,10 @@ func TestParseOptServerPriority(t *testing.T) { o, err = ParseOptServerPriority([]byte{}) require.Error(t, err, "empty byte stream") - o, err = ParseOptServerPriority([]byte{4, 2, 1}) + o, err = ParseOptServerPriority([]byte{1}) require.Error(t, err, "short byte stream") - o, err = ParseOptServerPriority([]byte{4, 3, 1, 1}) - require.Error(t, err, "wrong priority length") - - o, err = ParseOptServerPriority([]byte{53, 2, 168, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptServerPriority([]byte{4, 2, 0, 100}) + o, err = ParseOptServerPriority([]byte{0, 100}) require.NoError(t, err) - require.Equal(t, 100, o.Priority) + require.Equal(t, uint16(100), o.Priority) } diff --git a/dhcpv4/bsdp/bsdp_option_version.go b/dhcpv4/bsdp/bsdp_option_version.go index 8431a94..38158d7 100644 --- a/dhcpv4/bsdp/bsdp_option_version.go +++ b/dhcpv4/bsdp/bsdp_option_version.go @@ -4,10 +4,9 @@ import ( "fmt" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" ) -// Implements the BSDP option version. Can be one of 1.0 or 1.1 - // Specific versions. var ( Version1_0 = []byte{1, 0} @@ -15,6 +14,8 @@ var ( ) // OptVersion represents a BSDP protocol version. +// +// Implements the BSDP option version. Can be one of 1.0 or 1.1 type OptVersion struct { Version []byte } @@ -22,18 +23,8 @@ type OptVersion struct { // ParseOptVersion constructs an OptVersion struct from a sequence of // bytes and returns it, or an error. func ParseOptVersion(data []byte) (*OptVersion, error) { - if len(data) < 4 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != OptionVersion { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionVersion, code) - } - length := int(data[1]) - if length != 2 { - return nil, fmt.Errorf("expected length 2, got %d instead", length) - } - return &OptVersion{data[2:4]}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptVersion{buf.CopyN(2)}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/bsdp/bsdp_option_version_test.go b/dhcpv4/bsdp/bsdp_option_version_test.go index c6a6afc..d28f243 100644 --- a/dhcpv4/bsdp/bsdp_option_version_test.go +++ b/dhcpv4/bsdp/bsdp_option_version_test.go @@ -14,25 +14,15 @@ func TestOptVersionInterfaceMethods(t *testing.T) { } func TestParseOptVersion(t *testing.T) { - data := []byte{2, 2, 1, 1} + data := []byte{1, 1} o, err := ParseOptVersion(data) require.NoError(t, err) require.Equal(t, &OptVersion{Version1_1}, o) // Short byte stream - data = []byte{2, 2} + data = []byte{2} _, err = ParseOptVersion(data) require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 0} - _, err = ParseOptVersion(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{2, 4, 1, 0} - _, err = ParseOptVersion(data) - require.Error(t, err, "should get error from bad length") } func TestOptVersionString(t *testing.T) { diff --git a/dhcpv4/bsdp/option_vendor_specific_information.go b/dhcpv4/bsdp/option_vendor_specific_information.go index e735b57..1bd41a7 100644 --- a/dhcpv4/bsdp/option_vendor_specific_information.go +++ b/dhcpv4/bsdp/option_vendor_specific_information.go @@ -1,8 +1,6 @@ package bsdp import ( - "errors" - "fmt" "strings" "github.com/insomniacslk/dhcp/dhcpv4" @@ -11,20 +9,17 @@ import ( // OptVendorSpecificInformation encapsulates the BSDP-specific options used for // the protocol. type OptVendorSpecificInformation struct { - Options []dhcpv4.Option + Options dhcpv4.Options } // parseOption is similar to dhcpv4.ParseOption, except that it switches based // on the BSDP specific options. -func parseOption(data []byte) (dhcpv4.Option, error) { - if len(data) == 0 { - return nil, dhcpv4.ErrZeroLengthByteStream - } +func parseOption(code dhcpv4.OptionCode, data []byte) (dhcpv4.Option, error) { var ( opt dhcpv4.Option err error ) - switch dhcpv4.OptionCode(data[0]) { + switch code { case OptionBootImageList: opt, err = ParseOptBootImageList(data) case OptionDefaultBootImageID: @@ -44,7 +39,7 @@ func parseOption(data []byte) (dhcpv4.Option, error) { case OptionVersion: opt, err = ParseOptVersion(data) default: - opt, err = ParseOptGeneric(data) + opt, err = ParseOptGeneric(code, data) } if err != nil { return nil, err @@ -55,39 +50,10 @@ func parseOption(data []byte) (dhcpv4.Option, error) { // ParseOptVendorSpecificInformation constructs an OptVendorSpecificInformation struct from a sequence of // bytes and returns it, or an error. func ParseOptVendorSpecificInformation(data []byte) (*OptVendorSpecificInformation, error) { - // Should at least have code + length - if len(data) < 2 { - return nil, dhcpv4.ErrShortByteStream - } - code := dhcpv4.OptionCode(data[0]) - if code != dhcpv4.OptionVendorSpecificInformation { - return nil, fmt.Errorf("expected option %v, got %v instead", dhcpv4.OptionVendorSpecificInformation, code) - } - length := int(data[1]) - if len(data) < length+2 { - return nil, fmt.Errorf("expected length 2, got %d instead", length) - } - - options := make([]dhcpv4.Option, 0, 10) - idx := 2 - for { - if idx == len(data) { - break - } - // This should never happen. - if idx > len(data) { - return nil, errors.New("read past the end of options") - } - opt, err := parseOption(data[idx:]) - if err != nil { - return nil, err - } - options = append(options, opt) - - // Account for code + length bytes - idx += 2 + opt.Length() + options, err := dhcpv4.OptionsFromBytesWithParser(data, parseOption, false /* don't check for OptionEnd tag */) + if err != nil { + return nil, err } - return &OptVendorSpecificInformation{options}, nil } @@ -133,20 +99,10 @@ func (o *OptVendorSpecificInformation) Length() int { // GetOption returns all suboptions that match the given OptionCode code. func (o *OptVendorSpecificInformation) GetOption(code dhcpv4.OptionCode) []dhcpv4.Option { - var opts []dhcpv4.Option - for _, opt := range o.Options { - if opt.Code() == code { - opts = append(opts, opt) - } - } - return opts + return o.Options.GetOption(code) } // GetOneOption returns the first suboption that matches the OptionCode code. func (o *OptVendorSpecificInformation) GetOneOption(code dhcpv4.OptionCode) dhcpv4.Option { - opts := o.GetOption(code) - if len(opts) == 0 { - return nil - } - return opts[0] + return o.Options.GetOneOption(code) } diff --git a/dhcpv4/bsdp/option_vendor_specific_information_test.go b/dhcpv4/bsdp/option_vendor_specific_information_test.go index 8a4368f..8aa2bdf 100644 --- a/dhcpv4/bsdp/option_vendor_specific_information_test.go +++ b/dhcpv4/bsdp/option_vendor_specific_information_test.go @@ -34,19 +34,11 @@ func TestParseOptVendorSpecificInformation(t *testing.T) { o *OptVendorSpecificInformation err error ) - o, err = ParseOptVendorSpecificInformation([]byte{}) - require.Error(t, err, "empty byte stream") - o, err = ParseOptVendorSpecificInformation([]byte{1, 2}) require.Error(t, err, "short byte stream") - o, err = ParseOptVendorSpecificInformation([]byte{53, 2, 1, 1}) - require.Error(t, err, "wrong option code") - // Good byte stream data := []byte{ - 43, // code - 7, // length 1, 1, 1, // List option 2, 2, 1, 1, // Version option } @@ -59,13 +51,13 @@ func TestParseOptVendorSpecificInformation(t *testing.T) { }, } require.Equal(t, 2, len(o.Options), "number of parsed suboptions") - require.Equal(t, expected.Options[0].Code(), o.Options[0].Code()) - require.Equal(t, expected.Options[1].Code(), o.Options[1].Code()) + typ := o.GetOneOption(OptionMessageType) + version := o.GetOneOption(OptionVersion) + require.Equal(t, expected.Options[0].Code(), typ.Code()) + require.Equal(t, expected.Options[1].Code(), version.Code()) // Short byte stream (length and data mismatch) data = []byte{ - 43, // code - 7, // length 1, 1, 1, // List option 2, 2, 1, // Version option } @@ -74,8 +66,6 @@ func TestParseOptVendorSpecificInformation(t *testing.T) { // Bad option data = []byte{ - 43, // code - 7, // length 1, 1, 1, // List option 2, 2, 1, // Version option 5, 3, 1, 1, 1, // Reply port option @@ -85,8 +75,6 @@ func TestParseOptVendorSpecificInformation(t *testing.T) { // Boot images + default. data = []byte{ - 43, // code - 7, // length 1, 1, 1, // List option 2, 2, 1, 1, // Version option 5, 2, 1, 1, // Reply port option diff --git a/dhcpv4/bsdp/types.go b/dhcpv4/bsdp/types.go index ac4b05b..aa9a824 100644 --- a/dhcpv4/bsdp/types.go +++ b/dhcpv4/bsdp/types.go @@ -26,9 +26,9 @@ const ( OptionMachineName dhcpv4.OptionCode = 130 ) -// OptionCodeToString maps BSDP OptionCodes to human-readable strings +// optionCodeToString maps BSDP OptionCodes to human-readable strings // describing what they are. -var OptionCodeToString = map[dhcpv4.OptionCode]string{ +var optionCodeToString = map[dhcpv4.OptionCode]string{ OptionMessageType: "BSDP Message Type", OptionVersion: "BSDP Version", OptionServerIdentifier: "BSDP Server Identifier", diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index b24502f..2f832e8 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -46,7 +46,7 @@ type DHCPv4 struct { ClientHWAddr net.HardwareAddr ServerHostName string BootFileName string - Options []Option + Options Options } // Modifier defines the signature for functions that can modify DHCPv4 @@ -363,9 +363,32 @@ func (d *DHCPv4) SetUnicast() { // GetOption will attempt to get all options that match a DHCPv4 option // from its OptionCode. If the option was not found it will return an // empty list. +// +// According to RFC 3396, options that are specified more than once are +// concatenated, and hence this should always just return one option. func (d *DHCPv4) GetOption(code OptionCode) []Option { + return d.Options.GetOption(code) +} + +// GetOneOption will attempt to get an option that match a Option code. +// If there are multiple options with the same OptionCode it will only return +// the first one found. If no matching option is found nil will be returned. +func (d *DHCPv4) GetOneOption(code OptionCode) Option { + return d.Options.GetOneOption(code) +} + +// Options is a collection of options. +type Options []Option + +// GetOption will attempt to get all options that match a DHCPv4 option +// from its OptionCode. If the option was not found it will return an +// empty list. +// +// According to RFC 3396, options that are specified more than once are +// concatenated, and hence this should always just return one option. +func (o Options) GetOption(code OptionCode) []Option { opts := []Option{} - for _, opt := range d.Options { + for _, opt := range o { if opt.Code() == code { opts = append(opts, opt) } @@ -376,8 +399,8 @@ func (d *DHCPv4) GetOption(code OptionCode) []Option { // GetOneOption will attempt to get an option that match a Option code. // If there are multiple options with the same OptionCode it will only return // the first one found. If no matching option is found nil will be returned. -func (d *DHCPv4) GetOneOption(code OptionCode) Option { - for _, opt := range d.Options { +func (o Options) GetOneOption(code OptionCode) Option { + for _, opt := range o { if opt.Code() == code { return opt } @@ -385,21 +408,6 @@ func (d *DHCPv4) GetOneOption(code OptionCode) Option { return nil } -// StrippedOptions works like Options, but it does not return anything after the -// End option. -func (d *DHCPv4) StrippedOptions() []Option { - // differently from Options() this function strips away anything coming - // after the End option (normally just Pad options). - strippedOptions := []Option{} - for _, opt := range d.Options { - strippedOptions = append(strippedOptions, opt) - if opt.Code() == OptionEnd { - break - } - } - return strippedOptions -} - // AddOption appends an option to the existing ones. If the last option is an // OptionEnd, it will be inserted before that. It does not deal with End // options that appead before the end, like in malformed packets. diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go index f6f9d7c..9be2a1a 100644 --- a/dhcpv4/dhcpv4_test.go +++ b/dhcpv4/dhcpv4_test.go @@ -44,6 +44,7 @@ func TestFromBytes(t *testing.T) { 0, 0, 0, 0, // gateway IP address 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // client MAC address + padding } + // server host name expectedHostname := []byte{} for i := 0; i < 64; i++ { @@ -57,7 +58,7 @@ func TestFromBytes(t *testing.T) { } data = append(data, expectedBootfilename...) // magic cookie, then no options - data = append(data, []byte{99, 130, 83, 99}...) + data = append(data, magicCookie[:]...) d, err := FromBytes(data) require.NoError(t, err) @@ -179,11 +180,11 @@ func TestGetOption(t *testing.T) { } hostnameOpt := &OptionGeneric{OptionCode: OptionHostName, Data: []byte("darkstar")} - bootFileOpt1 := &OptBootfileName{[]byte("boot.img")} - bootFileOpt2 := &OptBootfileName{[]byte("boot2.img")} + bootFileOpt1 := &OptBootfileName{"boot.img"} + bootFileOpt2 := &OptBootfileName{"boot2.img"} d.AddOption(hostnameOpt) - d.AddOption(&OptBootfileName{[]byte("boot.img")}) - d.AddOption(&OptBootfileName{[]byte("boot2.img")}) + d.AddOption(&OptBootfileName{"boot.img"}) + d.AddOption(&OptBootfileName{"boot2.img"}) require.Equal(t, d.GetOption(OptionHostName), []Option{hostnameOpt}) require.Equal(t, d.GetOption(OptionBootfileName), []Option{bootFileOpt1, bootFileOpt2}) @@ -227,32 +228,6 @@ func TestUpdateOption(t *testing.T) { require.Equal(t, OptionEnd, d.Options[1].Code()) } -func TestStrippedOptions(t *testing.T) { - // Normal set of options that terminate with OptionEnd. - d, err := New() - require.NoError(t, err) - opts := []Option{ - &OptBootfileName{[]byte("boot.img")}, - &OptClassIdentifier{"something"}, - &OptionGeneric{OptionCode: OptionEnd}, - } - d.Options = opts - stripped := d.StrippedOptions() - require.Equal(t, len(opts), len(stripped)) - for i := range stripped { - require.Equal(t, opts[i], stripped[i]) - } - - // Set of options with additional options after OptionEnd - opts = append(opts, &OptMaximumDHCPMessageSize{uint16(1234)}) - d.Options = opts - stripped = d.StrippedOptions() - require.Equal(t, len(opts)-1, len(stripped)) - for i := range stripped { - require.Equal(t, opts[i], stripped[i]) - } -} - func TestDHCPv4NewRequestFromOffer(t *testing.T) { offer, err := New() require.NoError(t, err) diff --git a/dhcpv4/option_archtype.go b/dhcpv4/option_archtype.go index 92a3769..f5882bc 100644 --- a/dhcpv4/option_archtype.go +++ b/dhcpv4/option_archtype.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/insomniacslk/dhcp/iana" + "github.com/u-root/u-root/pkg/uio" ) // OptClientArchType represents an option encapsulating the Client System @@ -35,7 +36,7 @@ func (o *OptClientArchType) ToBytes() []byte { // Length returns the length of the data portion (excluding option code an byte // length). func (o *OptClientArchType) Length() int { - return 2*len(o.ArchTypes) + return 2 * len(o.ArchTypes) } // String returns a human-readable string. @@ -53,24 +54,14 @@ func (o *OptClientArchType) String() string { // ParseOptClientArchType returns a new OptClientArchType from a byte stream, // or error if any. func ParseOptClientArchType(data []byte) (*OptClientArchType, error) { - if len(data) < 2 { - return nil, ErrShortByteStream + buf := uio.NewBigEndianBuffer(data) + if buf.Len() == 0 { + return nil, fmt.Errorf("must have at least one archtype if option is present") } - code := OptionCode(data[0]) - if code != OptionClientSystemArchitectureType { - return nil, fmt.Errorf("expected code %v, got %v", OptionClientSystemArchitectureType, code) - } - length := int(data[1]) - if length == 0 || length%2 != 0 { - return nil, fmt.Errorf("Invalid length: expected multiple of 2 larger than 2, got %v", length) - } - if len(data) < 2+length { - return nil, ErrShortByteStream - } - archTypes := make([]iana.ArchType, 0, length%2) - for idx := 0; idx < length; idx += 2 { - b := data[2+idx : 2+idx+2] - archTypes = append(archTypes, iana.ArchType(binary.BigEndian.Uint16(b))) + + archTypes := make([]iana.ArchType, 0, buf.Len()/2) + for buf.Has(2) { + archTypes = append(archTypes, iana.ArchType(buf.Read16())) } - return &OptClientArchType{ArchTypes: archTypes}, nil + return &OptClientArchType{ArchTypes: archTypes}, buf.FinError() } diff --git a/dhcpv4/option_archtype_test.go b/dhcpv4/option_archtype_test.go index d803328..482ebb1 100644 --- a/dhcpv4/option_archtype_test.go +++ b/dhcpv4/option_archtype_test.go @@ -13,7 +13,7 @@ func TestParseOptClientArchType(t *testing.T) { 2, // Length 0, 6, // EFI_IA32 } - opt, err := ParseOptClientArchType(data) + opt, err := ParseOptClientArchType(data[2:]) require.NoError(t, err) require.Equal(t, opt.ArchTypes[0], iana.EFI_IA32) } @@ -25,7 +25,7 @@ func TestParseOptClientArchTypeMultiple(t *testing.T) { 0, 6, // EFI_IA32 0, 2, // EFI_ITANIUM } - opt, err := ParseOptClientArchType(data) + opt, err := ParseOptClientArchType(data[2:]) require.NoError(t, err) require.Equal(t, opt.ArchTypes[0], iana.EFI_IA32) require.Equal(t, opt.ArchTypes[1], iana.EFI_ITANIUM) @@ -43,7 +43,7 @@ func TestOptClientArchTypeParseAndToBytes(t *testing.T) { 2, // Length 0, 8, // EFI_XSCALE } - opt, err := ParseOptClientArchType(data) + opt, err := ParseOptClientArchType(data[2:]) require.NoError(t, err) require.Equal(t, opt.ToBytes(), data) } @@ -55,7 +55,7 @@ func TestOptClientArchTypeParseAndToBytesMultiple(t *testing.T) { 0, 8, // EFI_XSCALE 0, 6, // EFI_IA32 } - opt, err := ParseOptClientArchType(data) + opt, err := ParseOptClientArchType(data[2:]) require.NoError(t, err) require.Equal(t, opt.ToBytes(), data) } diff --git a/dhcpv4/option_bootfile_name.go b/dhcpv4/option_bootfile_name.go index ca9317b..e06e4eb 100644 --- a/dhcpv4/option_bootfile_name.go +++ b/dhcpv4/option_bootfile_name.go @@ -9,7 +9,7 @@ import ( // OptBootfileName implements the BootFile Name option type OptBootfileName struct { - BootfileName []byte + BootfileName string } // Code returns the option code @@ -19,7 +19,7 @@ func (op *OptBootfileName) Code() OptionCode { // ToBytes serializes the option and returns it as a sequence of bytes func (op *OptBootfileName) ToBytes() []byte { - return append([]byte{byte(op.Code()), byte(op.Length())}, op.BootfileName...) + return append([]byte{byte(op.Code()), byte(op.Length())}, []byte(op.BootfileName)...) } // Length returns the option length in bytes @@ -34,21 +34,5 @@ func (op *OptBootfileName) String() string { // ParseOptBootfileName returns a new OptBootfile from a byte stream or error if any func ParseOptBootfileName(data []byte) (*OptBootfileName, error) { - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionBootfileName { - return nil, fmt.Errorf("ParseOptBootfileName: invalid code: %v; want %v", code, OptionBootfileName) - } - length := int(data[1]) - if length < 1 { - return nil, fmt.Errorf("Bootfile name has invalid length of %d", length) - } - bootFileNameData := data[2:] - if len(bootFileNameData) < length { - return nil, fmt.Errorf("ParseOptBootfileName: short data: %d bytes; want %d", - len(bootFileNameData), length) - } - return &OptBootfileName{BootfileName: bootFileNameData[:length]}, nil + return &OptBootfileName{BootfileName: string(data)}, nil } diff --git a/dhcpv4/option_bootfile_name_test.go b/dhcpv4/option_bootfile_name_test.go index 0c7c200..2671ac5 100644 --- a/dhcpv4/option_bootfile_name_test.go +++ b/dhcpv4/option_bootfile_name_test.go @@ -13,7 +13,7 @@ func TestOptBootfileNameCode(t *testing.T) { func TestOptBootfileNameToBytes(t *testing.T) { opt := OptBootfileName{ - BootfileName: []byte("linuxboot"), + BootfileName: "linuxboot", } data := opt.ToBytes() expected := []byte{ @@ -26,40 +26,15 @@ func TestOptBootfileNameToBytes(t *testing.T) { func TestParseOptBootfileName(t *testing.T) { expected := []byte{ - 67, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', + 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } opt, err := ParseOptBootfileName(expected) require.NoError(t, err) require.Equal(t, 9, opt.Length()) - require.Equal(t, "linuxboot", string(opt.BootfileName)) -} - -func TestParseOptBootfileNameZeroLength(t *testing.T) { - expected := []byte{ - 67, 0, - } - _, err := ParseOptBootfileName(expected) - require.Error(t, err) -} - -func TestParseOptBootfileNameInvalidLength(t *testing.T) { - expected := []byte{ - 67, 9, 'l', 'i', 'n', 'u', 'x', 'b', - } - _, err := ParseOptBootfileName(expected) - require.Error(t, err) -} - -func TestParseOptBootfileNameShortLength(t *testing.T) { - expected := []byte{ - 67, 4, 'l', 'i', 'n', 'u', 'x', - } - opt, err := ParseOptBootfileName(expected) - require.NoError(t, err) - require.Equal(t, []byte("linu"), opt.BootfileName) + require.Equal(t, "linuxboot", opt.BootfileName) } func TestOptBootfileNameString(t *testing.T) { - o := OptBootfileName{BootfileName: []byte("testy test")} + o := OptBootfileName{BootfileName: "testy test"} require.Equal(t, "Bootfile Name -> testy test", o.String()) } diff --git a/dhcpv4/option_broadcast_address.go b/dhcpv4/option_broadcast_address.go index ffa57e3..fdce946 100644 --- a/dhcpv4/option_broadcast_address.go +++ b/dhcpv4/option_broadcast_address.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "fmt" "net" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the server identifier option @@ -16,21 +18,8 @@ type OptBroadcastAddress struct { // ParseOptBroadcastAddress returns a new OptBroadcastAddress from a byte // stream, or error if any. func ParseOptBroadcastAddress(data []byte) (*OptBroadcastAddress, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionBroadcastAddress { - return nil, fmt.Errorf("expected code %v, got %v", OptionBroadcastAddress, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("unexepcted length: expected 4, got %v", length) - } - if len(data) < 6 { - return nil, ErrShortByteStream - } - return &OptBroadcastAddress{BroadcastAddress: net.IP(data[2 : 2+length])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptBroadcastAddress{BroadcastAddress: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_broadcast_address_test.go b/dhcpv4/option_broadcast_address_test.go index 3572dc0..1feb6cc 100644 --- a/dhcpv4/option_broadcast_address_test.go +++ b/dhcpv4/option_broadcast_address_test.go @@ -29,16 +29,10 @@ func TestParseOptBroadcastAddress(t *testing.T) { o, err = ParseOptBroadcastAddress([]byte{}) require.Error(t, err, "empty byte stream") - o, err = ParseOptBroadcastAddress([]byte{byte(OptionBroadcastAddress), 4, 192}) - require.Error(t, err, "short byte stream") - - o, err = ParseOptBroadcastAddress([]byte{byte(OptionBroadcastAddress), 3, 192, 168, 0, 1}) + o, err = ParseOptBroadcastAddress([]byte{192, 168, 0}) require.Error(t, err, "wrong IP length") - o, err = ParseOptBroadcastAddress([]byte{53, 4, 192, 168, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptBroadcastAddress([]byte{byte(OptionBroadcastAddress), 4, 192, 168, 0, 1}) + o, err = ParseOptBroadcastAddress([]byte{192, 168, 0, 1}) require.NoError(t, err) require.Equal(t, net.IP{192, 168, 0, 1}, o.BroadcastAddress) } diff --git a/dhcpv4/option_class_identifier.go b/dhcpv4/option_class_identifier.go index ae5dba2..1a49b87 100644 --- a/dhcpv4/option_class_identifier.go +++ b/dhcpv4/option_class_identifier.go @@ -15,19 +15,7 @@ type OptClassIdentifier struct { // ParseOptClassIdentifier constructs an OptClassIdentifier struct from a sequence of // bytes and returns it, or an error. func ParseOptClassIdentifier(data []byte) (*OptClassIdentifier, error) { - // Should at least have code and length - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionClassIdentifier { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionClassIdentifier, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - return &OptClassIdentifier{Identifier: string(data[2 : 2+length])}, nil + return &OptClassIdentifier{Identifier: string(data)}, nil } // Code returns the option code. diff --git a/dhcpv4/option_class_identifier_test.go b/dhcpv4/option_class_identifier_test.go index 786ecfb..289eafe 100644 --- a/dhcpv4/option_class_identifier_test.go +++ b/dhcpv4/option_class_identifier_test.go @@ -14,25 +14,10 @@ func TestOptClassIdentifierInterfaceMethods(t *testing.T) { } func TestParseOptClassIdentifier(t *testing.T) { - data := []byte{byte(OptionClassIdentifier), 4, 't', 'e', 's', 't'} // DISCOVER + data := []byte{'t', 'e', 's', 't'} o, err := ParseOptClassIdentifier(data) require.NoError(t, err) require.Equal(t, &OptClassIdentifier{Identifier: "test"}, o) - - // Short byte stream - data = []byte{byte(OptionClassIdentifier)} - _, err = ParseOptClassIdentifier(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptClassIdentifier(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionClassIdentifier), 6, 1, 1, 1} - _, err = ParseOptClassIdentifier(data) - require.Error(t, err, "should get error from bad length") } func TestOptClassIdentifierString(t *testing.T) { diff --git a/dhcpv4/option_domain_name.go b/dhcpv4/option_domain_name.go index e876b05..673b2a6 100644 --- a/dhcpv4/option_domain_name.go +++ b/dhcpv4/option_domain_name.go @@ -13,18 +13,7 @@ type OptDomainName struct { // ParseOptDomainName returns a new OptDomainName from a byte // stream, or error if any. func ParseOptDomainName(data []byte) (*OptDomainName, error) { - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionDomainName { - return nil, fmt.Errorf("expected code %v, got %v", OptionDomainName, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - return &OptDomainName{DomainName: string(data[2 : 2+length])}, nil + return &OptDomainName{DomainName: string(data)}, nil } // Code returns the option code. diff --git a/dhcpv4/option_domain_name_server.go b/dhcpv4/option_domain_name_server.go index 470eaa0..8633cc4 100644 --- a/dhcpv4/option_domain_name_server.go +++ b/dhcpv4/option_domain_name_server.go @@ -17,26 +17,11 @@ type OptDomainNameServer struct { // ParseOptDomainNameServer returns a new OptDomainNameServer from a byte // stream, or error if any. func ParseOptDomainNameServer(data []byte) (*OptDomainNameServer, error) { - if len(data) < 2 { - return nil, ErrShortByteStream + ips, err := ParseIPs(data) + if err != nil { + return nil, err } - code := OptionCode(data[0]) - if code != OptionDomainNameServer { - return nil, fmt.Errorf("expected code %v, got %v", OptionDomainNameServer, code) - } - length := int(data[1]) - if length == 0 || length%4 != 0 { - return nil, fmt.Errorf("Invalid length: expected multiple of 4 larger than 4, got %v", length) - } - if len(data) < 2+length { - return nil, ErrShortByteStream - } - nameservers := make([]net.IP, 0, length%4) - for idx := 0; idx < length; idx += 4 { - b := data[2+idx : 2+idx+4] - nameservers = append(nameservers, net.IPv4(b[0], b[1], b[2], b[3])) - } - return &OptDomainNameServer{NameServers: nameservers}, nil + return &OptDomainNameServer{NameServers: ips}, nil } // Code returns the option code. diff --git a/dhcpv4/option_domain_name_server_test.go b/dhcpv4/option_domain_name_server_test.go index c801cb6..546d399 100644 --- a/dhcpv4/option_domain_name_server_test.go +++ b/dhcpv4/option_domain_name_server_test.go @@ -25,37 +25,23 @@ func TestParseOptDomainNameServer(t *testing.T) { 192, 168, 0, 10, // DNS #1 192, 168, 0, 20, // DNS #2 } - o, err := ParseOptDomainNameServer(data) + o, err := ParseOptDomainNameServer(data[2:]) require.NoError(t, err) servers := []net.IP{ - net.IPv4(192, 168, 0, 10), - net.IPv4(192, 168, 0, 20), + net.IP{192, 168, 0, 10}, + net.IP{192, 168, 0, 20}, } require.Equal(t, &OptDomainNameServer{NameServers: servers}, o) - // Short byte stream - data = []byte{byte(OptionDomainNameServer)} - _, err = ParseOptDomainNameServer(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptDomainNameServer(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{byte(OptionDomainNameServer), 6, 1, 1, 1} + data = []byte{1, 1, 1} _, err = ParseOptDomainNameServer(data) require.Error(t, err, "should get error from bad length") } func TestParseOptDomainNameServerNoServers(t *testing.T) { // RFC2132 requires that at least one DNS server IP is specified - data := []byte{ - byte(OptionDomainNameServer), - 0, // Length - } - _, err := ParseOptDomainNameServer(data) + _, err := ParseOptDomainNameServer([]byte{}) require.Error(t, err) } diff --git a/dhcpv4/option_domain_name_test.go b/dhcpv4/option_domain_name_test.go index ab66e8b..e88d87e 100644 --- a/dhcpv4/option_domain_name_test.go +++ b/dhcpv4/option_domain_name_test.go @@ -14,25 +14,10 @@ func TestOptDomainNameInterfaceMethods(t *testing.T) { } func TestParseOptDomainName(t *testing.T) { - data := []byte{byte(OptionDomainName), 4, 't', 'e', 's', 't'} // DISCOVER + data := []byte{'t', 'e', 's', 't'} o, err := ParseOptDomainName(data) require.NoError(t, err) require.Equal(t, &OptDomainName{DomainName: "test"}, o) - - // Short byte stream - data = []byte{byte(OptionDomainName)} - _, err = ParseOptDomainName(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptDomainName(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionDomainName), 6, 1, 1, 1} - _, err = ParseOptDomainName(data) - require.Error(t, err, "should get error from bad length") } func TestOptDomainNameString(t *testing.T) { diff --git a/dhcpv4/option_domain_search.go b/dhcpv4/option_domain_search.go index 9c24eea..5fafb6e 100644 --- a/dhcpv4/option_domain_search.go +++ b/dhcpv4/option_domain_search.go @@ -43,18 +43,7 @@ func (op *OptDomainSearch) String() string { // ParseOptDomainSearch returns a new OptDomainSearch from a byte stream, or // error if any. func ParseOptDomainSearch(data []byte) (*OptDomainSearch, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionDNSDomainSearchList { - return nil, fmt.Errorf("expected code %v, got %v", OptionDNSDomainSearchList, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - labels, err := rfc1035label.FromBytes(data[2 : length+2]) + labels, err := rfc1035label.FromBytes(data) if err != nil { return nil, err } diff --git a/dhcpv4/option_domain_search_test.go b/dhcpv4/option_domain_search_test.go index 590ccd0..0e5f8e9 100644 --- a/dhcpv4/option_domain_search_test.go +++ b/dhcpv4/option_domain_search_test.go @@ -14,7 +14,7 @@ func TestParseOptDomainSearch(t *testing.T) { 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0, } - opt, err := ParseOptDomainSearch(data) + opt, err := ParseOptDomainSearch(data[2:]) require.NoError(t, err) require.Equal(t, 2, len(opt.DomainSearch.Labels)) require.Equal(t, data[2:], opt.DomainSearch.ToBytes()) diff --git a/dhcpv4/option_generic.go b/dhcpv4/option_generic.go index 4ff35f8..964a655 100644 --- a/dhcpv4/option_generic.go +++ b/dhcpv4/option_generic.go @@ -15,23 +15,11 @@ type OptionGeneric struct { // ParseOptionGeneric parses a bytestream and creates a new OptionGeneric from // it, or an error. -func ParseOptionGeneric(data []byte) (*OptionGeneric, error) { +func ParseOptionGeneric(code OptionCode, data []byte) (Option, error) { if len(data) == 0 { return nil, errors.New("invalid zero-length bytestream") } - var ( - length int - optionData []byte - ) - code := OptionCode(data[0]) - if code != OptionPad && code != OptionEnd { - length = int(data[1]) - if len(data) < length+2 { - return nil, fmt.Errorf("invalid data length: declared %v, actual %v", length, len(data)) - } - optionData = data[2 : length+2] - } - return &OptionGeneric{OptionCode: code, Data: optionData}, nil + return &OptionGeneric{OptionCode: code, Data: data}, nil } // Code returns the generic option code. diff --git a/dhcpv4/option_generic_test.go b/dhcpv4/option_generic_test.go index dbc0fc1..5c34903 100644 --- a/dhcpv4/option_generic_test.go +++ b/dhcpv4/option_generic_test.go @@ -8,7 +8,7 @@ import ( func TestParseOptionGeneric(t *testing.T) { // Empty bytestream produces error - _, err := ParseOptionGeneric([]byte{}) + _, err := ParseOptionGeneric(OptionHostName, []byte{}) require.Error(t, err, "error from empty bytestream") } @@ -20,14 +20,6 @@ func TestOptionGenericCode(t *testing.T) { require.Equal(t, OptionDHCPMessageType, o.Code()) } -func TestOptionGenericData(t *testing.T) { - o := OptionGeneric{ - OptionCode: OptionNameServer, - Data: []byte{192, 168, 0, 1}, - } - require.Equal(t, []byte{192, 168, 0, 1}, o.Data) -} - func TestOptionGenericToBytes(t *testing.T) { o := OptionGeneric{ OptionCode: OptionDHCPMessageType, diff --git a/dhcpv4/option_host_name.go b/dhcpv4/option_host_name.go index a922a2b..decc18b 100644 --- a/dhcpv4/option_host_name.go +++ b/dhcpv4/option_host_name.go @@ -13,18 +13,7 @@ type OptHostName struct { // ParseOptHostName returns a new OptHostName from a byte stream, or error if // any. func ParseOptHostName(data []byte) (*OptHostName, error) { - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionHostName { - return nil, fmt.Errorf("expected code %v, got %v", OptionHostName, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - return &OptHostName{HostName: string(data[2 : 2+length])}, nil + return &OptHostName{HostName: string(data)}, nil } // Code returns the option code. diff --git a/dhcpv4/option_host_name_test.go b/dhcpv4/option_host_name_test.go index 7f99100..f5e8548 100644 --- a/dhcpv4/option_host_name_test.go +++ b/dhcpv4/option_host_name_test.go @@ -14,25 +14,10 @@ func TestOptHostNameInterfaceMethods(t *testing.T) { } func TestParseOptHostName(t *testing.T) { - data := []byte{byte(OptionHostName), 4, 't', 'e', 's', 't'} + data := []byte{'t', 'e', 's', 't'} o, err := ParseOptHostName(data) require.NoError(t, err) require.Equal(t, &OptHostName{HostName: "test"}, o) - - // Short byte stream - data = []byte{byte(OptionHostName)} - _, err = ParseOptHostName(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptHostName(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionHostName), 6, 1, 1, 1} - _, err = ParseOptHostName(data) - require.Error(t, err, "should get error from bad length") } func TestOptHostNameString(t *testing.T) { diff --git a/dhcpv4/option_ip_address_lease_time.go b/dhcpv4/option_ip_address_lease_time.go index 7562c58..2f63fb7 100644 --- a/dhcpv4/option_ip_address_lease_time.go +++ b/dhcpv4/option_ip_address_lease_time.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "encoding/binary" "fmt" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the IP Address Lease Time option @@ -16,20 +18,9 @@ type OptIPAddressLeaseTime struct { // ParseOptIPAddressLeaseTime constructs an OptIPAddressLeaseTime struct from a // sequence of bytes and returns it, or an error. func ParseOptIPAddressLeaseTime(data []byte) (*OptIPAddressLeaseTime, error) { - // Should at least have code, length, and lease time. - if len(data) < 6 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionIPAddressLeaseTime { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionIPAddressLeaseTime, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("expected length 4, got %v instead", length) - } - leaseTime := binary.BigEndian.Uint32(data[2:6]) - return &OptIPAddressLeaseTime{LeaseTime: leaseTime}, nil + buf := uio.NewBigEndianBuffer(data) + leaseTime := buf.Read32() + return &OptIPAddressLeaseTime{LeaseTime: leaseTime}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_ip_address_lease_time_test.go b/dhcpv4/option_ip_address_lease_time_test.go index 7d507bf..dafa6e4 100644 --- a/dhcpv4/option_ip_address_lease_time_test.go +++ b/dhcpv4/option_ip_address_lease_time_test.go @@ -14,23 +14,18 @@ func TestOptIPAddressLeaseTimeInterfaceMethods(t *testing.T) { } func TestParseOptIPAddressLeaseTime(t *testing.T) { - data := []byte{51, 4, 0, 0, 168, 192} + data := []byte{0, 0, 168, 192} o, err := ParseOptIPAddressLeaseTime(data) require.NoError(t, err) require.Equal(t, &OptIPAddressLeaseTime{LeaseTime: 43200}, o) // Short byte stream - data = []byte{51, 4, 168, 192} + data = []byte{168, 192} _, err = ParseOptIPAddressLeaseTime(data) require.Error(t, err, "should get error from short byte stream") - // Wrong code - data = []byte{54, 4, 0, 0, 168, 192} - _, err = ParseOptIPAddressLeaseTime(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{51, 5, 1, 1, 1, 1, 1} + data = []byte{1, 1, 1, 1, 1} _, err = ParseOptIPAddressLeaseTime(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/option_maximum_dhcp_message_size.go b/dhcpv4/option_maximum_dhcp_message_size.go index e5fedc6..26865fd 100644 --- a/dhcpv4/option_maximum_dhcp_message_size.go +++ b/dhcpv4/option_maximum_dhcp_message_size.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "encoding/binary" "fmt" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the Maximum DHCP Message size option @@ -16,20 +18,8 @@ type OptMaximumDHCPMessageSize struct { // ParseOptMaximumDHCPMessageSize constructs an OptMaximumDHCPMessageSize struct from a sequence of // bytes and returns it, or an error. func ParseOptMaximumDHCPMessageSize(data []byte) (*OptMaximumDHCPMessageSize, error) { - // Should at least have code, length, and message size. - if len(data) < 4 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionMaximumDHCPMessageSize { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionMaximumDHCPMessageSize, code) - } - length := int(data[1]) - if length != 2 { - return nil, fmt.Errorf("expected length 2, got %v instead", length) - } - msgSize := binary.BigEndian.Uint16(data[2:4]) - return &OptMaximumDHCPMessageSize{Size: msgSize}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptMaximumDHCPMessageSize{Size: buf.Read16()}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_maximum_dhcp_message_size_test.go b/dhcpv4/option_maximum_dhcp_message_size_test.go index f24b499..24ba49f 100644 --- a/dhcpv4/option_maximum_dhcp_message_size_test.go +++ b/dhcpv4/option_maximum_dhcp_message_size_test.go @@ -14,23 +14,18 @@ func TestOptMaximumDHCPMessageSizeInterfaceMethods(t *testing.T) { } func TestParseOptMaximumDHCPMessageSize(t *testing.T) { - data := []byte{57, 2, 5, 220} + data := []byte{5, 220} o, err := ParseOptMaximumDHCPMessageSize(data) require.NoError(t, err) require.Equal(t, &OptMaximumDHCPMessageSize{Size: 1500}, o) // Short byte stream - data = []byte{57, 2} + data = []byte{2} _, err = ParseOptMaximumDHCPMessageSize(data) require.Error(t, err, "should get error from short byte stream") - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptMaximumDHCPMessageSize(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{57, 3, 1, 1, 1} + data = []byte{1, 1, 1} _, err = ParseOptMaximumDHCPMessageSize(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/option_message_type.go b/dhcpv4/option_message_type.go index 903a57e..3e11a9a 100644 --- a/dhcpv4/option_message_type.go +++ b/dhcpv4/option_message_type.go @@ -2,6 +2,8 @@ package dhcpv4 import ( "fmt" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the message type option @@ -15,20 +17,8 @@ type OptMessageType struct { // ParseOptMessageType constructs an OptMessageType struct from a sequence of // bytes and returns it, or an error. func ParseOptMessageType(data []byte) (*OptMessageType, error) { - // Should at least have code, length, and message type. - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionDHCPMessageType { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionDHCPMessageType, code) - } - length := int(data[1]) - if length != 1 { - return nil, ErrShortByteStream - } - messageType := MessageType(data[2]) - return &OptMessageType{MessageType: messageType}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptMessageType{MessageType: MessageType(buf.Read8())}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_message_type_test.go b/dhcpv4/option_message_type_test.go index 59e6c17..091066e 100644 --- a/dhcpv4/option_message_type_test.go +++ b/dhcpv4/option_message_type_test.go @@ -20,23 +20,13 @@ func TestOptMessageTypeNew(t *testing.T) { } func TestParseOptMessageType(t *testing.T) { - data := []byte{53, 1, 1} // DISCOVER + data := []byte{1} // DISCOVER o, err := ParseOptMessageType(data) require.NoError(t, err) require.Equal(t, &OptMessageType{MessageType: MessageTypeDiscover}, o) - // Short byte stream - data = []byte{53, 1} - _, err = ParseOptMessageType(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 1, 1} - _, err = ParseOptMessageType(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{53, 5, 1} + data = []byte{1, 2} _, err = ParseOptMessageType(data) require.Error(t, err, "should get error from bad length") } diff --git a/dhcpv4/option_ntp_servers.go b/dhcpv4/option_ntp_servers.go index 39881d6..6d30920 100644 --- a/dhcpv4/option_ntp_servers.go +++ b/dhcpv4/option_ntp_servers.go @@ -15,26 +15,11 @@ type OptNTPServers struct { // ParseOptNTPServers returns a new OptNTPServers from a byte stream, or error if any. func ParseOptNTPServers(data []byte) (*OptNTPServers, error) { - if len(data) < 2 { - return nil, ErrShortByteStream + ips, err := ParseIPs(data) + if err != nil { + return nil, err } - code := OptionCode(data[0]) - if code != OptionNTPServers { - return nil, fmt.Errorf("expected code %v, got %v", OptionNTPServers, code) - } - length := int(data[1]) - if length == 0 || length%4 != 0 { - return nil, fmt.Errorf("Invalid length: expected multiple of 4 larger than 4, got %v", length) - } - if len(data) < 2+length { - return nil, ErrShortByteStream - } - ntpServers := make([]net.IP, 0, length%4) - for idx := 0; idx < length; idx += 4 { - b := data[2+idx : 2+idx+4] - ntpServers = append(ntpServers, net.IPv4(b[0], b[1], b[2], b[3])) - } - return &OptNTPServers{NTPServers: ntpServers}, nil + return &OptNTPServers{NTPServers: ips}, nil } // Code returns the option code. diff --git a/dhcpv4/option_ntp_servers_test.go b/dhcpv4/option_ntp_servers_test.go index e7bcefd..4d321ff 100644 --- a/dhcpv4/option_ntp_servers_test.go +++ b/dhcpv4/option_ntp_servers_test.go @@ -25,37 +25,23 @@ func TestParseOptNTPServers(t *testing.T) { 192, 168, 0, 10, // NTP server #1 192, 168, 0, 20, // NTP server #2 } - o, err := ParseOptNTPServers(data) + o, err := ParseOptNTPServers(data[2:]) require.NoError(t, err) ntpServers := []net.IP{ - net.IPv4(192, 168, 0, 10), - net.IPv4(192, 168, 0, 20), + net.IP{192, 168, 0, 10}, + net.IP{192, 168, 0, 20}, } require.Equal(t, &OptNTPServers{NTPServers: ntpServers}, o) - // Short byte stream - data = []byte{byte(OptionNTPServers)} - _, err = ParseOptNTPServers(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptNTPServers(data) - require.Error(t, err, "should get error from wrong code") - // Bad length - data = []byte{byte(OptionNTPServers), 6, 1, 1, 1} + data = []byte{1, 1, 1} _, err = ParseOptNTPServers(data) require.Error(t, err, "should get error from bad length") } func TestParseOptNTPserversNoNTPServers(t *testing.T) { // RFC2132 requires that at least one NTP server IP is specified - data := []byte{ - byte(OptionNTPServers), - 0, // Length - } - _, err := ParseOptNTPServers(data) + _, err := ParseOptNTPServers([]byte{}) require.Error(t, err) } diff --git a/dhcpv4/option_parameter_request_list.go b/dhcpv4/option_parameter_request_list.go index 865b2d7..8516e3b 100644 --- a/dhcpv4/option_parameter_request_list.go +++ b/dhcpv4/option_parameter_request_list.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "fmt" "strings" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the parameter request list option @@ -16,23 +18,12 @@ type OptParameterRequestList struct { // ParseOptParameterRequestList returns a new OptParameterRequestList from a // byte stream, or error if any. func ParseOptParameterRequestList(data []byte) (*OptParameterRequestList, error) { - // Should at least have code + length byte. - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionParameterRequestList { - return nil, fmt.Errorf("expected code %v, got %v", OptionParameterRequestList, code) - } - length := int(data[1]) - if len(data) < length+2 { - return nil, ErrShortByteStream - } - var requestedOpts []OptionCode - for _, opt := range data[2 : length+2] { - requestedOpts = append(requestedOpts, OptionCode(opt)) + buf := uio.NewBigEndianBuffer(data) + requestedOpts := make([]OptionCode, 0, buf.Len()) + for buf.Len() > 0 { + requestedOpts = append(requestedOpts, OptionCode(buf.Read8())) } - return &OptParameterRequestList{RequestedOpts: requestedOpts}, nil + return &OptParameterRequestList{RequestedOpts: requestedOpts}, buf.Error() } // Code returns the option code. diff --git a/dhcpv4/option_parameter_request_list_test.go b/dhcpv4/option_parameter_request_list_test.go index f600a70..d54fc0f 100644 --- a/dhcpv4/option_parameter_request_list_test.go +++ b/dhcpv4/option_parameter_request_list_test.go @@ -23,16 +23,7 @@ func TestParseOptParameterRequestList(t *testing.T) { o *OptParameterRequestList err error ) - o, err = ParseOptParameterRequestList([]byte{}) - require.Error(t, err, "empty byte stream") - - o, err = ParseOptParameterRequestList([]byte{55, 2}) - require.Error(t, err, "short byte stream") - - o, err = ParseOptParameterRequestList([]byte{53, 2, 1, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptParameterRequestList([]byte{55, 2, 67, 5}) + o, err = ParseOptParameterRequestList([]byte{67, 5}) require.NoError(t, err) expectedOpts := []OptionCode{OptionBootfileName, OptionNameServer} require.Equal(t, expectedOpts, o.RequestedOpts) diff --git a/dhcpv4/option_relay_agent_information.go b/dhcpv4/option_relay_agent_information.go index 447783e..d6547fe 100644 --- a/dhcpv4/option_relay_agent_information.go +++ b/dhcpv4/option_relay_agent_information.go @@ -8,42 +8,19 @@ import "fmt" // OptRelayAgentInformation is a "container" option for specific agent-supplied // sub-options. type OptRelayAgentInformation struct { - Options []Option + Options Options } // ParseOptRelayAgentInformation returns a new OptRelayAgentInformation from a // byte stream, or error if any. func ParseOptRelayAgentInformation(data []byte) (*OptRelayAgentInformation, error) { - if len(data) < 4 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionRelayAgentInformation { - return nil, fmt.Errorf("expected code %v, got %v", OptionRelayAgentInformation, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - options, err := OptionsFromBytesWithParser(data[2:length+2], relayParseOption) + options, err := OptionsFromBytesWithParser(data, ParseOptionGeneric, false /* don't check for OptionEnd tag */) if err != nil { return nil, err } return &OptRelayAgentInformation{Options: options}, nil } -func relayParseOption(data []byte) (Option, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - return &OptionGeneric{OptionCode: code, Data: data[2:length+2]}, nil -} - // Code returns the option code. func (o *OptRelayAgentInformation) Code() OptionCode { return OptionRelayAgentInformation diff --git a/dhcpv4/option_relay_agent_information_test.go b/dhcpv4/option_relay_agent_information_test.go index 1e99206..bff8ced 100644 --- a/dhcpv4/option_relay_agent_information_test.go +++ b/dhcpv4/option_relay_agent_information_test.go @@ -14,33 +14,21 @@ func TestParseOptRelayAgentInformation(t *testing.T) { 2, 4, 'b', 'o', 'o', 't', } - // short option bytes - opt, err := ParseOptRelayAgentInformation([]byte{}) - require.Error(t, err) - - // wrong code - opt, err = ParseOptRelayAgentInformation([]byte{1, 2, 1, 0}) - require.Error(t, err) - - // wrong option length - opt, err = ParseOptRelayAgentInformation([]byte{82, 3, 1, 0}) - require.Error(t, err) - // short sub-option bytes - opt, err = ParseOptRelayAgentInformation([]byte{82, 3, 1, 0, 1}) + opt, err := ParseOptRelayAgentInformation([]byte{1, 0, 1}) require.Error(t, err) // short sub-option length - opt, err = ParseOptRelayAgentInformation([]byte{82, 2, 1, 1}) + opt, err = ParseOptRelayAgentInformation([]byte{1, 1}) require.Error(t, err) - opt, err = ParseOptRelayAgentInformation(data) + opt, err = ParseOptRelayAgentInformation(data[2:]) require.NoError(t, err) require.Equal(t, len(opt.Options), 2) - circuit, ok := opt.Options[0].(*OptionGeneric) - require.True(t, ok) - remote, ok := opt.Options[1].(*OptionGeneric) - require.True(t, ok) + circuit := opt.Options.GetOneOption(1).(*OptionGeneric) + require.NoError(t, err) + remote := opt.Options.GetOneOption(2).(*OptionGeneric) + require.NoError(t, err) require.Equal(t, circuit.Data, []byte("linux")) require.Equal(t, remote.Data, []byte("boot")) } diff --git a/dhcpv4/option_requested_ip_address.go b/dhcpv4/option_requested_ip_address.go index 8662263..3539278 100644 --- a/dhcpv4/option_requested_ip_address.go +++ b/dhcpv4/option_requested_ip_address.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "fmt" "net" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the requested IP address option @@ -17,21 +19,8 @@ type OptRequestedIPAddress struct { // ParseOptRequestedIPAddress returns a new OptServerIdentifier from a byte // stream, or error if any. func ParseOptRequestedIPAddress(data []byte) (*OptRequestedIPAddress, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionRequestedIPAddress { - return nil, fmt.Errorf("expected code %v, got %v", OptionRequestedIPAddress, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("unexepcted length: expected 4, got %v", length) - } - if len(data) < 6 { - return nil, ErrShortByteStream - } - return &OptRequestedIPAddress{RequestedAddr: net.IP(data[2 : 2+length])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptRequestedIPAddress{RequestedAddr: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_requested_ip_address_test.go b/dhcpv4/option_requested_ip_address_test.go index efd6299..99592ea 100644 --- a/dhcpv4/option_requested_ip_address_test.go +++ b/dhcpv4/option_requested_ip_address_test.go @@ -29,16 +29,10 @@ func TestParseOptRequestedIPAddress(t *testing.T) { o, err = ParseOptRequestedIPAddress([]byte{}) require.Error(t, err, "empty byte stream") - o, err = ParseOptRequestedIPAddress([]byte{byte(OptionRequestedIPAddress), 4, 192}) - require.Error(t, err, "short byte stream") - - o, err = ParseOptRequestedIPAddress([]byte{byte(OptionRequestedIPAddress), 3, 192, 168, 0, 1}) + o, err = ParseOptRequestedIPAddress([]byte{192}) require.Error(t, err, "wrong IP length") - o, err = ParseOptRequestedIPAddress([]byte{53, 4, 192, 168, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptRequestedIPAddress([]byte{byte(OptionRequestedIPAddress), 4, 192, 168, 0, 1}) + o, err = ParseOptRequestedIPAddress([]byte{192, 168, 0, 1}) require.NoError(t, err) require.Equal(t, net.IP{192, 168, 0, 1}, o.RequestedAddr) } diff --git a/dhcpv4/option_root_path.go b/dhcpv4/option_root_path.go index 504ed17..ba6f03f 100644 --- a/dhcpv4/option_root_path.go +++ b/dhcpv4/option_root_path.go @@ -15,19 +15,7 @@ type OptRootPath struct { // ParseOptRootPath constructs an OptRootPath struct from a sequence of bytes // and returns it, or an error. func ParseOptRootPath(data []byte) (*OptRootPath, error) { - // Should at least have code and length - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionRootPath { - return nil, fmt.Errorf("expected option %v, got %v instead", OptionRootPath, code) - } - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - return &OptRootPath{Path: string(data[2 : 2+length])}, nil + return &OptRootPath{Path: string(data)}, nil } // Code returns the option code. diff --git a/dhcpv4/option_root_path_test.go b/dhcpv4/option_root_path_test.go index 53de45b..4bc7bc1 100644 --- a/dhcpv4/option_root_path_test.go +++ b/dhcpv4/option_root_path_test.go @@ -20,24 +20,9 @@ func TestOptRootPathInterfaceMethods(t *testing.T) { func TestParseOptRootPath(t *testing.T) { data := []byte{byte(OptionRootPath), 4, '/', 'f', 'o', 'o'} - o, err := ParseOptRootPath(data) + o, err := ParseOptRootPath(data[2:]) require.NoError(t, err) require.Equal(t, &OptRootPath{Path: "/foo"}, o) - - // Short byte stream - data = []byte{byte(OptionRootPath)} - _, err = ParseOptRootPath(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{43, 2, 1, 1} - _, err = ParseOptRootPath(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionRootPath), 6, 1, 1, 1} - _, err = ParseOptRootPath(data) - require.Error(t, err, "should get error from bad length") } func TestOptRootPathString(t *testing.T) { diff --git a/dhcpv4/option_router.go b/dhcpv4/option_router.go index 3154edd..60a96d1 100644 --- a/dhcpv4/option_router.go +++ b/dhcpv4/option_router.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "fmt" "net" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the router option @@ -13,28 +15,30 @@ type OptRouter struct { Routers []net.IP } -// ParseOptRouter returns a new OptRouter from a byte stream, or error if any. -func ParseOptRouter(data []byte) (*OptRouter, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionRouter { - return nil, fmt.Errorf("expected code %v, got %v", OptionRouter, code) - } - length := int(data[1]) - if length == 0 || length%4 != 0 { - return nil, fmt.Errorf("Invalid length: expected multiple of 4 larger than 4, got %v", length) +// ParseIPs parses an IPv4 address from a DHCP packet as used and specified by +// options in RFC 2132, Sections 3.5 through 3.13, 8.2, 8.3, 8.5, 8.6, 8.9, and +// 8.10. +func ParseIPs(data []byte) ([]net.IP, error) { + buf := uio.NewBigEndianBuffer(data) + + if buf.Len() == 0 { + return nil, fmt.Errorf("IP DHCP options must always list at least one IP") } - if len(data) < 2+length { - return nil, ErrShortByteStream + + ips := make([]net.IP, 0, buf.Len()/net.IPv4len) + for buf.Has(net.IPv4len) { + ips = append(ips, net.IP(buf.CopyN(net.IPv4len))) } - routers := make([]net.IP, 0, length%4) - for idx := 0; idx < length; idx += 4 { - b := data[2+idx : 2+idx+4] - routers = append(routers, net.IPv4(b[0], b[1], b[2], b[3])) + return ips, buf.FinError() +} + +// ParseOptRouter returns a new OptRouter from a byte stream, or error if any. +func ParseOptRouter(data []byte) (*OptRouter, error) { + ips, err := ParseIPs(data) + if err != nil { + return nil, err } - return &OptRouter{Routers: routers}, nil + return &OptRouter{Routers: ips}, nil } // Code returns the option code. diff --git a/dhcpv4/option_router_test.go b/dhcpv4/option_router_test.go index f492c22..3264ce3 100644 --- a/dhcpv4/option_router_test.go +++ b/dhcpv4/option_router_test.go @@ -25,11 +25,11 @@ func TestParseOptRouter(t *testing.T) { 192, 168, 0, 10, // Router #1 192, 168, 0, 20, // Router #2 } - o, err := ParseOptRouter(data) + o, err := ParseOptRouter(data[2:]) require.NoError(t, err) routers := []net.IP{ - net.IPv4(192, 168, 0, 10), - net.IPv4(192, 168, 0, 20), + net.IP{192, 168, 0, 10}, + net.IP{192, 168, 0, 20}, } require.Equal(t, &OptRouter{Routers: routers}, o) @@ -37,16 +37,6 @@ func TestParseOptRouter(t *testing.T) { data = []byte{byte(OptionRouter)} _, err = ParseOptRouter(data) require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptRouter(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionRouter), 6, 1, 1, 1} - _, err = ParseOptRouter(data) - require.Error(t, err, "should get error from bad length") } func TestParseOptRouterNoRouters(t *testing.T) { @@ -60,6 +50,6 @@ func TestParseOptRouterNoRouters(t *testing.T) { } func TestOptRouterString(t *testing.T) { - o := OptRouter{Routers: []net.IP{net.IPv4(192, 168, 0, 1), net.IPv4(192, 168, 0, 10)}} + o := OptRouter{Routers: []net.IP{net.IP{192, 168, 0, 1}, net.IP{192, 168, 0, 10}}} require.Equal(t, "Routers -> 192.168.0.1, 192.168.0.10", o.String()) } diff --git a/dhcpv4/option_server_identifier.go b/dhcpv4/option_server_identifier.go index 26c21a7..fd0311a 100644 --- a/dhcpv4/option_server_identifier.go +++ b/dhcpv4/option_server_identifier.go @@ -3,12 +3,14 @@ package dhcpv4 import ( "fmt" "net" + + "github.com/u-root/u-root/pkg/uio" ) +// OptServerIdentifier represents an option encapsulating the server identifier. +// // This option implements the server identifier option // https://tools.ietf.org/html/rfc2132 - -// OptServerIdentifier represents an option encapsulating the server identifier. type OptServerIdentifier struct { ServerID net.IP } @@ -16,21 +18,8 @@ type OptServerIdentifier struct { // ParseOptServerIdentifier returns a new OptServerIdentifier from a byte // stream, or error if any. func ParseOptServerIdentifier(data []byte) (*OptServerIdentifier, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionServerIdentifier { - return nil, fmt.Errorf("expected code %v, got %v", OptionServerIdentifier, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("unexepcted length: expected 4, got %v", length) - } - if len(data) < 6 { - return nil, ErrShortByteStream - } - return &OptServerIdentifier{ServerID: net.IP(data[2 : 2+length])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptServerIdentifier{ServerID: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_server_identifier_test.go b/dhcpv4/option_server_identifier_test.go index 2951ebb..4a85469 100644 --- a/dhcpv4/option_server_identifier_test.go +++ b/dhcpv4/option_server_identifier_test.go @@ -29,16 +29,10 @@ func TestParseOptServerIdentifier(t *testing.T) { o, err = ParseOptServerIdentifier([]byte{}) require.Error(t, err, "empty byte stream") - o, err = ParseOptServerIdentifier([]byte{54, 4, 192}) - require.Error(t, err, "short byte stream") - - o, err = ParseOptServerIdentifier([]byte{54, 3, 192, 168, 0, 1}) + o, err = ParseOptServerIdentifier([]byte{192, 168, 0}) require.Error(t, err, "wrong IP length") - o, err = ParseOptServerIdentifier([]byte{53, 4, 192, 168, 1}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptServerIdentifier([]byte{54, 4, 192, 168, 0, 1}) + o, err = ParseOptServerIdentifier([]byte{192, 168, 0, 1}) require.NoError(t, err) require.Equal(t, net.IP{192, 168, 0, 1}, o.ServerID) } diff --git a/dhcpv4/option_subnet_mask.go b/dhcpv4/option_subnet_mask.go index f1ff4a4..86ce004 100644 --- a/dhcpv4/option_subnet_mask.go +++ b/dhcpv4/option_subnet_mask.go @@ -3,6 +3,8 @@ package dhcpv4 import ( "fmt" "net" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the subnet mask option @@ -16,21 +18,8 @@ type OptSubnetMask struct { // ParseOptSubnetMask returns a new OptSubnetMask from a byte // stream, or error if any. func ParseOptSubnetMask(data []byte) (*OptSubnetMask, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionSubnetMask { - return nil, fmt.Errorf("expected code %v, got %v", OptionSubnetMask, code) - } - length := int(data[1]) - if length != 4 { - return nil, fmt.Errorf("unexepcted length: expected 4, got %v", length) - } - if len(data) < 6 { - return nil, ErrShortByteStream - } - return &OptSubnetMask{SubnetMask: net.IPMask(data[2 : 2+length])}, nil + buf := uio.NewBigEndianBuffer(data) + return &OptSubnetMask{SubnetMask: net.IPMask(buf.CopyN(net.IPv4len))}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_subnet_mask_test.go b/dhcpv4/option_subnet_mask_test.go index 4cb8819..e3c37bf 100644 --- a/dhcpv4/option_subnet_mask_test.go +++ b/dhcpv4/option_subnet_mask_test.go @@ -29,16 +29,10 @@ func TestParseOptSubnetMask(t *testing.T) { o, err = ParseOptSubnetMask([]byte{}) require.Error(t, err, "empty byte stream") - o, err = ParseOptSubnetMask([]byte{1, 4, 255}) + o, err = ParseOptSubnetMask([]byte{255}) require.Error(t, err, "short byte stream") - o, err = ParseOptSubnetMask([]byte{1, 3, 255, 255, 255, 0}) - require.Error(t, err, "wrong IP length") - - o, err = ParseOptSubnetMask([]byte{2, 4, 255, 255, 255}) - require.Error(t, err, "wrong option code") - - o, err = ParseOptSubnetMask([]byte{1, 4, 255, 255, 255, 0}) + o, err = ParseOptSubnetMask([]byte{255, 255, 255, 0}) require.NoError(t, err) require.Equal(t, net.IPMask{255, 255, 255, 0}, o.SubnetMask) } diff --git a/dhcpv4/option_tftp_server_name.go b/dhcpv4/option_tftp_server_name.go index 3a310f9..2a4af6d 100644 --- a/dhcpv4/option_tftp_server_name.go +++ b/dhcpv4/option_tftp_server_name.go @@ -9,7 +9,7 @@ import ( // OptTFTPServerName implements the TFTP server name option. type OptTFTPServerName struct { - TFTPServerName []byte + TFTPServerName string } // Code returns the option code @@ -19,7 +19,7 @@ func (op *OptTFTPServerName) Code() OptionCode { // ToBytes serializes the option and returns it as a sequence of bytes func (op *OptTFTPServerName) ToBytes() []byte { - return append([]byte{byte(op.Code()), byte(op.Length())}, op.TFTPServerName...) + return append([]byte{byte(op.Code()), byte(op.Length())}, []byte(op.TFTPServerName)...) } // Length returns the option length in bytes @@ -33,22 +33,5 @@ func (op *OptTFTPServerName) String() string { // ParseOptTFTPServerName returns a new OptTFTPServerName from a byte stream or error if any func ParseOptTFTPServerName(data []byte) (*OptTFTPServerName, error) { - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionTFTPServerName { - return nil, fmt.Errorf("ParseOptTFTPServerName: invalid code: %v; want %v", - code, OptionTFTPServerName) - } - length := int(data[1]) - if length < 1 { - return nil, fmt.Errorf("TFTP server name has invalid length of %d", length) - } - TFTPServerNameData := data[2:] - if len(TFTPServerNameData) < length { - return nil, fmt.Errorf("ParseOptTFTPServerName: short data: %d bytes; want %d", - len(TFTPServerNameData), length) - } - return &OptTFTPServerName{TFTPServerName: TFTPServerNameData[:length]}, nil + return &OptTFTPServerName{TFTPServerName: string(data)}, nil } diff --git a/dhcpv4/option_tftp_server_name_test.go b/dhcpv4/option_tftp_server_name_test.go index 812210f..54efef8 100644 --- a/dhcpv4/option_tftp_server_name_test.go +++ b/dhcpv4/option_tftp_server_name_test.go @@ -13,7 +13,7 @@ func TestOptTFTPServerNameCode(t *testing.T) { func TestOptTFTPServerNameToBytes(t *testing.T) { opt := OptTFTPServerName{ - TFTPServerName: []byte("linuxboot"), + TFTPServerName: "linuxboot", } data := opt.ToBytes() expected := []byte{ @@ -26,7 +26,7 @@ func TestOptTFTPServerNameToBytes(t *testing.T) { func TestParseOptTFTPServerName(t *testing.T) { expected := []byte{ - 66, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', + 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } opt, err := ParseOptTFTPServerName(expected) require.NoError(t, err) @@ -34,32 +34,7 @@ func TestParseOptTFTPServerName(t *testing.T) { require.Equal(t, "linuxboot", string(opt.TFTPServerName)) } -func TestParseOptTFTPServerNameZeroLength(t *testing.T) { - expected := []byte{ - 66, 0, - } - _, err := ParseOptTFTPServerName(expected) - require.Error(t, err) -} - -func TestParseOptTFTPServerNameInvalidLength(t *testing.T) { - expected := []byte{ - 66, 9, 'l', 'i', 'n', 'u', 'x', 'b', - } - _, err := ParseOptTFTPServerName(expected) - require.Error(t, err) -} - -func TestParseOptTFTPServerNameShortLength(t *testing.T) { - expected := []byte{ - 66, 4, 'l', 'i', 'n', 'u', 'x', - } - opt, err := ParseOptTFTPServerName(expected) - require.NoError(t, err) - require.Equal(t, []byte("linu"), opt.TFTPServerName) -} - func TestOptTFTPServerNameString(t *testing.T) { - o := OptTFTPServerName{TFTPServerName: []byte("testy test")} + o := OptTFTPServerName{TFTPServerName: "testy test"} require.Equal(t, "TFTP Server Name -> testy test", o.String()) } diff --git a/dhcpv4/option_userclass.go b/dhcpv4/option_userclass.go index d6ddabc..44ce090 100644 --- a/dhcpv4/option_userclass.go +++ b/dhcpv4/option_userclass.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "strings" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the User Class option @@ -12,7 +14,7 @@ import ( // OptUserClass represents an option encapsulating User Classes. type OptUserClass struct { UserClasses [][]byte - Rfc3004 bool + Rfc3004 bool } // Code returns the option code @@ -61,21 +63,7 @@ func (op *OptUserClass) String() string { // error if any func ParseOptUserClass(data []byte) (*OptUserClass, error) { opt := OptUserClass{} - - if len(data) < 3 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionUserClassInformation { - return nil, fmt.Errorf("expected code %v, got %v", OptionUserClassInformation, code) - } - - totalLength := int(data[1]) - data = data[2:] - if len(data) < totalLength { - return nil, fmt.Errorf("ParseOptUserClass: short data: length is %d but got %d bytes", - totalLength, len(data)) - } + buf := uio.NewBigEndianBuffer(data) // Check if option is Microsoft style instead of RFC compliant, issue #113 @@ -85,29 +73,24 @@ func ParseOptUserClass(data []byte) (*OptUserClass, error) { // option length. If the lengths don't add up, we assume that the option // is a single string and non RFC3004 compliant var counting int - for counting < totalLength { + for counting < buf.Len() { // UC_Len_i does not include itself so add 1 counting += int(data[counting]) + 1 } - if counting != totalLength { - opt.UserClasses = append(opt.UserClasses, data[:totalLength]) + if counting != buf.Len() { + opt.UserClasses = append(opt.UserClasses, data) return &opt, nil } opt.Rfc3004 = true - for i := 0; i < totalLength; { - ucLen := int(data[i]) + for buf.Has(1) { + ucLen := buf.Read8() if ucLen == 0 { - return nil, errors.New("User Class value has invalid length of 0") - } - base := i + 1 - if len(data) < base+ucLen { - return nil, fmt.Errorf("ParseOptUserClass: short data: %d bytes; want: %d", len(data), base+ucLen) + return nil, fmt.Errorf("DHCP user class must have length greater than 0") } - opt.UserClasses = append(opt.UserClasses, data[base:base+ucLen]) - i += base + ucLen + opt.UserClasses = append(opt.UserClasses, buf.CopyN(int(ucLen))) } - if len(opt.UserClasses) < 1 { + if len(opt.UserClasses) == 0 { return nil, errors.New("ParseOptUserClass: at least one user class is required") } - return &opt, nil + return &opt, buf.FinError() } diff --git a/dhcpv4/option_userclass_test.go b/dhcpv4/option_userclass_test.go index f6039df..d392ed8 100644 --- a/dhcpv4/option_userclass_test.go +++ b/dhcpv4/option_userclass_test.go @@ -9,7 +9,7 @@ import ( func TestOptUserClassToBytes(t *testing.T) { opt := OptUserClass{ UserClasses: [][]byte{[]byte("linuxboot")}, - Rfc3004: true, + Rfc3004: true, } data := opt.ToBytes() expected := []byte{ @@ -35,7 +35,6 @@ func TestOptUserClassMicrosoftToBytes(t *testing.T) { func TestParseOptUserClassMultiple(t *testing.T) { expected := []byte{ - 77, 15, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 4, 't', 'e', 's', 't', } @@ -54,7 +53,7 @@ func TestParseOptUserClassNone(t *testing.T) { func TestParseOptUserClassMicrosoft(t *testing.T) { expected := []byte{ - 77, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', + 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } opt, err := ParseOptUserClass(expected) require.NoError(t, err) @@ -64,7 +63,7 @@ func TestParseOptUserClassMicrosoft(t *testing.T) { func TestParseOptUserClassMicrosoftShort(t *testing.T) { expected := []byte{ - 77, 1, 'l', + 'l', } opt, err := ParseOptUserClass(expected) require.NoError(t, err) @@ -72,19 +71,9 @@ func TestParseOptUserClassMicrosoftShort(t *testing.T) { require.Equal(t, []byte("l"), opt.UserClasses[0]) } -func TestParseOptUserClassMicrosoftLongerThanLength(t *testing.T) { - expected := []byte{ - 77, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 'X', - } - opt, err := ParseOptUserClass(expected) - require.NoError(t, err) - require.Equal(t, 1, len(opt.UserClasses)) - require.Equal(t, []byte("linuxboot"), opt.UserClasses[0]) -} - func TestParseOptUserClass(t *testing.T) { expected := []byte{ - 77, 10, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', + 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } opt, err := ParseOptUserClass(expected) require.NoError(t, err) @@ -110,44 +99,18 @@ func TestOptUserClassToBytesMultiple(t *testing.T) { require.Equal(t, expected, data) } -func TestParseOptUserClassLongerThanLength(t *testing.T) { - expected := []byte{ - 77, 10, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 'X', - } - opt, err := ParseOptUserClass(expected) - require.NoError(t, err) - require.Equal(t, 1, len(opt.UserClasses)) - require.Equal(t, []byte("linuxboot"), opt.UserClasses[0]) -} - -func TestParseOptUserClassShorterTotalLength(t *testing.T) { - expected := []byte{ - 77, 11, 10, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', - } - _, err := ParseOptUserClass(expected) - require.Error(t, err) -} - func TestOptUserClassLength(t *testing.T) { expected := []byte{ - 77, 10, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 'X', + 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 'X', } opt, err := ParseOptUserClass(expected) require.NoError(t, err) - require.Equal(t, 10, opt.Length()) + require.Equal(t, 11, opt.Length()) } func TestParseOptUserClassZeroLength(t *testing.T) { expected := []byte{ - 77, 1, 0, 0, - } - _, err := ParseOptUserClass(expected) - require.Error(t, err) -} - -func TestParseOptUserClassMultipleWithZeroLength(t *testing.T) { - expected := []byte{ - 77, 12, 10, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 0, + 0, 0, } _, err := ParseOptUserClass(expected) require.Error(t, err) diff --git a/dhcpv4/option_vivc.go b/dhcpv4/option_vivc.go index 7576637..4ff42a3 100644 --- a/dhcpv4/option_vivc.go +++ b/dhcpv4/option_vivc.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "fmt" + + "github.com/u-root/u-root/pkg/uio" ) // This option implements the Vendor-Identifying Vendor Class Option @@ -23,39 +25,16 @@ type OptVIVC struct { // ParseOptVIVC contructs an OptVIVC tsruct from a sequence of bytes and returns // it, or an error. func ParseOptVIVC(data []byte) (*OptVIVC, error) { - if len(data) < 2 { - return nil, ErrShortByteStream - } - code := OptionCode(data[0]) - if code != OptionVendorIdentifyingVendorClass { - return nil, fmt.Errorf("expected code %v, got %v", OptionVendorIdentifyingVendorClass, code) - } - - length := int(data[1]) - if len(data) < 2+length { - return nil, ErrShortByteStream - } - data = data[2:length+2] - - ids := []VIVCIdentifier{} - for len(data) > 5 { - entID := binary.BigEndian.Uint32(data[0:4]) - idLen := int(data[4]) - data = data[5:] - - if idLen > len(data) { - return nil, ErrShortByteStream - } - - ids = append(ids, VIVCIdentifier{EntID: entID, Data: data[:idLen]}) - data = data[idLen:] - } + buf := uio.NewBigEndianBuffer(data) - if len(data) != 0 { - return nil, ErrShortByteStream + var ids []VIVCIdentifier + for buf.Has(5) { + entID := buf.Read32() + idLen := int(buf.Read8()) + ids = append(ids, VIVCIdentifier{EntID: entID, Data: buf.CopyN(idLen)}) } - return &OptVIVC{Identifiers: ids}, nil + return &OptVIVC{Identifiers: ids}, buf.FinError() } // Code returns the option code. diff --git a/dhcpv4/option_vivc_test.go b/dhcpv4/option_vivc_test.go index b290da0..0b4ab7c 100644 --- a/dhcpv4/option_vivc_test.go +++ b/dhcpv4/option_vivc_test.go @@ -31,36 +31,20 @@ func TestOptVIVCInterfaceMethods(t *testing.T) { } func TestParseOptVICO(t *testing.T) { - o, err := ParseOptVIVC(sampleVIVCOptRaw) + o, err := ParseOptVIVC(sampleVIVCOptRaw[2:]) require.NoError(t, err) require.Equal(t, &sampleVIVCOpt, o) - // Short byte stream - data := []byte{byte(OptionVendorIdentifyingVendorClass)} - _, err = ParseOptVIVC(data) - require.Error(t, err, "should get error from short byte stream") - - // Wrong code - data = []byte{54, 2, 1, 1} - _, err = ParseOptVIVC(data) - require.Error(t, err, "should get error from wrong code") - - // Bad length - data = []byte{byte(OptionVendorIdentifyingVendorClass), 6, 1, 1, 1} - _, err = ParseOptVIVC(data) - require.Error(t, err, "should get error from bad length") - // Identifier len too long - data = make([]byte, len(sampleVIVCOptRaw)) - copy(data, sampleVIVCOptRaw) - data[6] = 40 + data := make([]byte, len(sampleVIVCOptRaw[2:])) + copy(data, sampleVIVCOptRaw[2:]) + data[4] = 40 _, err = ParseOptVIVC(data) require.Error(t, err, "should get error from bad length") // Longer than length - data[1] = 10 - data[6] = 5 - o, err = ParseOptVIVC(data) + data[4] = 5 + o, err = ParseOptVIVC(data[:10]) require.NoError(t, err) require.Equal(t, o.Identifiers[0].Data, []byte("Cisco")) } diff --git a/dhcpv4/options.go b/dhcpv4/options.go index a5a51be..215da39 100644 --- a/dhcpv4/options.go +++ b/dhcpv4/options.go @@ -2,15 +2,26 @@ package dhcpv4 import ( "errors" + "fmt" + "io" + + "github.com/u-root/u-root/pkg/uio" ) -// ErrShortByteStream is an error that is thrown any time a short byte stream is -// detected during option parsing. -var ErrShortByteStream = errors.New("short byte stream") +var ( + // ErrShortByteStream is an error that is thrown any time a short byte stream is + // detected during option parsing. + ErrShortByteStream = errors.New("short byte stream") + + // ErrZeroLengthByteStream is an error that is thrown any time a zero-length + // byte stream is encountered. + ErrZeroLengthByteStream = errors.New("zero-length byte stream") -// ErrZeroLengthByteStream is an error that is thrown any time a zero-length -// byte stream is encountered. -var ErrZeroLengthByteStream = errors.New("zero-length byte stream") + // ErrInvalidOptions is returned when invalid options data is + // encountered during parsing. The data could report an incorrect + // length or have trailing bytes which are not part of the option. + ErrInvalidOptions = errors.New("invalid options data") +) // OptionCode is a single byte representing the code for a given Option. type OptionCode byte @@ -25,15 +36,12 @@ type Option interface { // ParseOption parses a sequence of bytes as a single DHCPv4 option, returning // the specific option structure or error, if any. -func ParseOption(data []byte) (Option, error) { - if len(data) == 0 { - return nil, errors.New("invalid zero-length DHCPv4 option") - } +func ParseOption(code OptionCode, data []byte) (Option, error) { var ( opt Option err error ) - switch OptionCode(data[0]) { + switch code { case OptionSubnetMask: opt, err = ParseOptSubnetMask(data) case OptionRouter: @@ -79,7 +87,7 @@ func ParseOption(data []byte) (Option, error) { case OptionVendorIdentifyingVendorClass: opt, err = ParseOptVIVC(data) default: - opt, err = ParseOptionGeneric(data) + opt, err = ParseOptionGeneric(code, data) } if err != nil { return nil, err @@ -94,41 +102,74 @@ func ParseOption(data []byte) (Option, error) { // // Returns an error if any invalid option or length is found. func OptionsFromBytes(data []byte) ([]Option, error) { - return OptionsFromBytesWithParser(data, ParseOption) + return OptionsFromBytesWithParser(data, ParseOption, true) } // OptionParser is a function signature for option parsing -type OptionParser func(data []byte) (Option, error) +type OptionParser func(code OptionCode, data []byte) (Option, error) // OptionsFromBytesWithParser parses Options from byte sequences using the // parsing function that is passed in as a paremeter -func OptionsFromBytesWithParser(data []byte, parser OptionParser) ([]Option, error) { - options := make([]Option, 0, 10) - idx := 0 - for { - if idx == len(data) { +func OptionsFromBytesWithParser(data []byte, parser OptionParser, checkEndOption bool) (Options, error) { + if len(data) == 0 { + return nil, nil + } + buf := uio.NewBigEndianBuffer(data) + options := make(map[OptionCode][]byte, 10) + var order []OptionCode + + // Due to RFC 3396 allowing an option to be specified multiple times, + // we have to collect all option data first, and then parse it. + var end bool + for buf.Len() >= 1 { + // 1 byte: option code + // 1 byte: option length n + // n bytes: data + code := OptionCode(buf.Read8()) + + if code == OptionPad { + continue + } else if code == OptionEnd { + end = true break } - // This should never happen. - if idx > len(data) { - return nil, errors.New("read past the end of options") + length := int(buf.Read8()) + + // N bytes: option data + data := buf.Consume(length) + if data == nil { + return nil, fmt.Errorf("error collecting options: %v", buf.Error()) } - opt, err := parser(data[idx:]) - idx++ - if err != nil { - return nil, err + data = data[:length:length] + + if _, ok := options[code]; !ok { + order = append(order, code) } - options = append(options, opt) - if opt.Code() == OptionEnd { - break + // RFC 3396: Just concatenate the data if the option code was + // specified multiple times. + options[code] = append(options[code], data...) + } + + // If we never read the End option, the sender of this packet screwed + // up. + if !end && checkEndOption { + return nil, io.ErrUnexpectedEOF + } + + // Any bytes left must be padding. + for buf.Len() >= 1 { + if OptionCode(buf.Read8()) != OptionPad { + return nil, ErrInvalidOptions } + } - // Options with zero length have no length byte, so here we handle the - // ones with nonzero length - if opt.Code() != OptionPad { - idx++ + opts := make(Options, 0, 10) + for _, code := range order { + parsedOpt, err := parser(code, options[code]) + if err != nil { + return nil, fmt.Errorf("error parsing option code %s: %v", code, err) } - idx += opt.Length() + opts = append(opts, parsedOpt) } - return options, nil + return opts, nil } diff --git a/dhcpv4/options_test.go b/dhcpv4/options_test.go index 2557af5..9a2f1c0 100644 --- a/dhcpv4/options_test.go +++ b/dhcpv4/options_test.go @@ -1,6 +1,9 @@ package dhcpv4 import ( + "bytes" + "fmt" + "math" "testing" "github.com/stretchr/testify/require" @@ -9,7 +12,7 @@ import ( func TestParseOption(t *testing.T) { // Generic option := []byte{5, 4, 192, 168, 1, 254} // DNS option - opt, err := ParseOption(option) + opt, err := ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) generic := opt.(*OptionGeneric) require.Equal(t, OptionNameServer, generic.Code()) @@ -19,7 +22,7 @@ func TestParseOption(t *testing.T) { // Option subnet mask option = []byte{1, 4, 255, 255, 255, 0} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionSubnetMask, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -27,7 +30,7 @@ func TestParseOption(t *testing.T) { // Option router option = []byte{3, 4, 192, 168, 1, 1} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionRouter, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -35,7 +38,7 @@ func TestParseOption(t *testing.T) { // Option domain name server option = []byte{6, 4, 192, 168, 1, 1} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionDomainNameServer, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -43,7 +46,7 @@ func TestParseOption(t *testing.T) { // Option host name option = []byte{12, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionHostName, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -51,7 +54,7 @@ func TestParseOption(t *testing.T) { // Option domain name option = []byte{15, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionDomainName, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -59,7 +62,7 @@ func TestParseOption(t *testing.T) { // Option root path option = []byte{17, 4, '/', 'f', 'o', 'o'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionRootPath, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -67,7 +70,7 @@ func TestParseOption(t *testing.T) { // Option broadcast address option = []byte{28, 4, 255, 255, 255, 255} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionBroadcastAddress, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -75,7 +78,7 @@ func TestParseOption(t *testing.T) { // Option NTP servers option = []byte{42, 4, 10, 10, 10, 10} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionNTPServers, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -83,7 +86,7 @@ func TestParseOption(t *testing.T) { // Requested IP address option = []byte{50, 4, 1, 2, 3, 4} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionRequestedIPAddress, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -91,7 +94,7 @@ func TestParseOption(t *testing.T) { // Requested IP address lease time option = []byte{51, 4, 0, 0, 0, 0} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionIPAddressLeaseTime, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -99,7 +102,7 @@ func TestParseOption(t *testing.T) { // Message type option = []byte{53, 1, 1} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionDHCPMessageType, opt.Code(), "Code") require.Equal(t, 1, opt.Length(), "Length") @@ -107,7 +110,7 @@ func TestParseOption(t *testing.T) { // Option server ID option = []byte{54, 4, 1, 2, 3, 4} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionServerIdentifier, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -115,7 +118,7 @@ func TestParseOption(t *testing.T) { // Parameter request list option = []byte{55, 3, 5, 53, 61} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionParameterRequestList, opt.Code(), "Code") require.Equal(t, 3, opt.Length(), "Length") @@ -123,7 +126,7 @@ func TestParseOption(t *testing.T) { // Option max message size option = []byte{57, 2, 1, 2} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionMaximumDHCPMessageSize, opt.Code(), "Code") require.Equal(t, 2, opt.Length(), "Length") @@ -131,7 +134,7 @@ func TestParseOption(t *testing.T) { // Option class identifier option = []byte{60, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionClassIdentifier, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -139,7 +142,7 @@ func TestParseOption(t *testing.T) { // Option TFTP server name option = []byte{66, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionTFTPServerName, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") @@ -147,7 +150,7 @@ func TestParseOption(t *testing.T) { // Option Bootfile name option = []byte{67, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionBootfileName, opt.Code(), "Code") require.Equal(t, 9, opt.Length(), "Length") @@ -155,37 +158,131 @@ func TestParseOption(t *testing.T) { // Option user class information option = []byte{77, 5, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionUserClassInformation, opt.Code(), "Code") require.Equal(t, 5, opt.Length(), "Length") require.Equal(t, option, opt.ToBytes(), "ToBytes") // Option relay agent information - option = []byte{82, 2, 1, 0} - opt, err = ParseOption(option) + option = []byte{82, 6, 1, 4, 129, 168, 0, 1} + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionRelayAgentInformation, opt.Code(), "Code") - require.Equal(t, 2, opt.Length(), "Length") + require.Equal(t, 6, opt.Length(), "Length") require.Equal(t, option, opt.ToBytes(), "ToBytes") // Option client system architecture type option option = []byte{93, 4, 't', 'e', 's', 't'} - opt, err = ParseOption(option) + opt, err = ParseOption(OptionCode(option[0]), option[2:]) require.NoError(t, err) require.Equal(t, OptionClientSystemArchitectureType, opt.Code(), "Code") require.Equal(t, 4, opt.Length(), "Length") require.Equal(t, option, opt.ToBytes(), "ToBytes") } -func TestParseOptionZeroLength(t *testing.T) { - option := []byte{} - _, err := ParseOption(option) - require.Error(t, err, "should get error from zero-length options") -} - -func TestParseOptionShortOption(t *testing.T) { - option := []byte{53, 1} - _, err := ParseOption(option) - require.Error(t, err, "should get error from short options") +func TestOptionsUnmarshal(t *testing.T) { + for i, tt := range []struct { + input []byte + want Options + wantError bool + }{ + { + // Buffer missing data. + input: []byte{ + 3 /* key */, 3 /* length */, 1, + }, + wantError: true, + }, + { + input: []byte{ + // This may look too long, but 0 is padding. + // The issue here is the missing OptionEnd. + 3, 3, 0, 0, 0, 0, 0, 0, 0, + }, + wantError: true, + }, + { + // Only OptionPad and OptionEnd can stand on their own + // without a length field. So this is too short. + input: []byte{ + 3, + }, + wantError: true, + }, + { + // Option present after the End is a nono. + input: []byte{byte(OptionEnd), 3}, + wantError: true, + }, + { + input: []byte{byte(OptionEnd)}, + want: Options{}, + }, + { + input: []byte{ + 3, 2, 5, 6, + byte(OptionEnd), + }, + want: Options{ + &OptionGeneric{ + OptionCode: 3, + Data: []byte{5, 6}, + }, + }, + }, + { + // Test RFC 3396. + input: append( + append([]byte{3, math.MaxUint8}, bytes.Repeat([]byte{10}, math.MaxUint8)...), + 3, 5, 10, 10, 10, 10, 10, + byte(OptionEnd), + ), + want: Options{ + &OptionGeneric{ + OptionCode: 3, + Data: bytes.Repeat([]byte{10}, math.MaxUint8+5), + }, + }, + }, + { + input: []byte{ + 10, 2, 255, 254, + 11, 3, 5, 5, 5, + byte(OptionEnd), + }, + want: Options{ + &OptionGeneric{ + OptionCode: 10, + Data: []byte{255, 254}, + }, + &OptionGeneric{ + OptionCode: 11, + Data: []byte{5, 5, 5}, + }, + }, + }, + { + input: append( + append([]byte{10, 2, 255, 254}, bytes.Repeat([]byte{byte(OptionPad)}, 255)...), + byte(OptionEnd), + ), + want: Options{ + &OptionGeneric{ + OptionCode: 10, + Data: []byte{255, 254}, + }, + }, + }, + } { + t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) { + opt, err := OptionsFromBytesWithParser(tt.input, ParseOptionGeneric, true) + if tt.wantError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, opt, tt.want) + } + }) + } } |