diff options
author | Christopher Koch <chrisko@google.com> | 2018-12-29 14:48:10 -0800 |
---|---|---|
committer | insomniac <insomniacslk@users.noreply.github.com> | 2019-01-24 08:05:49 +0000 |
commit | c90ab10024ada840e24bb028a3405961e8e4c26a (patch) | |
tree | 9b8af0c1b80ee6efc112921f9a14b92d6c73f8eb | |
parent | 2be5cae32d33f01ddecf6f167a9c0e5290e6d58f (diff) |
dhcpv4: nicer API for option parsing.
From:
r := d.GetOneOption(OptionRouter).(*OptRouter).Routers
d.UpdateOption(&OptRouter{Routers: []net.IP{net.IP{192, 168, 0, 1}}})
To:
r := GetRouter(d.Options)
d.UpdateOption(OptRouter(net.IP{192, 168, 0, 1}, ...))
71 files changed, 2114 insertions, 2640 deletions
diff --git a/dhcpv4/bsdp/boot_image.go b/dhcpv4/bsdp/boot_image.go index 954dcb6..58b5167 100644 --- a/dhcpv4/bsdp/boot_image.go +++ b/dhcpv4/bsdp/boot_image.go @@ -3,6 +3,7 @@ package bsdp import ( "fmt" + "github.com/insomniacslk/dhcp/dhcpv4" "github.com/u-root/u-root/pkg/uio" ) @@ -18,9 +19,9 @@ const ( // 4 - 127 are reserved for future use. ) -// BootImageTypeToString maps the different BootImageTypes to human-readable +// bootImageTypeToString maps the different BootImageTypes to human-readable // representations. -var BootImageTypeToString = map[BootImageType]string{ +var bootImageTypeToString = map[BootImageType]string{ BootImageTypeMacOS9: "macOS 9", BootImageTypeMacOSX: "macOS", BootImageTypeMacOSXServer: "macOS Server", @@ -35,6 +36,16 @@ type BootImageID struct { Index uint16 } +// ToBytes implements dhcpv4.OptionValue. +func (b BootImageID) ToBytes() []byte { + return uio.ToBigEndian(b) +} + +// FromBytes reads data into b. +func (b *BootImageID) FromBytes(data []byte) error { + return uio.FromBigEndian(b, data) +} + // Marshal writes the binary representation to buf. func (b BootImageID) Marshal(buf *uio.Lexer) { var byte0 byte @@ -55,7 +66,7 @@ func (b BootImageID) String() string { } else { s += " uninstallable" } - t, ok := BootImageTypeToString[b.ImageType] + t, ok := bootImageTypeToString[b.ImageType] if !ok { t = "unknown" } @@ -99,3 +110,37 @@ func (b *BootImage) Unmarshal(buf *uio.Lexer) error { b.Name = string(buf.Consume(int(nameLength))) return buf.Error() } + +func getBootImageID(code dhcpv4.OptionCode, o dhcpv4.Options) *BootImageID { + v := o.Get(code) + if v == nil { + return nil + } + var b BootImageID + if err := uio.FromBigEndian(&b, v); err != nil { + return nil + } + return &b +} + +// OptDefaultBootImageID returns a new default boot image ID option as per +// BSDP. +func OptDefaultBootImageID(b BootImageID) dhcpv4.Option { + return dhcpv4.Option{Code: OptionDefaultBootImageID, Value: b} +} + +// GetDefaultBootImageID returns the default boot image ID contained in o. +func GetDefaultBootImageID(o dhcpv4.Options) *BootImageID { + return getBootImageID(OptionDefaultBootImageID, o) +} + +// OptSelectedBootImageID returns a new selected boot image ID option as per +// BSDP. +func OptSelectedBootImageID(b BootImageID) dhcpv4.Option { + return dhcpv4.Option{Code: OptionSelectedBootImageID, Value: b} +} + +// GetSelectedBootImageID returns the selected boot image ID contained in o. +func GetSelectedBootImageID(o dhcpv4.Options) *BootImageID { + return getBootImageID(OptionSelectedBootImageID, o) +} diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go index 3cc87d2..9bcc15d 100644 --- a/dhcpv4/bsdp/bsdp.go +++ b/dhcpv4/bsdp/bsdp.go @@ -30,20 +30,12 @@ type ReplyConfig struct { // ParseBootImageListFromAck parses the list of boot images presented in the // ACK[LIST] packet and returns them as a list of BootImages. -func ParseBootImageListFromAck(ack dhcpv4.DHCPv4) ([]BootImage, error) { - opt := ack.GetOneOption(dhcpv4.OptionVendorSpecificInformation) - if opt == nil { +func ParseBootImageListFromAck(ack *dhcpv4.DHCPv4) ([]BootImage, error) { + vendorOpts := GetVendorOptions(ack.Options) + if vendorOpts == nil { return nil, errors.New("ParseBootImageListFromAck: could not find vendor-specific option") } - vendorOpt, err := ParseOptVendorSpecificInformation(opt.ToBytes()) - if err != nil { - return nil, err - } - bootImageOpts := vendorOpt.GetOneOption(OptionBootImageList) - if bootImageOpts == nil { - return nil, fmt.Errorf("boot image option not found") - } - return bootImageOpts.(*OptBootImageList).Images, nil + return GetBootImageList(vendorOpts.Options), nil } func needsReplyPort(replyPort uint16) bool { @@ -53,28 +45,41 @@ func needsReplyPort(replyPort uint16) bool { // MessageTypeFromPacket extracts the BSDP message type (LIST, SELECT) from the // vendor-specific options and returns it. If the message type option cannot be // found, returns false. -func MessageTypeFromPacket(packet *dhcpv4.DHCPv4) *MessageType { - var ( - vendorOpts *OptVendorSpecificInformation - err error - ) - opt := packet.GetOneOption(dhcpv4.OptionVendorSpecificInformation) - if opt == nil { - return nil - } - if vendorOpts, err = ParseOptVendorSpecificInformation(opt.ToBytes()); err == nil { - if o := vendorOpts.GetOneOption(OptionMessageType); o != nil { - if optMessageType, ok := o.(*OptMessageType); ok { - return &optMessageType.Type - } - } - } - return nil +func MessageTypeFromPacket(packet *dhcpv4.DHCPv4) MessageType { + vendorOpts := GetVendorOptions(packet.Options) + if vendorOpts == nil { + return MessageTypeNone + } + return GetMessageType(vendorOpts.Options) +} + +// Packet is a BSDP packet wrapper around a DHCPv4 packet in order to print the +// correct vendor-specific BSDP information in String(). +type Packet struct { + dhcpv4.DHCPv4 +} + +// PacketFor returns a wrapped BSDP Packet given a DHCPv4 packet. +func PacketFor(d *dhcpv4.DHCPv4) *Packet { + return &Packet{*d} +} + +func (p Packet) v4() *dhcpv4.DHCPv4 { + return &p.DHCPv4 +} + +func (p Packet) String() string { + return p.DHCPv4.String() +} + +// Summary prints the BSDP packet with the correct vendor-specific options. +func (p Packet) Summary() string { + return p.DHCPv4.SummaryWithVendor(&VendorOptions{}) } // NewInformListForInterface creates a new INFORM packet for interface ifname // with configuration options specified by config. -func NewInformListForInterface(ifname string, replyPort uint16) (*dhcpv4.DHCPv4, error) { +func NewInformListForInterface(ifname string, replyPort uint16) (*Packet, error) { iface, err := net.InterfaceByName(ifname) if err != nil { return nil, err @@ -96,7 +101,7 @@ func NewInformListForInterface(ifname string, replyPort uint16) (*dhcpv4.DHCPv4, // NewInformList creates a new INFORM packet for interface with hardware address // `hwaddr` and IP `localIP`. Packet will be sent out on port `replyPort`. -func NewInformList(hwaddr net.HardwareAddr, localIP net.IP, replyPort uint16, modifiers ...dhcpv4.Modifier) (*dhcpv4.DHCPv4, error) { +func NewInformList(hwaddr net.HardwareAddr, localIP net.IP, replyPort uint16, modifiers ...dhcpv4.Modifier) (*Packet, error) { // Validate replyPort first if needsReplyPort(replyPort) && replyPort >= 1024 { return nil, errors.New("replyPort must be a privileged port") @@ -109,60 +114,61 @@ func NewInformList(hwaddr net.HardwareAddr, localIP net.IP, replyPort uint16, mo // These are vendor-specific options used to pass along BSDP information. vendorOpts := []dhcpv4.Option{ - &OptMessageType{MessageTypeList}, - Version1_1, + OptMessageType(MessageTypeList), + OptVersion(Version1_1), } if needsReplyPort(replyPort) { - vendorOpts = append(vendorOpts, &OptReplyPort{replyPort}) + vendorOpts = append(vendorOpts, OptReplyPort(replyPort)) } - return dhcpv4.NewInform(hwaddr, localIP, + d, err := dhcpv4.NewInform(hwaddr, localIP, dhcpv4.PrependModifiers(modifiers, dhcpv4.WithRequestedOptions( dhcpv4.OptionVendorSpecificInformation, dhcpv4.OptionClassIdentifier, ), - dhcpv4.WithOption(&dhcpv4.OptMaximumDHCPMessageSize{Size: MaxDHCPMessageSize}), - dhcpv4.WithOption(&dhcpv4.OptClassIdentifier{Identifier: vendorClassID}), - dhcpv4.WithOption(&OptVendorSpecificInformation{vendorOpts}), + dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxDHCPMessageSize)), + dhcpv4.WithOption(dhcpv4.OptClassIdentifier(vendorClassID)), + dhcpv4.WithOption(OptVendorOptions(vendorOpts...)), )...) + if err != nil { + return nil, err + } + return PacketFor(d), nil } // InformSelectForAck constructs an INFORM[SELECT] packet given an ACK to the // previously-sent INFORM[LIST]. -func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootImage) (*dhcpv4.DHCPv4, error) { +func InformSelectForAck(ack *Packet, replyPort uint16, selectedImage BootImage) (*Packet, error) { if needsReplyPort(replyPort) && replyPort >= 1024 { return nil, errors.New("replyPort must be a privileged port") } // Data for OptionSelectedBootImageID vendorOpts := []dhcpv4.Option{ - &OptMessageType{MessageTypeSelect}, - Version1_1, - &OptSelectedBootImageID{selectedImage.ID}, + OptMessageType(MessageTypeSelect), + OptVersion(Version1_1), + OptSelectedBootImageID(selectedImage.ID), } // Validate replyPort if requested. if needsReplyPort(replyPort) { - vendorOpts = append(vendorOpts, &OptReplyPort{replyPort}) + vendorOpts = append(vendorOpts, OptReplyPort(replyPort)) } // Find server IP address - var serverIP net.IP - if opt := ack.GetOneOption(dhcpv4.OptionServerIdentifier); opt != nil { - serverIP = opt.(*dhcpv4.OptServerIdentifier).ServerID - } + serverIP := dhcpv4.GetServerIdentifier(ack.Options) if serverIP.To4() == nil { return nil, fmt.Errorf("could not parse server identifier from ACK") } - vendorOpts = append(vendorOpts, &OptServerIdentifier{serverIP}) + vendorOpts = append(vendorOpts, OptServerIdentifier(serverIP)) vendorClassID, err := MakeVendorClassIdentifier() if err != nil { return nil, err } - return dhcpv4.New(dhcpv4.WithReply(&ack), - dhcpv4.WithOption(&dhcpv4.OptClassIdentifier{Identifier: vendorClassID}), + d, err := dhcpv4.New(dhcpv4.WithReply(ack.v4()), + dhcpv4.WithOption(dhcpv4.OptClassIdentifier(vendorClassID)), dhcpv4.WithRequestedOptions( dhcpv4.OptionSubnetMask, dhcpv4.OptionRouter, @@ -171,20 +177,24 @@ func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootI dhcpv4.OptionClassIdentifier, ), dhcpv4.WithMessageType(dhcpv4.MessageTypeInform), - dhcpv4.WithOption(&OptVendorSpecificInformation{vendorOpts}), + dhcpv4.WithOption(OptVendorOptions(vendorOpts...)), ) + if err != nil { + return nil, err + } + return PacketFor(d), nil } // NewReplyForInformList constructs an ACK for the INFORM[LIST] packet `inform` // with additional options in `config`. -func NewReplyForInformList(inform *dhcpv4.DHCPv4, config ReplyConfig) (*dhcpv4.DHCPv4, error) { +func NewReplyForInformList(inform *Packet, config ReplyConfig) (*Packet, error) { if config.DefaultImage == nil { return nil, errors.New("NewReplyForInformList: no default boot image ID set") } if config.Images == nil || len(config.Images) == 0 { return nil, errors.New("NewReplyForInformList: no boot images provided") } - reply, err := dhcpv4.NewReplyFromRequest(inform) + reply, err := dhcpv4.NewReplyFromRequest(&inform.DHCPv4) if err != nil { return nil, err } @@ -193,34 +203,34 @@ func NewReplyForInformList(inform *dhcpv4.DHCPv4, config ReplyConfig) (*dhcpv4.D reply.ServerIPAddr = config.ServerIP reply.ServerHostName = config.ServerHostname - reply.UpdateOption(&dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeAck}) - reply.UpdateOption(&dhcpv4.OptServerIdentifier{ServerID: config.ServerIP}) - reply.UpdateOption(&dhcpv4.OptClassIdentifier{Identifier: AppleVendorID}) + reply.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck)) + reply.UpdateOption(dhcpv4.OptServerIdentifier(config.ServerIP)) + reply.UpdateOption(dhcpv4.OptClassIdentifier(AppleVendorID)) // BSDP opts. vendorOpts := []dhcpv4.Option{ - &OptMessageType{Type: MessageTypeList}, - &OptServerPriority{Priority: config.ServerPriority}, - &OptDefaultBootImageID{ID: config.DefaultImage.ID}, - &OptBootImageList{Images: config.Images}, + OptMessageType(MessageTypeList), + OptServerPriority(config.ServerPriority), + OptDefaultBootImageID(config.DefaultImage.ID), + OptBootImageList(config.Images...), } if config.SelectedImage != nil { - vendorOpts = append(vendorOpts, &OptSelectedBootImageID{ID: config.SelectedImage.ID}) + vendorOpts = append(vendorOpts, OptSelectedBootImageID(config.SelectedImage.ID)) } - reply.UpdateOption(&OptVendorSpecificInformation{Options: vendorOpts}) - return reply, nil + reply.UpdateOption(OptVendorOptions(vendorOpts...)) + return PacketFor(reply), nil } // NewReplyForInformSelect constructs an ACK for the INFORM[Select] packet // `inform` with additional options in `config`. -func NewReplyForInformSelect(inform *dhcpv4.DHCPv4, config ReplyConfig) (*dhcpv4.DHCPv4, error) { +func NewReplyForInformSelect(inform *Packet, config ReplyConfig) (*Packet, error) { if config.SelectedImage == nil { return nil, errors.New("NewReplyForInformSelect: no selected boot image ID set") } if config.Images == nil || len(config.Images) == 0 { return nil, errors.New("NewReplyForInformSelect: no boot images provided") } - reply, err := dhcpv4.NewReplyFromRequest(inform) + reply, err := dhcpv4.NewReplyFromRequest(&inform.DHCPv4) if err != nil { return nil, err } @@ -231,16 +241,14 @@ func NewReplyForInformSelect(inform *dhcpv4.DHCPv4, config ReplyConfig) (*dhcpv4 reply.ServerHostName = config.ServerHostname reply.BootFileName = config.BootFileName - reply.UpdateOption(&dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeAck}) - reply.UpdateOption(&dhcpv4.OptServerIdentifier{ServerID: config.ServerIP}) - reply.UpdateOption(&dhcpv4.OptClassIdentifier{Identifier: AppleVendorID}) + reply.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck)) + reply.UpdateOption(dhcpv4.OptServerIdentifier(config.ServerIP)) + reply.UpdateOption(dhcpv4.OptClassIdentifier(AppleVendorID)) // BSDP opts. - reply.UpdateOption(&OptVendorSpecificInformation{ - Options: []dhcpv4.Option{ - &OptMessageType{Type: MessageTypeSelect}, - &OptSelectedBootImageID{ID: config.SelectedImage.ID}, - }, - }) - return reply, nil + reply.UpdateOption(OptVendorOptions( + OptMessageType(MessageTypeSelect), + OptSelectedBootImageID(config.SelectedImage.ID), + )) + return PacketFor(reply), nil } diff --git a/dhcpv4/bsdp/bsdp_option_boot_image_list.go b/dhcpv4/bsdp/bsdp_option_boot_image_list.go index 3282fa3..ebbbd2d 100644 --- a/dhcpv4/bsdp/bsdp_option_boot_image_list.go +++ b/dhcpv4/bsdp/bsdp_option_boot_image_list.go @@ -1,52 +1,66 @@ package bsdp import ( + "strings" + "github.com/insomniacslk/dhcp/dhcpv4" "github.com/u-root/u-root/pkg/uio" ) -// OptBootImageList contains the list of boot images presented by a netboot -// server. -type OptBootImageList struct { - Images []BootImage -} +// BootImageList contains a list of boot images presented by a netboot server. +// +// Implements the BSDP option listing the boot images. +type BootImageList []BootImage -// ParseOptBootImageList constructs an OptBootImageList struct from a sequence -// of bytes and returns it, or an error. -func ParseOptBootImageList(data []byte) (*OptBootImageList, error) { +// FromBytes deserializes data into bil. +func (bil *BootImageList) FromBytes(data []byte) error { buf := uio.NewBigEndianBuffer(data) - var bootImages []BootImage for buf.Has(5) { var image BootImage - if err := (&image).Unmarshal(buf); err != nil { - return nil, err + if err := image.Unmarshal(buf); err != nil { + return err } - bootImages = append(bootImages, image) + *bil = append(*bil, image) } - - return &OptBootImageList{bootImages}, nil -} - -// Code returns the option code. -func (o *OptBootImageList) Code() dhcpv4.OptionCode { - return OptionBootImageList + return nil } // ToBytes returns a serialized stream of bytes for this option. -func (o *OptBootImageList) ToBytes() []byte { +func (bil BootImageList) ToBytes() []byte { buf := uio.NewBigEndianBuffer(nil) - for _, image := range o.Images { + for _, image := range bil { image.Marshal(buf) } return buf.Data() } // String returns a human-readable string for this option. -func (o *OptBootImageList) String() string { - s := "BSDP Boot Image List ->" - for _, image := range o.Images { - s += "\n " + image.String() +func (bil BootImageList) String() string { + s := make([]string, 0, len(bil)) + for _, image := range bil { + s = append(s, image.String()) + } + return strings.Join(s, ", ") +} + +// OptBootImageList returns a new BSDP boot image list. +func OptBootImageList(b ...BootImage) dhcpv4.Option { + return dhcpv4.Option{ + Code: OptionBootImageList, + Value: BootImageList(b), + } +} + +// GetBootImageList returns the BSDP boot image list. +func GetBootImageList(o dhcpv4.Options) BootImageList { + v := o.Get(OptionBootImageList) + if v == nil { + return nil + } + var bil BootImageList + if err := bil.FromBytes(v); err != nil { + return nil } - return s + return bil } diff --git a/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go b/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go index 5d1b77c..6282156 100644 --- a/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go +++ b/dhcpv4/bsdp/bsdp_option_boot_image_list_test.go @@ -25,8 +25,8 @@ func TestOptBootImageListInterfaceMethods(t *testing.T) { Name: "bsdp-2", }, } - o := OptBootImageList{bs} - require.Equal(t, OptionBootImageList, o.Code(), "Code") + o := OptBootImageList(bs...) + require.Equal(t, OptionBootImageList, o.Code, "Code") expectedBytes := []byte{ // boot image 1 0x1, 0x0, 0x03, 0xe9, // ID @@ -37,7 +37,7 @@ func TestOptBootImageListInterfaceMethods(t *testing.T) { 6, // name length 'b', 's', 'd', 'p', '-', '2', } - require.Equal(t, expectedBytes, o.ToBytes(), "ToBytes") + require.Equal(t, expectedBytes, o.Value.ToBytes(), "ToBytes") } func TestParseOptBootImageList(t *testing.T) { @@ -51,9 +51,10 @@ func TestParseOptBootImageList(t *testing.T) { 6, // name length 'b', 's', 'd', 'p', '-', '2', } - o, err := ParseOptBootImageList(data) + var o BootImageList + err := o.FromBytes(data) require.NoError(t, err) - expectedBootImages := []BootImage{ + expectedBootImages := BootImageList{ BootImage{ ID: BootImageID{ IsInstall: false, @@ -71,7 +72,7 @@ func TestParseOptBootImageList(t *testing.T) { Name: "bsdp-2", }, } - require.Equal(t, &OptBootImageList{expectedBootImages}, o) + require.Equal(t, expectedBootImages, o) // Error parsing boot image (malformed) data = []byte{ @@ -84,7 +85,7 @@ func TestParseOptBootImageList(t *testing.T) { 6, // name length 'b', 's', 'd', 'p', '-', '2', } - _, err = ParseOptBootImageList(data) + err = o.FromBytes(data) require.Error(t, err, "should get error from bad boot image") } @@ -107,7 +108,7 @@ func TestOptBootImageListString(t *testing.T) { Name: "bsdp-2", }, } - o := OptBootImageList{bs} - expectedString := "BSDP Boot Image List ->\n bsdp-1 [1001] uninstallable macOS image\n bsdp-2 [9009] installable macOS 9 image" + o := OptBootImageList(bs...) + expectedString := "BSDP Boot Image List: bsdp-1 [1001] uninstallable macOS image, bsdp-2 [9009] installable macOS 9 image" require.Equal(t, expectedString, o.String()) } diff --git a/dhcpv4/bsdp/bsdp_option_default_boot_image_id.go b/dhcpv4/bsdp/bsdp_option_default_boot_image_id.go deleted file mode 100644 index 40ab0be..0000000 --- a/dhcpv4/bsdp/bsdp_option_default_boot_image_id.go +++ /dev/null @@ -1,42 +0,0 @@ -package bsdp - -import ( - "fmt" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/u-root/u-root/pkg/uio" -) - -// OptDefaultBootImageID contains the selected boot image ID. -// -// Implements the BSDP option default boot image ID, which tells the client -// which image is the default boot image if one is not selected. -type OptDefaultBootImageID struct { - ID BootImageID -} - -// ParseOptDefaultBootImageID constructs an OptDefaultBootImageID struct from a sequence of -// bytes and returns it, or an error. -func ParseOptDefaultBootImageID(data []byte) (*OptDefaultBootImageID, error) { - var o OptDefaultBootImageID - buf := uio.NewBigEndianBuffer(data) - if err := o.ID.Unmarshal(buf); err != nil { - return nil, err - } - return &o, buf.FinError() -} - -// Code returns the option code. -func (o *OptDefaultBootImageID) Code() dhcpv4.OptionCode { - return OptionDefaultBootImageID -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptDefaultBootImageID) ToBytes() []byte { - return uio.ToBigEndian(o.ID) -} - -// String returns a human-readable string for this option. -func (o *OptDefaultBootImageID) String() string { - return fmt.Sprintf("BSDP Default Boot Image ID -> %s", o.ID.String()) -} diff --git a/dhcpv4/bsdp/bsdp_option_default_boot_image_id_test.go b/dhcpv4/bsdp/bsdp_option_default_boot_image_id_test.go deleted file mode 100644 index a5abdaf..0000000 --- a/dhcpv4/bsdp/bsdp_option_default_boot_image_id_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package bsdp - -import ( - "testing" - - "github.com/stretchr/testify/require" - "github.com/u-root/u-root/pkg/uio" -) - -func TestOptDefaultBootImageIDInterfaceMethods(t *testing.T) { - b := BootImageID{IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 1001} - o := OptDefaultBootImageID{b} - require.Equal(t, OptionDefaultBootImageID, o.Code(), "Code") - require.Equal(t, uio.ToBigEndian(b), o.ToBytes(), "ToBytes") -} - -func TestParseOptDefaultBootImageID(t *testing.T) { - b := BootImageID{IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 1001} - o, err := ParseOptDefaultBootImageID(uio.ToBigEndian(b)) - require.NoError(t, err) - require.Equal(t, &OptDefaultBootImageID{b}, o) - - // Short byte stream - data := []byte{} - _, err = ParseOptDefaultBootImageID(data) - require.Error(t, err, "should get error from short byte stream") - - // Bad length - data = []byte{1, 0, 0, 0, 0} - _, err = ParseOptDefaultBootImageID(data) - require.Error(t, err, "should get error from bad length") -} - -func TestOptDefaultBootImageIDString(t *testing.T) { - b := BootImageID{IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 1001} - o := OptDefaultBootImageID{b} - require.Equal(t, "BSDP Default Boot Image ID -> [1001] installable macOS image", o.String()) - - b = BootImageID{IsInstall: false, ImageType: BootImageTypeMacOS9, Index: 1001} - o = OptDefaultBootImageID{b} - require.Equal(t, "BSDP Default Boot Image ID -> [1001] uninstallable macOS 9 image", o.String()) - - b = BootImageID{IsInstall: false, ImageType: BootImageType(99), Index: 1001} - o = OptDefaultBootImageID{b} - require.Equal(t, "BSDP Default Boot Image ID -> [1001] uninstallable unknown image", o.String()) -} diff --git a/dhcpv4/bsdp/bsdp_option_generic.go b/dhcpv4/bsdp/bsdp_option_generic.go deleted file mode 100644 index e9e163f..0000000 --- a/dhcpv4/bsdp/bsdp_option_generic.go +++ /dev/null @@ -1,36 +0,0 @@ -package bsdp - -import ( - "fmt" - - "github.com/insomniacslk/dhcp/dhcpv4" -) - -// OptGeneric is an option that only contains the option code and associated -// data. Every option that does not have a specific implementation will fall -// back to this option. -type OptGeneric struct { - OptionCode dhcpv4.OptionCode - Data []byte -} - -// ParseOptGeneric parses a bytestream and creates a new OptGeneric from it, -// or an error. -func ParseOptGeneric(code dhcpv4.OptionCode, data []byte) (*OptGeneric, error) { - return &OptGeneric{OptionCode: code, Data: data}, nil -} - -// Code returns the generic option code. -func (o OptGeneric) Code() dhcpv4.OptionCode { - return o.OptionCode -} - -// ToBytes returns a serialized generic option as a slice of bytes. -func (o OptGeneric) ToBytes() []byte { - return o.Data -} - -// String returns a human-readable representation of a generic option. -func (o OptGeneric) String() string { - return fmt.Sprintf("%s -> %v", o.OptionCode, o.Data) -} diff --git a/dhcpv4/bsdp/bsdp_option_generic_test.go b/dhcpv4/bsdp/bsdp_option_generic_test.go deleted file mode 100644 index a813f95..0000000 --- a/dhcpv4/bsdp/bsdp_option_generic_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package bsdp - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestParseOptGeneric(t *testing.T) { - // Good parse - o, err := ParseOptGeneric(OptionMessageType, []byte{1}) - require.NoError(t, err) - require.Equal(t, OptionMessageType, o.Code()) - require.Equal(t, MessageTypeList, MessageType(o.Data[0])) -} - -func TestOptGenericCode(t *testing.T) { - o := OptGeneric{ - OptionCode: OptionMessageType, - Data: []byte{byte(MessageTypeList)}, - } - require.Equal(t, OptionMessageType, o.Code()) -} - -func TestOptGenericData(t *testing.T) { - o := OptGeneric{ - OptionCode: OptionServerIdentifier, - Data: []byte{192, 168, 0, 1}, - } - require.Equal(t, []byte{192, 168, 0, 1}, o.Data) -} - -func TestOptGenericToBytes(t *testing.T) { - o := OptGeneric{ - OptionCode: OptionServerIdentifier, - Data: []byte{192, 168, 0, 1}, - } - serialized := o.ToBytes() - expected := []byte{192, 168, 0, 1} - require.Equal(t, expected, serialized) -} - -func TestOptGenericString(t *testing.T) { - o := OptGeneric{ - OptionCode: OptionServerIdentifier, - Data: []byte{192, 168, 0, 1}, - } - require.Equal(t, "BSDP Server Identifier -> [192 168 0 1]", o.String()) -} - -func TestOptGenericStringUnknown(t *testing.T) { - o := OptGeneric{ - OptionCode: optionCode(102), // Returned option code. - Data: []byte{5}, - } - require.Equal(t, "unknown -> [5]", o.String()) -} diff --git a/dhcpv4/bsdp/bsdp_option_machine_name.go b/dhcpv4/bsdp/bsdp_option_machine_name.go deleted file mode 100644 index ced88b0..0000000 --- a/dhcpv4/bsdp/bsdp_option_machine_name.go +++ /dev/null @@ -1,34 +0,0 @@ -package bsdp - -import ( - "github.com/insomniacslk/dhcp/dhcpv4" -) - -// OptMachineName represents a BSDP message type. -// -// Implements the BSDP option machine name, which gives the Netboot server's -// machine name. -type OptMachineName struct { - Name string -} - -// ParseOptMachineName constructs an OptMachineName struct from a sequence of -// bytes and returns it, or an error. -func ParseOptMachineName(data []byte) (*OptMachineName, error) { - return &OptMachineName{Name: string(data)}, nil -} - -// Code returns the option code. -func (o *OptMachineName) Code() dhcpv4.OptionCode { - return OptionMachineName -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptMachineName) ToBytes() []byte { - return []byte(o.Name) -} - -// String returns a human-readable string for this option. -func (o *OptMachineName) String() string { - return "BSDP Machine Name -> " + o.Name -} diff --git a/dhcpv4/bsdp/bsdp_option_machine_name_test.go b/dhcpv4/bsdp/bsdp_option_machine_name_test.go deleted file mode 100644 index abc0d54..0000000 --- a/dhcpv4/bsdp/bsdp_option_machine_name_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package bsdp - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestOptMachineNameInterfaceMethods(t *testing.T) { - o := OptMachineName{"somebox"} - require.Equal(t, OptionMachineName, o.Code(), "Code") - expectedBytes := []byte{'s', 'o', 'm', 'e', 'b', 'o', 'x'} - require.Equal(t, expectedBytes, o.ToBytes(), "ToBytes") -} - -func TestParseOptMachineName(t *testing.T) { - data := []byte{'s', 'o', 'm', 'e', 'b', 'o', 'x'} - o, err := ParseOptMachineName(data) - require.NoError(t, err) - require.Equal(t, &OptMachineName{"somebox"}, o) -} - -func TestOptMachineNameString(t *testing.T) { - o := OptMachineName{"somebox"} - require.Equal(t, "BSDP Machine Name -> somebox", o.String()) -} diff --git a/dhcpv4/bsdp/bsdp_option_message_type.go b/dhcpv4/bsdp/bsdp_option_message_type.go index 5f96f12..cb0c5cf 100644 --- a/dhcpv4/bsdp/bsdp_option_message_type.go +++ b/dhcpv4/bsdp/bsdp_option_message_type.go @@ -15,16 +15,23 @@ type MessageType byte // BSDP Message types - e.g. LIST, SELECT, FAILED const ( + MessageTypeNone MessageType = 0 MessageTypeList MessageType = 1 MessageTypeSelect MessageType = 2 MessageTypeFailed MessageType = 3 ) +// ToBytes returns a serialized stream of bytes for this option. +func (m MessageType) ToBytes() []byte { + return []byte{byte(m)} +} + +// String returns a human-friendly representation of MessageType. func (m MessageType) String() string { if s, ok := messageTypeToString[m]; ok { return s } - return "Unknown" + return fmt.Sprintf("unknown (%d)", m) } // messageTypeToString maps each BSDP message type to a human-readable string. @@ -34,29 +41,30 @@ var messageTypeToString = map[MessageType]string{ MessageTypeFailed: "FAILED", } -// OptMessageType represents a BSDP message type. -type OptMessageType struct { - Type MessageType -} - -// ParseOptMessageType constructs an OptMessageType struct from a sequence of -// bytes and returns it, or an error. -func ParseOptMessageType(data []byte) (*OptMessageType, error) { +// FromBytes reads data into m. +func (m *MessageType) FromBytes(data []byte) error { buf := uio.NewBigEndianBuffer(data) - return &OptMessageType{Type: MessageType(buf.Read8())}, buf.FinError() + *m = MessageType(buf.Read8()) + return buf.FinError() } -// Code returns the option code. -func (o *OptMessageType) Code() dhcpv4.OptionCode { - return OptionMessageType -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptMessageType) ToBytes() []byte { - return []byte{byte(o.Type)} +// OptMessageType returns a new BSDP Message Type option. +func OptMessageType(mt MessageType) dhcpv4.Option { + return dhcpv4.Option{ + Code: OptionMessageType, + Value: mt, + } } -// String returns a human-readable string for this option. -func (o *OptMessageType) String() string { - return fmt.Sprintf("BSDP Message Type -> %s", o.Type.String()) +// GetMessageType returns the BSDP Message Type in o. +func GetMessageType(o dhcpv4.Options) MessageType { + v := o.Get(OptionMessageType) + if v == nil { + return MessageTypeNone + } + var m MessageType + if err := m.FromBytes(v); err != nil { + return MessageTypeNone + } + return m } diff --git a/dhcpv4/bsdp/bsdp_option_message_type_test.go b/dhcpv4/bsdp/bsdp_option_message_type_test.go index a6695cc..6666137 100644 --- a/dhcpv4/bsdp/bsdp_option_message_type_test.go +++ b/dhcpv4/bsdp/bsdp_option_message_type_test.go @@ -7,24 +7,25 @@ import ( ) func TestOptMessageTypeInterfaceMethods(t *testing.T) { - o := OptMessageType{MessageTypeList} - require.Equal(t, OptionMessageType, o.Code(), "Code") - require.Equal(t, []byte{1}, o.ToBytes(), "ToBytes") + o := OptMessageType(MessageTypeList) + require.Equal(t, OptionMessageType, o.Code, "Code") + require.Equal(t, []byte{1}, o.Value.ToBytes(), "ToBytes") } func TestParseOptMessageType(t *testing.T) { + var o MessageType data := []byte{1} // DISCOVER - o, err := ParseOptMessageType(data) + err := o.FromBytes(data) require.NoError(t, err) - require.Equal(t, &OptMessageType{MessageTypeList}, o) + require.Equal(t, MessageTypeList, o) } func TestOptMessageTypeString(t *testing.T) { // known - o := OptMessageType{MessageTypeList} - require.Equal(t, "BSDP Message Type -> LIST", o.String()) + o := OptMessageType(MessageTypeList) + require.Equal(t, "BSDP Message Type: LIST", o.String()) // unknown - o = OptMessageType{99} - require.Equal(t, "BSDP Message Type -> Unknown", o.String()) + o = OptMessageType(99) + require.Equal(t, "BSDP Message Type: unknown (99)", o.String()) } diff --git a/dhcpv4/bsdp/bsdp_option_misc.go b/dhcpv4/bsdp/bsdp_option_misc.go new file mode 100644 index 0000000..2d3a7bf --- /dev/null +++ b/dhcpv4/bsdp/bsdp_option_misc.go @@ -0,0 +1,99 @@ +package bsdp + +import ( + "fmt" + "net" + + "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/u-root/u-root/pkg/uio" +) + +// OptReplyPort returns a new BSDP reply port option. +// +// Implements the BSDP option reply port. This is used when BSDP responses +// should be sent to a reply port other than the DHCP default. The macOS GUI +// "Startup Disk Select" sends this option since it's operating in an +// unprivileged context. +func OptReplyPort(port uint16) dhcpv4.Option { + return dhcpv4.Option{Code: OptionReplyPort, Value: dhcpv4.Uint16(port)} +} + +// GetReplyPort returns the BSDP reply port in o, if present. +func GetReplyPort(o dhcpv4.Options) (uint16, error) { + return dhcpv4.GetUint16(OptionReplyPort, o) +} + +// OptServerPriority returns a new BSDP server priority option. +func OptServerPriority(prio uint16) dhcpv4.Option { + return dhcpv4.Option{Code: OptionServerPriority, Value: dhcpv4.Uint16(prio)} +} + +// GetServerPriority returns the BSDP server priority in o if present. +func GetServerPriority(o dhcpv4.Options) (uint16, error) { + return dhcpv4.GetUint16(OptionServerPriority, o) +} + +// OptMachineName returns a BSDP Machine Name option. +func OptMachineName(name string) dhcpv4.Option { + return dhcpv4.Option{Code: OptionMachineName, Value: dhcpv4.String(name)} +} + +// GetMachineName finds and parses the BSDP Machine Name option from o. +func GetMachineName(o dhcpv4.Options) string { + return dhcpv4.GetString(OptionMachineName, o) +} + +// Version is the BSDP protocol version. Can be one of 1.0 or 1.1. +type Version [2]byte + +// Specific versions. +var ( + Version1_0 = Version{1, 0} + Version1_1 = Version{1, 1} +) + +// ToBytes returns a serialized stream of bytes for this option. +func (o Version) ToBytes() []byte { + return o[:] +} + +// String returns a human-readable string for this option. +func (o Version) String() string { + return fmt.Sprintf("%d.%d", o[0], o[1]) +} + +// FromBytes constructs a Version struct from a sequence of +// bytes and returns it, or an error. +func (o *Version) FromBytes(data []byte) error { + buf := uio.NewBigEndianBuffer(data) + buf.ReadBytes(o[:]) + return buf.FinError() +} + +// OptVersion returns a new BSDP version option. +func OptVersion(version Version) dhcpv4.Option { + return dhcpv4.Option{Code: OptionVersion, Value: version} +} + +// GetVersion returns the BSDP version in o if present. +func GetVersion(o dhcpv4.Options) (Version, error) { + v := o.Get(OptionVersion) + if v == nil { + return Version{0, 0}, fmt.Errorf("version not found") + } + var ver Version + if err := ver.FromBytes(v); err != nil { + return Version{0, 0}, err + } + return ver, nil +} + +// GetServerIdentifier returns the BSDP Server Identifier value in o. +func GetServerIdentifier(o dhcpv4.Options) net.IP { + return dhcpv4.GetIP(OptionServerIdentifier, o) +} + +// OptServerIdentifier returns a new BSDP Server Identifier option. +func OptServerIdentifier(ip net.IP) dhcpv4.Option { + return dhcpv4.Option{Code: OptionServerIdentifier, Value: dhcpv4.IP(ip)} +} diff --git a/dhcpv4/bsdp/bsdp_option_misc_test.go b/dhcpv4/bsdp/bsdp_option_misc_test.go new file mode 100644 index 0000000..dfa81b5 --- /dev/null +++ b/dhcpv4/bsdp/bsdp_option_misc_test.go @@ -0,0 +1,95 @@ +package bsdp + +import ( + "net" + "testing" + + "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/stretchr/testify/require" +) + +func TestOptReplyPort(t *testing.T) { + o := OptReplyPort(1234) + require.Equal(t, OptionReplyPort, o.Code, "Code") + require.Equal(t, []byte{4, 210}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "BSDP Reply Port: 1234", o.String()) +} + +func TestGetReplyPort(t *testing.T) { + o := VendorOptions{dhcpv4.OptionsFromList(OptReplyPort(1234))} + port, err := GetReplyPort(o.Options) + require.NoError(t, err) + require.Equal(t, uint16(1234), port) + + port, err = GetReplyPort(dhcpv4.Options{}) + require.Error(t, err, "no reply port present") +} + +func TestOptServerPriority(t *testing.T) { + o := OptServerPriority(1234) + require.Equal(t, OptionServerPriority, o.Code, "Code") + require.Equal(t, []byte{4, 210}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "BSDP Server Priority: 1234", o.String()) +} + +func TestGetServerPriority(t *testing.T) { + o := VendorOptions{dhcpv4.OptionsFromList(OptServerPriority(1234))} + prio, err := GetServerPriority(o.Options) + require.NoError(t, err) + require.Equal(t, uint16(1234), prio) + + prio, err = GetServerPriority(dhcpv4.Options{}) + require.Error(t, err, "no server prio present") +} + +func TestOptMachineName(t *testing.T) { + o := OptMachineName("foo") + require.Equal(t, OptionMachineName, o.Code, "Code") + require.Equal(t, []byte("foo"), o.Value.ToBytes(), "ToBytes") + require.Equal(t, "BSDP Machine Name: foo", o.String()) +} + +func TestGetMachineName(t *testing.T) { + o := VendorOptions{dhcpv4.OptionsFromList(OptMachineName("foo"))} + require.Equal(t, "foo", GetMachineName(o.Options)) + require.Equal(t, "", GetMachineName(dhcpv4.Options{})) +} + +func TestOptVersion(t *testing.T) { + o := OptVersion(Version1_1) + require.Equal(t, OptionVersion, o.Code, "Code") + require.Equal(t, []byte{1, 1}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "BSDP Version: 1.1", o.String()) +} + +func TestGetVersion(t *testing.T) { + o := VendorOptions{dhcpv4.OptionsFromList(OptVersion(Version1_1))} + ver, err := GetVersion(o.Options) + require.NoError(t, err) + require.Equal(t, ver, Version1_1) + + ver, err = GetVersion(dhcpv4.Options{}) + require.Error(t, err, "no version present") + + ver, err = GetVersion(dhcpv4.Options{OptionVersion.Code(): []byte{}}) + require.Error(t, err, "empty version field") + + ver, err = GetVersion(dhcpv4.Options{OptionVersion.Code(): []byte{1}}) + require.Error(t, err, "version option too short") + + ver, err = GetVersion(dhcpv4.Options{OptionVersion.Code(): []byte{1, 2, 3}}) + require.Error(t, err, "version option too long") +} + +func TestOptServerIdentifier(t *testing.T) { + o := OptServerIdentifier(net.IP{1, 1, 1, 1}) + require.Equal(t, OptionServerIdentifier, o.Code, "Code") + require.Equal(t, []byte{1, 1, 1, 1}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "BSDP Server Identifier: 1.1.1.1", o.String()) +} + +func TestGetServerIdentifier(t *testing.T) { + o := VendorOptions{dhcpv4.OptionsFromList(OptServerIdentifier(net.IP{1, 1, 1, 1}))} + require.Equal(t, net.IP{1, 1, 1, 1}, GetServerIdentifier(o.Options)) + require.Equal(t, net.IP(nil), GetServerIdentifier(dhcpv4.Options{})) +} diff --git a/dhcpv4/bsdp/bsdp_option_reply_port.go b/dhcpv4/bsdp/bsdp_option_reply_port.go deleted file mode 100644 index 5eea5ee..0000000 --- a/dhcpv4/bsdp/bsdp_option_reply_port.go +++ /dev/null @@ -1,42 +0,0 @@ -package bsdp - -import ( - "fmt" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/u-root/u-root/pkg/uio" -) - -// OptReplyPort represents a BSDP protocol version. -// -// Implements the BSDP option reply port. This is used when BSDP responses -// should be sent to a reply port other than the DHCP default. The macOS GUI -// "Startup Disk Select" sends this option since it's operating in an -// unprivileged context. -type OptReplyPort struct { - Port uint16 -} - -// ParseOptReplyPort constructs an OptReplyPort struct from a sequence of -// bytes and returns it, or an error. -func ParseOptReplyPort(data []byte) (*OptReplyPort, error) { - buf := uio.NewBigEndianBuffer(data) - return &OptReplyPort{buf.Read16()}, buf.FinError() -} - -// Code returns the option code. -func (o *OptReplyPort) Code() dhcpv4.OptionCode { - return OptionReplyPort -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptReplyPort) ToBytes() []byte { - buf := uio.NewBigEndianBuffer(nil) - buf.Write16(o.Port) - return buf.Data() -} - -// String returns a human-readable string for this option. -func (o *OptReplyPort) String() string { - return fmt.Sprintf("BSDP Reply Port -> %v", o.Port) -} diff --git a/dhcpv4/bsdp/bsdp_option_reply_port_test.go b/dhcpv4/bsdp/bsdp_option_reply_port_test.go deleted file mode 100644 index de94ffb..0000000 --- a/dhcpv4/bsdp/bsdp_option_reply_port_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package bsdp - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestOptReplyPortInterfaceMethods(t *testing.T) { - o := OptReplyPort{1234} - require.Equal(t, OptionReplyPort, o.Code(), "Code") - require.Equal(t, []byte{4, 210}, o.ToBytes(), "ToBytes") -} - -func TestParseOptReplyPort(t *testing.T) { - data := []byte{0, 1} - o, err := ParseOptReplyPort(data) - require.NoError(t, err) - require.Equal(t, &OptReplyPort{1}, o) - - // Short byte stream - data = []byte{} - _, err = ParseOptReplyPort(data) - require.Error(t, err, "should get error from short byte stream") - - // Bad length - data = []byte{1} - _, err = ParseOptReplyPort(data) - require.Error(t, err, "should get error from bad length") -} - -func TestOptReplyPortString(t *testing.T) { - // known - o := OptReplyPort{1234} - require.Equal(t, "BSDP Reply Port -> 1234", o.String()) -} diff --git a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id.go b/dhcpv4/bsdp/bsdp_option_selected_boot_image_id.go deleted file mode 100644 index 67f99a8..0000000 --- a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id.go +++ /dev/null @@ -1,42 +0,0 @@ -package bsdp - -import ( - "fmt" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/u-root/u-root/pkg/uio" -) - -// OptSelectedBootImageID contains the selected boot image ID. -// -// Implements the BSDP option selected boot image ID, which tells the server -// which boot image has been selected by the client. -type OptSelectedBootImageID struct { - ID BootImageID -} - -// ParseOptSelectedBootImageID constructs an OptSelectedBootImageID struct from a sequence of -// bytes and returns it, or an error. -func ParseOptSelectedBootImageID(data []byte) (*OptSelectedBootImageID, error) { - var o OptSelectedBootImageID - buf := uio.NewBigEndianBuffer(data) - if err := o.ID.Unmarshal(buf); err != nil { - return nil, err - } - return &o, buf.FinError() -} - -// Code returns the option code. -func (o *OptSelectedBootImageID) Code() dhcpv4.OptionCode { - return OptionSelectedBootImageID -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptSelectedBootImageID) ToBytes() []byte { - return uio.ToBigEndian(o.ID) -} - -// String returns a human-readable string for this option. -func (o *OptSelectedBootImageID) String() string { - return fmt.Sprintf("BSDP Selected Boot Image ID -> %s", o.ID.String()) -} diff --git a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id_test.go b/dhcpv4/bsdp/bsdp_option_selected_boot_image_id_test.go deleted file mode 100644 index e187fc7..0000000 --- a/dhcpv4/bsdp/bsdp_option_selected_boot_image_id_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package bsdp - -import ( - "testing" - - "github.com/stretchr/testify/require" - "github.com/u-root/u-root/pkg/uio" -) - -func TestOptSelectedBootImageIDInterfaceMethods(t *testing.T) { - b := BootImageID{IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 1001} - o := OptSelectedBootImageID{b} - require.Equal(t, OptionSelectedBootImageID, o.Code(), "Code") - require.Equal(t, uio.ToBigEndian(b), o.ToBytes(), "ToBytes") -} - -func TestParseOptSelectedBootImageID(t *testing.T) { - b := BootImageID{IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 1001} - o, err := ParseOptSelectedBootImageID(uio.ToBigEndian(b)) - require.NoError(t, err) - require.Equal(t, &OptSelectedBootImageID{b}, o) - - // Short byte stream - data := []byte{} - _, err = ParseOptSelectedBootImageID(data) - require.Error(t, err, "should get error from short byte stream") - - // Bad length - data = []byte{1, 0, 0, 0, 0} - _, err = ParseOptSelectedBootImageID(data) - require.Error(t, err, "should get error from bad length") -} - -func TestOptSelectedBootImageIDString(t *testing.T) { - b := BootImageID{IsInstall: true, ImageType: BootImageTypeMacOSX, Index: 1001} - o := OptSelectedBootImageID{b} - require.Equal(t, "BSDP Selected Boot Image ID -> [1001] installable macOS image", o.String()) - - b = BootImageID{IsInstall: false, ImageType: BootImageTypeMacOS9, Index: 1001} - o = OptSelectedBootImageID{b} - require.Equal(t, "BSDP Selected Boot Image ID -> [1001] uninstallable macOS 9 image", o.String()) - - b = BootImageID{IsInstall: false, ImageType: BootImageType(99), Index: 1001} - o = OptSelectedBootImageID{b} - require.Equal(t, "BSDP Selected Boot Image ID -> [1001] uninstallable unknown image", o.String()) -} diff --git a/dhcpv4/bsdp/bsdp_option_server_identifier.go b/dhcpv4/bsdp/bsdp_option_server_identifier.go deleted file mode 100644 index d1f5b6c..0000000 --- a/dhcpv4/bsdp/bsdp_option_server_identifier.go +++ /dev/null @@ -1,36 +0,0 @@ -package bsdp - -import ( - "fmt" - "net" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/u-root/u-root/pkg/uio" -) - -// OptServerIdentifier implements the BSDP server identifier option. -type OptServerIdentifier struct { - ServerID net.IP -} - -// ParseOptServerIdentifier returns a new OptServerIdentifier from a byte -// stream, or error if any. -func ParseOptServerIdentifier(data []byte) (*OptServerIdentifier, error) { - buf := uio.NewBigEndianBuffer(data) - return &OptServerIdentifier{ServerID: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() -} - -// Code returns the option code. -func (o *OptServerIdentifier) Code() dhcpv4.OptionCode { - return OptionServerIdentifier -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptServerIdentifier) ToBytes() []byte { - return o.ServerID.To4() -} - -// String returns a human-readable string. -func (o *OptServerIdentifier) String() string { - return fmt.Sprintf("BSDP Server Identifier -> %v", o.ServerID.String()) -} diff --git a/dhcpv4/bsdp/bsdp_option_server_identifier_test.go b/dhcpv4/bsdp/bsdp_option_server_identifier_test.go deleted file mode 100644 index 5a77644..0000000 --- a/dhcpv4/bsdp/bsdp_option_server_identifier_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package bsdp - -import ( - "net" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestOptServerIdentifierInterfaceMethods(t *testing.T) { - ip := net.IP{192, 168, 0, 1} - o := OptServerIdentifier{ServerID: ip} - require.Equal(t, OptionServerIdentifier, o.Code(), "Code") - expectedBytes := []byte{192, 168, 0, 1} - require.Equal(t, expectedBytes, o.ToBytes(), "ToBytes") - require.Equal(t, "BSDP Server Identifier -> 192.168.0.1", o.String(), "String") -} - -func TestParseOptServerIdentifier(t *testing.T) { - var ( - o *OptServerIdentifier - err error - ) - o, err = ParseOptServerIdentifier([]byte{}) - require.Error(t, err, "empty byte stream") - - o, err = ParseOptServerIdentifier([]byte{3, 4, 192}) - require.Error(t, err, "wrong IP length") - - o, err = ParseOptServerIdentifier([]byte{192, 168, 0, 1}) - require.NoError(t, err) - require.Equal(t, net.IP{192, 168, 0, 1}, o.ServerID) -} diff --git a/dhcpv4/bsdp/bsdp_option_server_priority.go b/dhcpv4/bsdp/bsdp_option_server_priority.go deleted file mode 100644 index f6fcf57..0000000 --- a/dhcpv4/bsdp/bsdp_option_server_priority.go +++ /dev/null @@ -1,37 +0,0 @@ -package bsdp - -import ( - "fmt" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/u-root/u-root/pkg/uio" -) - -// OptServerPriority represents an option encapsulating the server priority. -type OptServerPriority struct { - Priority uint16 -} - -// ParseOptServerPriority returns a new OptServerPriority from a byte stream, or -// error if any. -func ParseOptServerPriority(data []byte) (*OptServerPriority, error) { - buf := uio.NewBigEndianBuffer(data) - return &OptServerPriority{Priority: buf.Read16()}, buf.FinError() -} - -// Code returns the option code. -func (o *OptServerPriority) Code() dhcpv4.OptionCode { - return OptionServerPriority -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptServerPriority) ToBytes() []byte { - buf := uio.NewBigEndianBuffer(nil) - buf.Write16(o.Priority) - return buf.Data() -} - -// String returns a human-readable string. -func (o *OptServerPriority) String() string { - return fmt.Sprintf("BSDP Server Priority -> %v", o.Priority) -} diff --git a/dhcpv4/bsdp/bsdp_option_server_priority_test.go b/dhcpv4/bsdp/bsdp_option_server_priority_test.go deleted file mode 100644 index c4c96de..0000000 --- a/dhcpv4/bsdp/bsdp_option_server_priority_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package bsdp - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestOptServerPriorityInterfaceMethods(t *testing.T) { - o := OptServerPriority{Priority: 100} - require.Equal(t, OptionServerPriority, o.Code(), "Code") - require.Equal(t, []byte{0, 100}, o.ToBytes(), "ToBytes") - require.Equal(t, "BSDP Server Priority -> 100", o.String(), "String") -} - -func TestParseOptServerPriority(t *testing.T) { - var ( - o *OptServerPriority - err error - ) - o, err = ParseOptServerPriority([]byte{}) - require.Error(t, err, "empty byte stream") - - o, err = ParseOptServerPriority([]byte{1}) - require.Error(t, err, "short byte stream") - - o, err = ParseOptServerPriority([]byte{0, 100}) - require.NoError(t, err) - require.Equal(t, uint16(100), o.Priority) -} diff --git a/dhcpv4/bsdp/bsdp_option_version.go b/dhcpv4/bsdp/bsdp_option_version.go deleted file mode 100644 index d6b78c8..0000000 --- a/dhcpv4/bsdp/bsdp_option_version.go +++ /dev/null @@ -1,41 +0,0 @@ -package bsdp - -import ( - "fmt" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/u-root/u-root/pkg/uio" -) - -// Version is the BSDP protocol version. Can be one of 1.0 or 1.1. -type Version [2]byte - -// Specific versions. -var ( - Version1_0 = Version{1, 0} - Version1_1 = Version{1, 1} -) - -// ParseOptVersion constructs an OptVersion struct from a sequence of -// bytes and returns it, or an error. -func ParseOptVersion(data []byte) (Version, error) { - buf := uio.NewBigEndianBuffer(data) - var v Version - buf.ReadBytes(v[:]) - return v, buf.FinError() -} - -// Code returns the option code. -func (o Version) Code() dhcpv4.OptionCode { - return OptionVersion -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o Version) ToBytes() []byte { - return o[:] -} - -// String returns a human-readable string for this option. -func (o Version) String() string { - return fmt.Sprintf("BSDP Version -> %d.%d", o[0], o[1]) -} diff --git a/dhcpv4/bsdp/bsdp_option_version_test.go b/dhcpv4/bsdp/bsdp_option_version_test.go deleted file mode 100644 index 69d4c86..0000000 --- a/dhcpv4/bsdp/bsdp_option_version_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package bsdp - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestOptVersionInterfaceMethods(t *testing.T) { - o := Version1_1 - require.Equal(t, OptionVersion, o.Code(), "Code") - require.Equal(t, []byte{1, 1}, o.ToBytes(), "ToBytes") -} - -func TestParseOptVersion(t *testing.T) { - data := []byte{1, 1} - o, err := ParseOptVersion(data) - require.NoError(t, err) - require.Equal(t, Version1_1, o) - - // Short byte stream - data = []byte{2} - _, err = ParseOptVersion(data) - require.Error(t, err, "should get error from short byte stream") -} - -func TestOptVersionString(t *testing.T) { - require.Equal(t, "BSDP Version -> 1.1", Version1_1.String()) -} diff --git a/dhcpv4/bsdp/bsdp_test.go b/dhcpv4/bsdp/bsdp_test.go index 638a408..e0378c2 100644 --- a/dhcpv4/bsdp/bsdp_test.go +++ b/dhcpv4/bsdp/bsdp_test.go @@ -12,9 +12,9 @@ import ( func RequireHasOption(t *testing.T, opts dhcpv4.Options, opt dhcpv4.Option) { require.NotNil(t, opts, "must pass list of options") require.NotNil(t, opt, "must pass option") - require.True(t, opts.Has(opt.Code())) - actual := opts.GetOne(opt.Code()) - require.Equal(t, opt, actual) + require.True(t, opts.Has(opt.Code)) + actual := opts.Get(opt.Code) + require.Equal(t, opt.Value.ToBytes(), actual) } func TestParseBootImageListFromAck(t *testing.T) { @@ -37,11 +37,11 @@ func TestParseBootImageListFromAck(t *testing.T) { }, } ack, _ := dhcpv4.New() - ack.UpdateOption(&OptVendorSpecificInformation{ - []dhcpv4.Option{&OptBootImageList{expectedBootImages}}, - }) + ack.UpdateOption(OptVendorOptions( + OptBootImageList(expectedBootImages...), + )) - images, err := ParseBootImageListFromAck(*ack) + images, err := ParseBootImageListFromAck(ack) require.NoError(t, err) require.NotEmpty(t, images, "should get BootImages") require.Equal(t, expectedBootImages, images, "should get same BootImages") @@ -49,7 +49,7 @@ func TestParseBootImageListFromAck(t *testing.T) { func TestParseBootImageListFromAckNoVendorOption(t *testing.T) { ack, _ := dhcpv4.New() - images, err := ParseBootImageListFromAck(*ack) + images, err := ParseBootImageListFromAck(ack) require.Error(t, err) require.Empty(t, images, "no BootImages") } @@ -70,14 +70,13 @@ func TestNewInformList_NoReplyPort(t *testing.T) { require.True(t, m.Options.Has(dhcpv4.OptionParameterRequestList)) require.True(t, m.Options.Has(dhcpv4.OptionMaximumDHCPMessageSize)) - opt := m.GetOneOption(dhcpv4.OptionVendorSpecificInformation) - require.NotNil(t, opt, "vendor opts not present") - vendorInfo := opt.(*OptVendorSpecificInformation) - require.True(t, vendorInfo.Options.Has(OptionMessageType)) - require.True(t, vendorInfo.Options.Has(OptionVersion)) + vendorOpts := GetVendorOptions(m.Options) + require.NotNil(t, vendorOpts, "vendor opts not present") + require.True(t, vendorOpts.Has(OptionMessageType)) + require.True(t, vendorOpts.Has(OptionVersion)) - opt = vendorInfo.GetOneOption(OptionMessageType) - require.Equal(t, MessageTypeList, opt.(*OptMessageType).Type) + mt := GetMessageType(vendorOpts.Options) + require.Equal(t, MessageTypeList, mt) } func TestNewInformList_ReplyPort(t *testing.T) { @@ -94,12 +93,12 @@ func TestNewInformList_ReplyPort(t *testing.T) { m, err := NewInformList(hwAddr, localIP, replyPort) require.NoError(t, err) - opt := m.GetOneOption(dhcpv4.OptionVendorSpecificInformation) - vendorInfo := opt.(*OptVendorSpecificInformation) - require.True(t, vendorInfo.Options.Has(OptionReplyPort)) + vendorOpts := GetVendorOptions(m.Options) + require.True(t, vendorOpts.Options.Has(OptionReplyPort)) - opt = vendorInfo.GetOneOption(OptionReplyPort) - require.Equal(t, replyPort, opt.(*OptReplyPort).Port) + port, err := GetReplyPort(vendorOpts.Options) + require.NoError(t, err) + require.Equal(t, replyPort, port) } func newAck(hwAddr net.HardwareAddr, transactionID [4]byte) *dhcpv4.DHCPv4 { @@ -108,7 +107,7 @@ func newAck(hwAddr net.HardwareAddr, transactionID [4]byte) *dhcpv4.DHCPv4 { ack.TransactionID = transactionID ack.HWType = iana.HWTypeEthernet ack.ClientHWAddr = hwAddr - ack.UpdateOption(&dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeAck}) + ack.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck)) return ack } @@ -126,9 +125,9 @@ func TestInformSelectForAck_Broadcast(t *testing.T) { } ack := newAck(hwAddr, tid) ack.SetBroadcast() - ack.UpdateOption(&dhcpv4.OptServerIdentifier{ServerID: serverID}) + ack.UpdateOption(dhcpv4.OptServerIdentifier(serverID)) - m, err := InformSelectForAck(*ack, 0, bootImage) + m, err := InformSelectForAck(PacketFor(ack), 0, bootImage) require.NoError(t, err) require.Equal(t, dhcpv4.OpcodeBootRequest, m.OpCode) require.Equal(t, ack.HWType, m.HWType) @@ -140,17 +139,16 @@ func TestInformSelectForAck_Broadcast(t *testing.T) { require.True(t, m.Options.Has(dhcpv4.OptionClassIdentifier)) require.True(t, m.Options.Has(dhcpv4.OptionParameterRequestList)) require.True(t, m.Options.Has(dhcpv4.OptionDHCPMessageType)) - opt := m.GetOneOption(dhcpv4.OptionDHCPMessageType) - require.Equal(t, dhcpv4.MessageTypeInform, opt.(*dhcpv4.OptMessageType).MessageType) + mt := dhcpv4.GetMessageType(m.Options) + require.Equal(t, dhcpv4.MessageTypeInform, mt) // Validate vendor opts. require.True(t, m.Options.Has(dhcpv4.OptionVendorSpecificInformation)) - opt = m.GetOneOption(dhcpv4.OptionVendorSpecificInformation) - vendorInfo := opt.(*OptVendorSpecificInformation) - RequireHasOption(t, vendorInfo.Options, &OptMessageType{Type: MessageTypeSelect}) - require.True(t, vendorInfo.Options.Has(OptionVersion)) - RequireHasOption(t, vendorInfo.Options, &OptSelectedBootImageID{ID: bootImage.ID}) - RequireHasOption(t, vendorInfo.Options, &OptServerIdentifier{ServerID: serverID}) + vendorOpts := GetVendorOptions(m.Options).Options + RequireHasOption(t, vendorOpts, OptMessageType(MessageTypeSelect)) + require.True(t, vendorOpts.Has(OptionVersion)) + RequireHasOption(t, vendorOpts, OptSelectedBootImageID(bootImage.ID)) + RequireHasOption(t, vendorOpts, OptServerIdentifier(serverID)) } func TestInformSelectForAck_NoServerID(t *testing.T) { @@ -166,7 +164,7 @@ func TestInformSelectForAck_NoServerID(t *testing.T) { } ack := newAck(hwAddr, tid) - _, err := InformSelectForAck(*ack, 0, bootImage) + _, err := InformSelectForAck(PacketFor(ack), 0, bootImage) require.Error(t, err, "expect error for no server identifier option") } @@ -184,9 +182,9 @@ func TestInformSelectForAck_BadReplyPort(t *testing.T) { } ack := newAck(hwAddr, tid) ack.SetBroadcast() - ack.UpdateOption(&dhcpv4.OptServerIdentifier{ServerID: serverID}) + ack.UpdateOption(dhcpv4.OptServerIdentifier(serverID)) - _, err := InformSelectForAck(*ack, 11223, bootImage) + _, err := InformSelectForAck(PacketFor(ack), 11223, bootImage) require.Error(t, err, "expect error for > 1024 replyPort") } @@ -204,16 +202,15 @@ func TestInformSelectForAck_ReplyPort(t *testing.T) { } ack := newAck(hwAddr, tid) ack.SetBroadcast() - ack.UpdateOption(&dhcpv4.OptServerIdentifier{ServerID: serverID}) + ack.UpdateOption(dhcpv4.OptServerIdentifier(serverID)) replyPort := uint16(999) - m, err := InformSelectForAck(*ack, replyPort, bootImage) + m, err := InformSelectForAck(PacketFor(ack), replyPort, bootImage) require.NoError(t, err) require.True(t, m.Options.Has(dhcpv4.OptionVendorSpecificInformation)) - opt := m.GetOneOption(dhcpv4.OptionVendorSpecificInformation) - vendorInfo := opt.(*OptVendorSpecificInformation) - RequireHasOption(t, vendorInfo.Options, &OptReplyPort{Port: replyPort}) + vendorOpts := GetVendorOptions(m.Options).Options + RequireHasOption(t, vendorOpts, OptReplyPort(replyPort)) } func TestNewReplyForInformList_NoDefaultImage(t *testing.T) { @@ -274,24 +271,24 @@ func TestNewReplyForInformList(t *testing.T) { require.Equal(t, "bsdp.foo.com", ack.ServerHostName) // Validate options. - RequireHasOption(t, ack.Options, &dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeAck}) - RequireHasOption(t, ack.Options, &dhcpv4.OptServerIdentifier{ServerID: net.IP{9, 9, 9, 9}}) - RequireHasOption(t, ack.Options, &dhcpv4.OptClassIdentifier{Identifier: AppleVendorID}) + RequireHasOption(t, ack.Options, dhcpv4.OptMessageType(dhcpv4.MessageTypeAck)) + RequireHasOption(t, ack.Options, dhcpv4.OptServerIdentifier(net.IP{9, 9, 9, 9})) + RequireHasOption(t, ack.Options, dhcpv4.OptClassIdentifier(AppleVendorID)) require.NotNil(t, ack.GetOneOption(dhcpv4.OptionVendorSpecificInformation)) // Vendor-specific options. - vendorOpts := ack.GetOneOption(dhcpv4.OptionVendorSpecificInformation).(*OptVendorSpecificInformation) - RequireHasOption(t, vendorOpts.Options, &OptMessageType{Type: MessageTypeList}) - RequireHasOption(t, vendorOpts.Options, &OptDefaultBootImageID{ID: images[0].ID}) - RequireHasOption(t, vendorOpts.Options, &OptServerPriority{Priority: 0x7070}) - RequireHasOption(t, vendorOpts.Options, &OptBootImageList{Images: images}) + vendorOpts := GetVendorOptions(ack.Options).Options + RequireHasOption(t, vendorOpts, OptMessageType(MessageTypeList)) + RequireHasOption(t, vendorOpts, OptDefaultBootImageID(images[0].ID)) + RequireHasOption(t, vendorOpts, OptServerPriority(0x7070)) + RequireHasOption(t, vendorOpts, OptBootImageList(images...)) // Add in selected boot image, ensure it's in the generated ACK. config.SelectedImage = &images[0] ack, err = NewReplyForInformList(inform, config) require.NoError(t, err) - vendorOpts = ack.GetOneOption(dhcpv4.OptionVendorSpecificInformation).(*OptVendorSpecificInformation) - RequireHasOption(t, vendorOpts.Options, &OptSelectedBootImageID{ID: images[0].ID}) + vendorOpts = GetVendorOptions(ack.Options).Options + RequireHasOption(t, vendorOpts, OptSelectedBootImageID(images[0].ID)) } func TestNewReplyForInformSelect_NoSelectedImage(t *testing.T) { @@ -352,30 +349,22 @@ func TestNewReplyForInformSelect(t *testing.T) { require.Equal(t, "bsdp.foo.com", ack.ServerHostName) // Validate options. - RequireHasOption(t, ack.Options, &dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeAck}) - RequireHasOption(t, ack.Options, &dhcpv4.OptServerIdentifier{ServerID: net.IP{9, 9, 9, 9}}) - RequireHasOption(t, ack.Options, &dhcpv4.OptServerIdentifier{ServerID: net.IP{9, 9, 9, 9}}) - RequireHasOption(t, ack.Options, &dhcpv4.OptClassIdentifier{Identifier: AppleVendorID}) + RequireHasOption(t, ack.Options, dhcpv4.OptMessageType(dhcpv4.MessageTypeAck)) + RequireHasOption(t, ack.Options, dhcpv4.OptServerIdentifier(net.IP{9, 9, 9, 9})) + RequireHasOption(t, ack.Options, dhcpv4.OptServerIdentifier(net.IP{9, 9, 9, 9})) + RequireHasOption(t, ack.Options, dhcpv4.OptClassIdentifier(AppleVendorID)) require.NotNil(t, ack.GetOneOption(dhcpv4.OptionVendorSpecificInformation)) - vendorOpts := ack.GetOneOption(dhcpv4.OptionVendorSpecificInformation).(*OptVendorSpecificInformation) - RequireHasOption(t, vendorOpts.Options, &OptMessageType{Type: MessageTypeSelect}) - RequireHasOption(t, vendorOpts.Options, &OptSelectedBootImageID{ID: images[0].ID}) + vendorOpts := GetVendorOptions(ack.Options) + RequireHasOption(t, vendorOpts.Options, OptMessageType(MessageTypeSelect)) + RequireHasOption(t, vendorOpts.Options, OptSelectedBootImageID(images[0].ID)) } func TestMessageTypeForPacket(t *testing.T) { - var ( - pkt *dhcpv4.DHCPv4 - gotMessageType *MessageType - ) - - list := new(MessageType) - *list = MessageTypeList - testcases := []struct { tcName string opts []dhcpv4.Option - wantMessageType *MessageType + wantMessageType MessageType }{ { tcName: "No options", @@ -384,45 +373,38 @@ func TestMessageTypeForPacket(t *testing.T) { { tcName: "Some options, no vendor opts", opts: []dhcpv4.Option{ - &dhcpv4.OptHostName{HostName: "foobar1234"}, + dhcpv4.OptHostName("foobar1234"), }, }, { tcName: "Vendor opts, no message type", opts: []dhcpv4.Option{ - &dhcpv4.OptHostName{HostName: "foobar1234"}, - &OptVendorSpecificInformation{ - Options: []dhcpv4.Option{ - Version1_1, - }, - }, + dhcpv4.OptHostName("foobar1234"), + OptVendorOptions( + OptVersion(Version1_1), + ), }, }, { tcName: "Vendor opts, with message type", opts: []dhcpv4.Option{ - &dhcpv4.OptHostName{HostName: "foobar1234"}, - &OptVendorSpecificInformation{ - Options: []dhcpv4.Option{ - Version1_1, - &OptMessageType{Type: MessageTypeList}, - }, - }, + dhcpv4.OptHostName("foobar1234"), + OptVendorOptions( + OptVersion(Version1_1), + OptMessageType(MessageTypeList), + ), }, - wantMessageType: list, + wantMessageType: MessageTypeList, }, } for _, tt := range testcases { t.Run(tt.tcName, func(t *testing.T) { - pkt, _ = dhcpv4.New() + pkt, _ := dhcpv4.New() for _, opt := range tt.opts { pkt.UpdateOption(opt) } - gotMessageType = MessageTypeFromPacket(pkt) + gotMessageType := MessageTypeFromPacket(pkt) require.Equal(t, tt.wantMessageType, gotMessageType) - if tt.wantMessageType != nil { - require.Equal(t, *tt.wantMessageType, *gotMessageType) - } }) } } diff --git a/dhcpv4/bsdp/client.go b/dhcpv4/bsdp/client.go index dd4a0a0..e8ca2ca 100644 --- a/dhcpv4/bsdp/client.go +++ b/dhcpv4/bsdp/client.go @@ -18,24 +18,10 @@ func NewClient() *Client { return &Client{Client: dhcpv4.Client{}} } -func castVendorOpt(ack *dhcpv4.DHCPv4) { - opts := ack.Options - for i := 0; i < len(opts); i++ { - if opts[i].Code() == dhcpv4.OptionVendorSpecificInformation { - vendorOpt, err := ParseOptVendorSpecificInformation(opts[i].ToBytes()) - // Oh well, we tried - if err != nil { - return - } - opts[i] = vendorOpt - } - } -} - // Exchange runs a full BSDP exchange (Inform[list], Ack, Inform[select], // Ack). Returns a list of DHCPv4 structures representing the exchange. -func (c *Client) Exchange(ifname string) ([]*dhcpv4.DHCPv4, error) { - conversation := make([]*dhcpv4.DHCPv4, 0) +func (c *Client) Exchange(ifname string) ([]*Packet, error) { + conversation := make([]*Packet, 0) // Get our file descriptor for the broadcast socket. sendFd, err := dhcpv4.MakeBroadcastSocket(ifname) @@ -55,17 +41,16 @@ func (c *Client) Exchange(ifname string) ([]*dhcpv4.DHCPv4, error) { conversation = append(conversation, informList) // ACK[LIST] - ackForList, err := c.Client.SendReceive(sendFd, recvFd, informList, dhcpv4.MessageTypeAck) + ackForList, err := c.Client.SendReceive(sendFd, recvFd, informList.v4(), dhcpv4.MessageTypeAck) if err != nil { return conversation, err } // Rewrite vendor-specific option for pretty printing. - castVendorOpt(ackForList) - conversation = append(conversation, ackForList) + conversation = append(conversation, PacketFor(ackForList)) // Parse boot images sent back by server - bootImages, err := ParseBootImageListFromAck(*ackForList) + bootImages, err := ParseBootImageListFromAck(ackForList) if err != nil { return conversation, err } @@ -74,17 +59,16 @@ func (c *Client) Exchange(ifname string) ([]*dhcpv4.DHCPv4, error) { } // INFORM[SELECT] - informSelect, err := InformSelectForAck(*ackForList, dhcpv4.ClientPort, bootImages[0]) + informSelect, err := InformSelectForAck(PacketFor(ackForList), dhcpv4.ClientPort, bootImages[0]) if err != nil { return conversation, err } conversation = append(conversation, informSelect) // ACK[SELECT] - ackForSelect, err := c.Client.SendReceive(sendFd, recvFd, informSelect, dhcpv4.MessageTypeAck) - castVendorOpt(ackForSelect) + ackForSelect, err := c.Client.SendReceive(sendFd, recvFd, informSelect.v4(), dhcpv4.MessageTypeAck) if err != nil { return conversation, err } - return append(conversation, ackForSelect), nil + return append(conversation, PacketFor(ackForSelect)), nil } diff --git a/dhcpv4/bsdp/option_vendor_specific_information.go b/dhcpv4/bsdp/option_vendor_specific_information.go index a87135f..4e107e1 100644 --- a/dhcpv4/bsdp/option_vendor_specific_information.go +++ b/dhcpv4/bsdp/option_vendor_specific_information.go @@ -1,93 +1,87 @@ package bsdp import ( - "strings" + "fmt" "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/u-root/u-root/pkg/uio" ) -// OptVendorSpecificInformation encapsulates the BSDP-specific options used for -// the protocol. -type OptVendorSpecificInformation struct { - Options dhcpv4.Options +// VendorOptions is like dhcpv4.Options, but stringifies using BSDP-specific +// option codes. +type VendorOptions struct { + dhcpv4.Options } -// parseOption is similar to dhcpv4.ParseOption, except that it switches based -// on the BSDP specific options. -func parseOption(code dhcpv4.OptionCode, data []byte) (dhcpv4.Option, error) { - var ( - opt dhcpv4.Option - err error - ) - switch code { - case OptionBootImageList: - opt, err = ParseOptBootImageList(data) - case OptionDefaultBootImageID: - opt, err = ParseOptDefaultBootImageID(data) - case OptionMachineName: - opt, err = ParseOptMachineName(data) - case OptionMessageType: - opt, err = ParseOptMessageType(data) - case OptionReplyPort: - opt, err = ParseOptReplyPort(data) - case OptionSelectedBootImageID: - opt, err = ParseOptSelectedBootImageID(data) - case OptionServerIdentifier: - opt, err = ParseOptServerIdentifier(data) - case OptionServerPriority: - opt, err = ParseOptServerPriority(data) - case OptionVersion: - opt, err = ParseOptVersion(data) - default: - opt, err = ParseOptGeneric(code, data) - } - if err != nil { - return nil, err - } - return opt, nil +// String prints the contained options using BSDP-specific option code parsing. +func (v VendorOptions) String() string { + return v.Options.ToString(bsdpHumanizer) } -// codeGetter is a dhcpv4.OptionCodeGetter for BSDP optionCodes. -func codeGetter(c uint8) dhcpv4.OptionCode { - return optionCode(c) +// FromBytes parses vendor options from +func (v *VendorOptions) FromBytes(data []byte) error { + v.Options = make(dhcpv4.Options) + return v.Options.FromBytes(data) } -// ParseOptVendorSpecificInformation constructs an OptVendorSpecificInformation struct from a sequence of -// bytes and returns it, or an error. -func ParseOptVendorSpecificInformation(data []byte) (*OptVendorSpecificInformation, error) { - options, err := dhcpv4.OptionsFromBytesWithParser(data, codeGetter, parseOption, false /* don't check for OptionEnd tag */) - if err != nil { - return nil, err +// OptVendorOptions returns the BSDP Vendor Specific Info in o. +func OptVendorOptions(o ...dhcpv4.Option) dhcpv4.Option { + return dhcpv4.Option{ + Code: dhcpv4.OptionVendorSpecificInformation, + Value: VendorOptions{dhcpv4.OptionsFromList(o...)}, } - return &OptVendorSpecificInformation{options}, nil } -// Code returns the option code. -func (o *OptVendorSpecificInformation) Code() dhcpv4.OptionCode { - return dhcpv4.OptionVendorSpecificInformation +// GetVendorOptions returns a new BSDP Vendor Specific Info option. +func GetVendorOptions(o dhcpv4.Options) *VendorOptions { + v := o.Get(dhcpv4.OptionVendorSpecificInformation) + if v == nil { + return nil + } + var vo VendorOptions + if err := vo.FromBytes(v); err != nil { + return nil + } + return &vo } -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptVendorSpecificInformation) ToBytes() []byte { - return uio.ToBigEndian(o.Options) +var bsdpHumanizer = dhcpv4.OptionHumanizer{ + ValueHumanizer: parseOption, + CodeHumanizer: func(c uint8) dhcpv4.OptionCode { + return optionCode(c) + }, } -// String returns a human-readable string for this option. -func (o *OptVendorSpecificInformation) String() string { - s := "Vendor Specific Information ->" - for _, opt := range o.Options { - optString := opt.String() - // If this option has sub-structures, offset them accordingly. - if strings.Contains(optString, "\n") { - optString = strings.Replace(optString, "\n ", "\n ", -1) - } - s += "\n " + optString - } - return s -} +// parseOption is similar to dhcpv4.parseOption, except that it interprets +// option codes based on the BSDP-specific options. +func parseOption(code dhcpv4.OptionCode, data []byte) fmt.Stringer { + var d dhcpv4.OptionDecoder + switch code { + case OptionMachineName: + var s dhcpv4.String + d = &s + + case OptionServerIdentifier: + d = &dhcpv4.IP{} + + case OptionServerPriority, OptionReplyPort: + var u dhcpv4.Uint16 + d = &u + + case OptionBootImageList: + d = &BootImageList{} -// GetOneOption returns the first suboption that matches the OptionCode code. -func (o *OptVendorSpecificInformation) GetOneOption(code dhcpv4.OptionCode) dhcpv4.Option { - return o.Options.GetOne(code) + case OptionDefaultBootImageID, OptionSelectedBootImageID: + d = &BootImageID{} + + case OptionMessageType: + var m MessageType + d = &m + + case OptionVersion: + d = &Version{} + } + if d != nil && d.FromBytes(data) == nil { + return d + } + return dhcpv4.OptionGeneric{data} } diff --git a/dhcpv4/bsdp/option_vendor_specific_information_test.go b/dhcpv4/bsdp/option_vendor_specific_information_test.go index ede8a0b..a6727f5 100644 --- a/dhcpv4/bsdp/option_vendor_specific_information_test.go +++ b/dhcpv4/bsdp/option_vendor_specific_information_test.go @@ -1,6 +1,7 @@ package bsdp import ( + "net" "testing" "github.com/insomniacslk/dhcp/dhcpv4" @@ -8,182 +9,71 @@ import ( ) func TestOptVendorSpecificInformationInterfaceMethods(t *testing.T) { - messageTypeOpt := &OptMessageType{MessageTypeList} - versionOpt := Version1_1 - o := &OptVendorSpecificInformation{[]dhcpv4.Option{messageTypeOpt, versionOpt}} - require.Equal(t, dhcpv4.OptionVendorSpecificInformation, o.Code(), "Code") - - expectedBytes := []byte{ - 1, 1, 1, // List option - 2, 2, 1, 1, // Version option - } - o = &OptVendorSpecificInformation{ - []dhcpv4.Option{ - &OptMessageType{MessageTypeList}, - Version1_1, - }, - } - require.Equal(t, expectedBytes, o.ToBytes(), "ToBytes") -} - -func TestParseOptVendorSpecificInformation(t *testing.T) { - var ( - o *OptVendorSpecificInformation - err error + o := OptVendorOptions( + OptVersion(Version1_1), + OptMessageType(MessageTypeList), ) - o, err = ParseOptVendorSpecificInformation([]byte{1, 2}) - require.Error(t, err, "short byte stream") - - // Good byte stream - data := []byte{ - 1, 1, 1, // List option - 2, 2, 1, 1, // Version option - } - o, err = ParseOptVendorSpecificInformation(data) - require.NoError(t, err) - expected := &OptVendorSpecificInformation{ - []dhcpv4.Option{ - &OptMessageType{MessageTypeList}, - Version1_1, - }, - } - require.Equal(t, 2, len(o.Options), "number of parsed suboptions") - typ := o.GetOneOption(OptionMessageType) - version := o.GetOneOption(OptionVersion) - require.Equal(t, expected.Options[0].Code(), typ.Code()) - require.Equal(t, expected.Options[1].Code(), version.Code()) - - // Short byte stream (length and data mismatch) - data = []byte{ - 1, 1, 1, // List option - 2, 2, 1, // Version option - } - o, err = ParseOptVendorSpecificInformation(data) - require.Error(t, err) - - // Bad option - data = []byte{ - 1, 1, 1, // List option - 2, 2, 1, // Version option - 5, 3, 1, 1, 1, // Reply port option - } - o, err = ParseOptVendorSpecificInformation(data) - require.Error(t, err) + require.Equal(t, dhcpv4.OptionVendorSpecificInformation, o.Code, "Code") - // Boot images + default. - data = []byte{ + expectedBytes := []byte{ 1, 1, 1, // List option 2, 2, 1, 1, // Version option - 5, 2, 1, 1, // Reply port option - - // Boot image list - 9, 22, - 0x1, 0x0, 0x03, 0xe9, // ID - 6, // name length - 'b', 's', 'd', 'p', '-', '1', - 0x80, 0x0, 0x23, 0x31, // ID - 6, // name length - 'b', 's', 'd', 'p', '-', '2', - - // Default Boot Image ID - 7, 4, 0x1, 0x0, 0x03, 0xe9, } - o, err = ParseOptVendorSpecificInformation(data) - require.NoError(t, err) - require.Equal(t, 5, len(o.Options)) - for _, opt := range []dhcpv4.OptionCode{ - OptionMessageType, - OptionVersion, - OptionReplyPort, - OptionBootImageList, - OptionDefaultBootImageID, - } { - require.True(t, o.Options.Has(opt)) - } - optBootImage := o.GetOneOption(OptionBootImageList).(*OptBootImageList) - expectedBootImages := []BootImage{ - BootImage{ - ID: BootImageID{ - IsInstall: false, - ImageType: BootImageTypeMacOSX, - Index: 1001, - }, - Name: "bsdp-1", - }, - BootImage{ - ID: BootImageID{ - IsInstall: true, - ImageType: BootImageTypeMacOS9, - Index: 9009, - }, - Name: "bsdp-2", - }, - } - require.Equal(t, expectedBootImages, optBootImage.Images) + require.Equal(t, expectedBytes, o.Value.ToBytes(), "ToBytes") } func TestOptVendorSpecificInformationString(t *testing.T) { - o := &OptVendorSpecificInformation{ - []dhcpv4.Option{ - &OptMessageType{MessageTypeList}, - Version1_1, - }, - } - expectedString := "Vendor Specific Information ->\n BSDP Message Type -> LIST\n BSDP Version -> 1.1" + o := OptVendorOptions( + OptMessageType(MessageTypeList), + OptVersion(Version1_1), + ) + expectedString := "Vendor Specific Information:\n BSDP Message Type: LIST\n BSDP Version: 1.1\n" require.Equal(t, expectedString, o.String()) // Test more complicated string - sub options of sub options. - o = &OptVendorSpecificInformation{ - []dhcpv4.Option{ - &OptMessageType{MessageTypeList}, - &OptBootImageList{ - []BootImage{ - BootImage{ - ID: BootImageID{ - IsInstall: false, - ImageType: BootImageTypeMacOSX, - Index: 1001, - }, - Name: "bsdp-1", - }, - BootImage{ - ID: BootImageID{ - IsInstall: true, - ImageType: BootImageTypeMacOS9, - Index: 9009, - }, - Name: "bsdp-2", - }, + o = OptVendorOptions( + OptMessageType(MessageTypeList), + OptBootImageList( + BootImage{ + ID: BootImageID{ + IsInstall: false, + ImageType: BootImageTypeMacOSX, + Index: 1001, }, + Name: "bsdp-1", }, - }, - } - expectedString = "Vendor Specific Information ->\n" + - " BSDP Message Type -> LIST\n" + - " BSDP Boot Image List ->\n" + - " bsdp-1 [1001] uninstallable macOS image\n" + - " bsdp-2 [9009] installable macOS 9 image" + BootImage{ + ID: BootImageID{ + IsInstall: true, + ImageType: BootImageTypeMacOS9, + Index: 9009, + }, + Name: "bsdp-2", + }, + ), + OptMachineName("foo"), + OptServerIdentifier(net.IP{1, 1, 1, 1}), + OptServerPriority(1234), + OptReplyPort(1235), + OptDefaultBootImageID(BootImageID{ + IsInstall: true, + ImageType: BootImageTypeMacOS9, + Index: 9009, + }), + OptSelectedBootImageID(BootImageID{ + IsInstall: true, + ImageType: BootImageTypeMacOS9, + Index: 9009, + }), + ) + expectedString = "Vendor Specific Information:\n" + + " BSDP Message Type: LIST\n" + + " BSDP Server Identifier: 1.1.1.1\n" + + " BSDP Server Priority: 1234\n" + + " BSDP Reply Port: 1235\n" + + " BSDP Default Boot Image ID: [9009] installable macOS 9 image\n" + + " BSDP Selected Boot Image ID: [9009] installable macOS 9 image\n" + + " BSDP Boot Image List: bsdp-1 [1001] uninstallable macOS image, bsdp-2 [9009] installable macOS 9 image\n" + + " BSDP Machine Name: foo\n" require.Equal(t, expectedString, o.String()) } - -func TestOptVendorSpecificInformationGetOneOption(t *testing.T) { - // No option - o := &OptVendorSpecificInformation{ - []dhcpv4.Option{ - &OptMessageType{MessageTypeList}, - Version1_1, - }, - } - foundOpt := o.GetOneOption(OptionBootImageList) - require.Nil(t, foundOpt, "should not get options") - - // One option - o = &OptVendorSpecificInformation{ - []dhcpv4.Option{ - &OptMessageType{MessageTypeList}, - Version1_1, - }, - } - foundOpt = o.GetOneOption(OptionMessageType) - require.Equal(t, MessageTypeList, foundOpt.(*OptMessageType).Type) -} diff --git a/dhcpv4/bsdp/types.go b/dhcpv4/bsdp/types.go index 4ce840f..4931081 100644 --- a/dhcpv4/bsdp/types.go +++ b/dhcpv4/bsdp/types.go @@ -1,5 +1,9 @@ package bsdp +import ( + "fmt" +) + // DefaultMacOSVendorClassIdentifier is a default vendor class identifier used // on non-darwin hosts where the vendor class identifier cannot be determined. // It should mostly be used for debugging if testing BSDP on a non-darwin @@ -19,7 +23,7 @@ func (o optionCode) String() string { if s, ok := optionCodeToString[o]; ok { return s } - return "unknown" + return fmt.Sprintf("unknown (%d)", o) } // Options (occur as sub-options of DHCP option 43). diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index bbb2f37..d94c58a 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -1,3 +1,18 @@ +// Package dhcpv4 provides encoding and decoding of DHCPv4 packets and options. +// +// Example Usage: +// +// p, err := dhcpv4.New( +// dhcpv4.WithClientIP(net.IP{192, 168, 0, 1}), +// dhcpv4.WithMessageType(dhcpv4.MessageTypeInform), +// ) +// p.UpdateOption(dhcpv4.OptServerIdentifier(net.IP{192, 110, 110, 110})) +// +// // Retrieve the DHCP Message Type option. +// m := dhcpv4.GetMessageType(p.Options) +// +// bytesOnTheWire := p.ToBytes() +// longSummary := p.Summary() package dhcpv4 import ( @@ -121,7 +136,7 @@ func New(modifiers ...Modifier) (*DHCPv4, error) { YourIPAddr: net.IPv4zero, ServerIPAddr: net.IPv4zero, GatewayIPAddr: net.IPv4zero, - Options: make([]Option, 0, 10), + Options: make(Options), } for _, mod := range modifiers { mod(&d) @@ -203,11 +218,7 @@ func NewInform(hwaddr net.HardwareAddr, localIP net.IP, modifiers ...Modifier) ( // NewRequestFromOffer builds a DHCPv4 request from an offer. func NewRequestFromOffer(offer *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) { // find server IP address - var serverIP net.IP - serverID := offer.GetOneOption(OptionServerIdentifier) - if serverID != nil { - serverIP = serverID.(*OptServerIdentifier).ServerID - } + serverIP := GetServerIdentifier(offer.Options) if serverIP == nil { return nil, errors.New("Missing Server IP Address in DHCP Offer") } @@ -216,8 +227,8 @@ func NewRequestFromOffer(offer *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) WithReply(offer), WithMessageType(MessageTypeRequest), WithServerIP(serverIP), - WithOption(&OptRequestedIPAddress{RequestedAddr: offer.YourIPAddr}), - WithOption(&OptServerIdentifier{ServerID: serverIP}), + WithOption(OptRequestedIPAddress(offer.YourIPAddr)), + WithOption(OptServerIdentifier(serverIP)), )...) } @@ -281,11 +292,10 @@ func FromBytes(q []byte) (*DHCPv4, error) { return nil, fmt.Errorf("malformed DHCP packet: got magic cookie %v, want %v", cookie[:], magicCookie[:]) } - opts, err := OptionsFromBytes(buf.Data()) - if err != nil { + p.Options = make(Options) + if err := p.Options.fromBytesCheckEnd(buf.Data(), true); err != nil { return nil, err } - p.Options = opts return &p, nil } @@ -325,25 +335,25 @@ func (d *DHCPv4) SetUnicast() { // GetOneOption returns the option that matches the given option code. // -// If no matching option is found, nil is returned. -func (d *DHCPv4) GetOneOption(code OptionCode) Option { - return d.Options.GetOne(code) +// According to RFC 3396, options that are specified more than once are +// concatenated, and hence this should always just return one option. +func (d *DHCPv4) GetOneOption(code OptionCode) []byte { + return d.Options.Get(code) } // UpdateOption replaces an existing option with the same option code with the // given one, adding it if not already present. -func (d *DHCPv4) UpdateOption(option Option) { - d.Options.Update(option) +func (d *DHCPv4) UpdateOption(opt Option) { + if d.Options == nil { + d.Options = make(Options) + } + d.Options.Update(opt) } // MessageType returns the message type, trying to extract it from the // OptMessageType option. It returns nil if the message type cannot be extracted func (d *DHCPv4) MessageType() MessageType { - opt := d.GetOneOption(OptionDHCPMessageType) - if opt == nil { - return MessageTypeNone - } - return opt.(*OptMessageType).MessageType + return GetMessageType(d.Options) } // String implements fmt.Stringer. @@ -352,23 +362,24 @@ func (d *DHCPv4) String() string { d.OpCode, d.TransactionID, d.HWType, d.ClientHWAddr) } -// Summary prints detailed information about the packet. -func (d *DHCPv4) Summary() string { +// SummaryWithVendor prints a summary of the packet, interpreting the +// vendor-specific info option using the given parser (can be nil). +func (d *DHCPv4) SummaryWithVendor(vendorDecoder OptionDecoder) string { ret := fmt.Sprintf( - "DHCPv4\n"+ - " opcode=%s\n"+ - " hwtype=%s\n"+ - " hopcount=%v\n"+ - " transactionid=%s\n"+ - " numseconds=%v\n"+ - " flags=%v (0x%02x)\n"+ - " clientipaddr=%s\n"+ - " youripaddr=%s\n"+ - " serveripaddr=%s\n"+ - " gatewayipaddr=%s\n"+ - " clienthwaddr=%s\n"+ - " serverhostname=%s\n"+ - " bootfilename=%s\n", + "DHCPv4 Message\n"+ + " opcode: %s\n"+ + " hwtype: %s\n"+ + " hopcount: %v\n"+ + " transaction ID: %s\n"+ + " num seconds: %v\n"+ + " flags: %v (0x%02x)\n"+ + " client IP: %s\n"+ + " your IP: %s\n"+ + " server IP: %s\n"+ + " gateway IP: %s\n"+ + " client MAC: %s\n"+ + " server hostname: %s\n"+ + " bootfile name: %s\n", d.OpCode, d.HWType, d.HopCount, @@ -384,29 +395,20 @@ func (d *DHCPv4) Summary() string { d.ServerHostName, d.BootFileName, ) - ret += " options=\n" - for _, opt := range d.Options { - optString := opt.String() - // If this option has sub structures, offset them accordingly. - if strings.Contains(optString, "\n") { - optString = strings.Replace(optString, "\n ", "\n ", -1) - } - ret += fmt.Sprintf(" %v\n", optString) - if opt.Code() == OptionEnd { - break - } - } + ret += " options:\n" + ret += d.Options.Summary(vendorDecoder) return ret } +// Summary prints detailed information about the packet. +func (d *DHCPv4) Summary() string { + return d.SummaryWithVendor(nil) +} + // IsOptionRequested returns true if that option is within the requested // options of the DHCPv4 message. func (d *DHCPv4) IsOptionRequested(requested OptionCode) bool { - optprl := d.GetOneOption(OptionParameterRequestList) - if optprl == nil { - return false - } - for _, o := range optprl.(*OptParameterRequestList).RequestedOpts { + for _, o := range GetParameterRequestList(d.Options) { if o == requested { return true } @@ -459,7 +461,12 @@ func (d *DHCPv4) ToBytes() []byte { // The magic cookie. buf.WriteBytes(magicCookie[:]) + + // Write all options. d.Options.Marshal(buf) + + // Finish the packet. buf.Write8(uint8(OptionEnd)) + return buf.Data() } diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go index fb0ef70..a3e2a8a 100644 --- a/dhcpv4/dhcpv4_test.go +++ b/dhcpv4/dhcpv4_test.go @@ -179,43 +179,43 @@ func TestGetOption(t *testing.T) { t.Fatal(err) } - hostnameOpt := &OptionGeneric{OptionCode: OptionHostName, Data: []byte("darkstar")} - bootFileOpt := &OptBootfileName{"boot.img"} + hostnameOpt := OptGeneric(OptionHostName, []byte("darkstar")) + bootFileOpt := OptBootFileName("boot.img") d.UpdateOption(hostnameOpt) d.UpdateOption(bootFileOpt) - require.Equal(t, d.GetOneOption(OptionHostName), hostnameOpt) - require.Equal(t, d.GetOneOption(OptionBootfileName), bootFileOpt) - require.Equal(t, d.GetOneOption(OptionRouter), nil) + require.Equal(t, d.GetOneOption(OptionHostName), []byte("darkstar")) + require.Equal(t, d.GetOneOption(OptionBootfileName), []byte("boot.img")) + require.Equal(t, d.GetOneOption(OptionRouter), []byte(nil)) } func TestUpdateOption(t *testing.T) { d, err := New() require.NoError(t, err) - hostnameOpt := &OptionGeneric{OptionCode: OptionHostName, Data: []byte("darkstar")} - bootFileOpt1 := &OptBootfileName{"boot.img"} - bootFileOpt2 := &OptBootfileName{"boot2.img"} + hostnameOpt := OptGeneric(OptionHostName, []byte("darkstar")) + bootFileOpt1 := OptBootFileName("boot.img") + bootFileOpt2 := OptBootFileName("boot2.img") d.UpdateOption(hostnameOpt) d.UpdateOption(bootFileOpt1) d.UpdateOption(bootFileOpt2) options := d.Options require.Equal(t, len(options), 2) - require.Equal(t, d.Options.GetOne(OptionBootfileName), bootFileOpt2) - require.Equal(t, d.Options.GetOne(OptionHostName), hostnameOpt) + require.Equal(t, d.GetOneOption(OptionHostName), []byte("darkstar")) + require.Equal(t, d.GetOneOption(OptionBootfileName), []byte("boot2.img")) } func TestDHCPv4NewRequestFromOffer(t *testing.T) { offer, err := New() require.NoError(t, err) offer.SetBroadcast() - offer.UpdateOption(&OptMessageType{MessageType: MessageTypeOffer}) + offer.UpdateOption(OptMessageType(MessageTypeOffer)) req, err := NewRequestFromOffer(offer) require.Error(t, err) // Now add the option so it doesn't error out. - offer.UpdateOption(&OptServerIdentifier{ServerID: net.IPv4(192, 168, 0, 1)}) + offer.UpdateOption(OptServerIdentifier(net.IPv4(192, 168, 0, 1))) // Broadcast request req, err = NewRequestFromOffer(offer) @@ -235,13 +235,12 @@ func TestDHCPv4NewRequestFromOffer(t *testing.T) { func TestDHCPv4NewRequestFromOfferWithModifier(t *testing.T) { offer, err := New() require.NoError(t, err) - offer.UpdateOption(&OptMessageType{MessageType: MessageTypeOffer}) - offer.UpdateOption(&OptServerIdentifier{ServerID: net.IPv4(192, 168, 0, 1)}) + offer.UpdateOption(OptMessageType(MessageTypeOffer)) + offer.UpdateOption(OptServerIdentifier(net.IPv4(192, 168, 0, 1))) userClass := WithUserClass([]byte("linuxboot"), false) req, err := NewRequestFromOffer(offer, userClass) require.NoError(t, err) require.Equal(t, MessageTypeRequest, req.MessageType()) - require.Equal(t, "User Class Information -> linuxboot", req.Options[3].String()) } func TestNewReplyFromRequest(t *testing.T) { @@ -263,7 +262,6 @@ func TestNewReplyFromRequestWithModifier(t *testing.T) { require.NoError(t, err) require.Equal(t, discover.TransactionID, reply.TransactionID) require.Equal(t, discover.GatewayIPAddr, reply.GatewayIPAddr) - require.Equal(t, "User Class Information -> linuxboot", reply.Options[0].String()) } func TestDHCPv4MessageTypeNil(t *testing.T) { @@ -304,10 +302,33 @@ func TestIsOptionRequested(t *testing.T) { require.NoError(t, err) require.False(t, pkt.IsOptionRequested(OptionDomainNameServer)) - optprl := OptParameterRequestList{RequestedOpts: []OptionCode{OptionDomainNameServer}} - pkt.UpdateOption(&optprl) + optprl := OptParameterRequestList(OptionDomainNameServer) + pkt.UpdateOption(optprl) require.True(t, pkt.IsOptionRequested(OptionDomainNameServer)) } // TODO // test Summary() and String() +func TestSummary(t *testing.T) { + packet, err := New(WithMessageType(MessageTypeInform)) + packet.TransactionID = [4]byte{1, 1, 1, 1} + require.NoError(t, err) + + want := "DHCPv4 Message\n" + + " opcode: BootRequest\n" + + " hwtype: Ethernet\n" + + " hopcount: 0\n" + + " transaction ID: 0x01010101\n" + + " num seconds: 0\n" + + " flags: Unicast (0x00)\n" + + " client IP: 0.0.0.0\n" + + " your IP: 0.0.0.0\n" + + " server IP: 0.0.0.0\n" + + " gateway IP: 0.0.0.0\n" + + " client MAC: \n" + + " server hostname: \n" + + " bootfile name: \n" + + " options:\n" + + " DHCP Message Type: INFORM\n" + require.Equal(t, want, packet.Summary()) +} diff --git a/dhcpv4/modifiers.go b/dhcpv4/modifiers.go index 0759491..431fdfd 100644 --- a/dhcpv4/modifiers.go +++ b/dhcpv4/modifiers.go @@ -2,6 +2,7 @@ package dhcpv4 import ( "net" + "time" "github.com/insomniacslk/dhcp/iana" "github.com/insomniacslk/dhcp/rfc1035label" @@ -89,10 +90,13 @@ func WithOption(opt Option) Modifier { // rfc compliant or not. More details in issue #113 func WithUserClass(uc []byte, rfc bool) Modifier { // TODO let the user specify multiple user classes - return WithOption(&OptUserClass{ - UserClasses: [][]byte{uc}, - Rfc3004: rfc, - }) + return func(d *DHCPv4) { + if rfc { + d.UpdateOption(OptRFC3004UserClass([][]byte{uc})) + } else { + d.UpdateOption(OptUserClass(uc)) + } + } } // WithNetboot adds bootfile URL and bootfile param options to a DHCPv4 packet. @@ -102,7 +106,7 @@ func WithNetboot(d *DHCPv4) { // WithMessageType adds the DHCPv4 message type m to a packet. func WithMessageType(m MessageType) Modifier { - return WithOption(&OptMessageType{m}) + return WithOption(OptMessageType(m)) } // WithRequestedOptions adds requested options to the packet. @@ -110,10 +114,11 @@ func WithRequestedOptions(optionCodes ...OptionCode) Modifier { return func(d *DHCPv4) { params := d.GetOneOption(OptionParameterRequestList) if params == nil { - d.UpdateOption(&OptParameterRequestList{OptionCodeList(optionCodes)}) + d.UpdateOption(OptParameterRequestList(optionCodes...)) } else { - opts := params.(*OptParameterRequestList) - opts.RequestedOpts.Add(optionCodes...) + cl := OptionCodeList(GetParameterRequestList(d.Options)) + cl.Add(optionCodes...) + d.UpdateOption(OptParameterRequestList(cl...)) } } } @@ -124,33 +129,23 @@ func WithRelay(ip net.IP) Modifier { return func(d *DHCPv4) { d.SetUnicast() d.GatewayIPAddr = ip - d.HopCount += 1 + d.HopCount++ } } // WithNetmask adds or updates an OptSubnetMask func WithNetmask(mask net.IPMask) Modifier { - return WithOption(&OptSubnetMask{SubnetMask: mask}) + return WithOption(OptSubnetMask(mask)) } // WithLeaseTime adds or updates an OptIPAddressLeaseTime func WithLeaseTime(leaseTime uint32) Modifier { - return WithOption(&OptIPAddressLeaseTime{LeaseTime: leaseTime}) -} - -// WithDNS adds or updates an OptionDomainNameServer -func WithDNS(dnses ...net.IP) Modifier { - return WithOption(&OptDomainNameServer{NameServers: dnses}) + return WithOption(OptIPAddressLeaseTime(time.Duration(leaseTime) * time.Second)) } // WithDomainSearchList adds or updates an OptionDomainSearch func WithDomainSearchList(searchList ...string) Modifier { - return WithOption(&OptDomainSearch{DomainSearch: &rfc1035label.Labels{ + return WithOption(OptDomainSearch(&rfc1035label.Labels{ Labels: searchList, - }}) -} - -// WithRouter adds or updates an OptionRouter -func WithRouter(routers ...net.IP) Modifier { - return WithOption(&OptRouter{Routers: routers}) + })) } diff --git a/dhcpv4/modifiers_test.go b/dhcpv4/modifiers_test.go index 2cac2a0..6233a7d 100644 --- a/dhcpv4/modifiers_test.go +++ b/dhcpv4/modifiers_test.go @@ -3,6 +3,7 @@ package dhcpv4 import ( "net" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -35,13 +36,12 @@ func TestHwAddrModifier(t *testing.T) { } func TestWithOptionModifier(t *testing.T) { - d, err := New(WithOption(&OptDomainName{DomainName: "slackware.it"})) + d, err := New(WithOption(OptDomainName("slackware.it"))) require.NoError(t, err) - opt := d.GetOneOption(OptionDomainName) - require.NotNil(t, opt) - dnOpt := opt.(*OptDomainName) - require.Equal(t, "slackware.it", dnOpt.DomainName) + dnOpt := GetDomainName(d.Options) + require.NotNil(t, dnOpt) + require.Equal(t, "slackware.it", dnOpt) } func TestUserClassModifier(t *testing.T) { @@ -51,8 +51,7 @@ func TestUserClassModifier(t *testing.T) { expected := []byte{ 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } - require.Equal(t, "User Class Information -> linuxboot", d.Options[0].String()) - require.Equal(t, expected, d.Options[0].ToBytes()) + require.Equal(t, expected, d.GetOneOption(OptionUserClassInformation)) } func TestUserClassModifierRFC(t *testing.T) { @@ -62,43 +61,35 @@ func TestUserClassModifierRFC(t *testing.T) { expected := []byte{ 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } - require.Equal(t, "User Class Information -> linuxboot", d.Options[0].String()) - require.Equal(t, expected, d.Options[0].ToBytes()) + require.Equal(t, expected, d.GetOneOption(OptionUserClassInformation)) } func TestWithNetboot(t *testing.T) { d, err := New(WithNetboot) require.NoError(t, err) - require.Equal(t, "Parameter Request List -> TFTP Server Name, Bootfile Name", d.Options[0].String()) + require.Equal(t, "TFTP Server Name, Bootfile Name", GetParameterRequestList(d.Options).String()) } func TestWithNetbootExistingTFTP(t *testing.T) { - d, err := New() - require.NoError(t, err) - d.UpdateOption(&OptParameterRequestList{ - RequestedOpts: []OptionCode{OptionTFTPServerName}, - }) + d, _ := New() + d.UpdateOption(OptParameterRequestList(OptionTFTPServerName)) WithNetboot(d) - require.Equal(t, "Parameter Request List -> TFTP Server Name, Bootfile Name", d.Options[0].String()) + require.Equal(t, "TFTP Server Name, Bootfile Name", GetParameterRequestList(d.Options).String()) } func TestWithNetbootExistingBootfileName(t *testing.T) { d, _ := New() - d.UpdateOption(&OptParameterRequestList{ - RequestedOpts: []OptionCode{OptionBootfileName}, - }) + d.UpdateOption(OptParameterRequestList(OptionBootfileName)) WithNetboot(d) - require.Equal(t, "Parameter Request List -> Bootfile Name, TFTP Server Name", d.Options[0].String()) + require.Equal(t, "TFTP Server Name, Bootfile Name", GetParameterRequestList(d.Options).String()) } func TestWithNetbootExistingBoth(t *testing.T) { d, _ := New() - d.UpdateOption(&OptParameterRequestList{ - RequestedOpts: []OptionCode{OptionBootfileName, OptionTFTPServerName}, - }) + d.UpdateOption(OptParameterRequestList(OptionBootfileName, OptionTFTPServerName)) WithNetboot(d) - require.Equal(t, "Parameter Request List -> Bootfile Name, TFTP Server Name", d.Options[0].String()) + require.Equal(t, "TFTP Server Name, Bootfile Name", GetParameterRequestList(d.Options).String()) } func TestWithRequestedOptions(t *testing.T) { @@ -106,18 +97,16 @@ func TestWithRequestedOptions(t *testing.T) { d, err := New(WithRequestedOptions(OptionFQDN)) require.NoError(t, err) require.NotNil(t, d) - o := d.GetOneOption(OptionParameterRequestList) - require.NotNil(t, o) - opts := o.(*OptParameterRequestList) - require.ElementsMatch(t, opts.RequestedOpts, []OptionCode{OptionFQDN}) + opts := GetParameterRequestList(d.Options) + require.NotNil(t, opts) + require.ElementsMatch(t, opts, []OptionCode{OptionFQDN}) // Check if already set options are preserved WithRequestedOptions(OptionHostName)(d) require.NotNil(t, d) - o = d.GetOneOption(OptionParameterRequestList) - require.NotNil(t, o) - opts = o.(*OptParameterRequestList) - require.ElementsMatch(t, opts.RequestedOpts, []OptionCode{OptionFQDN, OptionHostName}) + opts = GetParameterRequestList(d.Options) + require.NotNil(t, opts) + require.ElementsMatch(t, opts, []OptionCode{OptionFQDN, OptionHostName}) } func TestWithRelay(t *testing.T) { @@ -134,56 +123,44 @@ func TestWithNetmask(t *testing.T) { d, err := New(WithNetmask(net.IPv4Mask(255, 255, 255, 0))) require.NoError(t, err) - require.Equal(t, 1, len(d.Options)) - require.Equal(t, OptionSubnetMask, d.Options[0].Code()) - osm := d.Options[0].(*OptSubnetMask) - require.Equal(t, net.IPv4Mask(255, 255, 255, 0), osm.SubnetMask) + osm := GetSubnetMask(d.Options) + require.Equal(t, net.IPv4Mask(255, 255, 255, 0), osm) } func TestWithLeaseTime(t *testing.T) { d, err := New(WithLeaseTime(uint32(3600))) require.NoError(t, err) - require.Equal(t, 1, len(d.Options)) - require.Equal(t, OptionIPAddressLeaseTime, d.Options[0].Code()) - olt := d.Options[0].(*OptIPAddressLeaseTime) - require.Equal(t, uint32(3600), olt.LeaseTime) + require.True(t, d.Options.Has(OptionIPAddressLeaseTime)) + olt := GetIPAddressLeaseTime(d.Options, 10*time.Second) + require.Equal(t, 3600*time.Second, olt) } func TestWithDNS(t *testing.T) { d, err := New(WithDNS(net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2"))) require.NoError(t, err) - require.Equal(t, 1, len(d.Options)) - require.Equal(t, OptionDomainNameServer, d.Options[0].Code()) - olt := d.Options[0].(*OptDomainNameServer) - require.Equal(t, 2, len(olt.NameServers)) - require.Equal(t, net.ParseIP("10.0.0.1"), olt.NameServers[0]) - require.Equal(t, net.ParseIP("10.0.0.2"), olt.NameServers[1]) - require.NotEqual(t, net.ParseIP("10.0.0.1"), olt.NameServers[1]) + dns := GetDNS(d.Options) + require.Equal(t, net.ParseIP("10.0.0.1").To4(), dns[0]) + require.Equal(t, net.ParseIP("10.0.0.2").To4(), dns[1]) } func TestWithDomainSearchList(t *testing.T) { d, err := New(WithDomainSearchList("slackware.it", "dhcp.slackware.it")) require.NoError(t, err) - require.Equal(t, 1, len(d.Options)) - osl := d.Options[0].(*OptDomainSearch) - require.Equal(t, OptionDNSDomainSearchList, osl.Code()) - require.NotNil(t, osl.DomainSearch) - require.Equal(t, 2, len(osl.DomainSearch.Labels)) - require.Equal(t, "slackware.it", osl.DomainSearch.Labels[0]) - require.Equal(t, "dhcp.slackware.it", osl.DomainSearch.Labels[1]) + osl := GetDomainSearch(d.Options) + require.NotNil(t, osl) + require.Equal(t, 2, len(osl.Labels)) + require.Equal(t, "slackware.it", osl.Labels[0]) + require.Equal(t, "dhcp.slackware.it", osl.Labels[1]) } func TestWithRouter(t *testing.T) { - rtr := net.ParseIP("10.0.0.254") + rtr := net.ParseIP("10.0.0.254").To4() d, err := New(WithRouter(rtr)) require.NoError(t, err) - require.Equal(t, 1, len(d.Options)) - ortr := d.Options[0].(*OptRouter) - require.Equal(t, OptionRouter, ortr.Code()) - require.Equal(t, 1, len(ortr.Routers)) - require.Equal(t, rtr, ortr.Routers[0]) + ortr := GetRouter(d.Options) + require.Equal(t, rtr, ortr[0]) } diff --git a/dhcpv4/option_archtype.go b/dhcpv4/option_archtype.go index 59dadb3..00a4417 100644 --- a/dhcpv4/option_archtype.go +++ b/dhcpv4/option_archtype.go @@ -1,55 +1,23 @@ package dhcpv4 import ( - "fmt" - "github.com/insomniacslk/dhcp/iana" - "github.com/u-root/u-root/pkg/uio" ) -// OptClientArchType represents an option encapsulating the Client System -// Architecture Type option definition. See RFC 4578. -type OptClientArchType struct { - ArchTypes []iana.Arch -} - -// Code returns the option code. -func (o *OptClientArchType) Code() OptionCode { - return OptionClientSystemArchitectureType -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptClientArchType) ToBytes() []byte { - buf := uio.NewBigEndianBuffer(nil) - for _, at := range o.ArchTypes { - buf.Write16(uint16(at)) - } - return buf.Data() +// OptClientArch returns a new Client System Architecture Type option. +func OptClientArch(archs ...iana.Arch) Option { + return Option{Code: OptionClientSystemArchitectureType, Value: iana.Archs(archs)} } -// String returns a human-readable string. -func (o *OptClientArchType) String() string { - var archTypes string - for idx, at := range o.ArchTypes { - archTypes += at.String() - if idx < len(o.ArchTypes)-1 { - archTypes += ", " - } +// GetClientArch returns the Client System Architecture Type option. +func GetClientArch(o Options) []iana.Arch { + v := o.Get(OptionClientSystemArchitectureType) + if v == nil { + return nil } - return fmt.Sprintf("Client System Architecture Type -> %v", archTypes) -} - -// ParseOptClientArchType returns a new OptClientArchType from a byte stream, -// or error if any. -func ParseOptClientArchType(data []byte) (*OptClientArchType, error) { - buf := uio.NewBigEndianBuffer(data) - if buf.Len() == 0 { - return nil, fmt.Errorf("must have at least one archtype if option is present") - } - - archTypes := make([]iana.Arch, 0, buf.Len()/2) - for buf.Has(2) { - archTypes = append(archTypes, iana.Arch(buf.Read16())) + var archs iana.Archs + if err := archs.FromBytes(v); err != nil { + return nil } - return &OptClientArchType{ArchTypes: archTypes}, buf.FinError() + return archs } diff --git a/dhcpv4/option_archtype_test.go b/dhcpv4/option_archtype_test.go index 60f8864..fcf526b 100644 --- a/dhcpv4/option_archtype_test.go +++ b/dhcpv4/option_archtype_test.go @@ -8,53 +8,42 @@ import ( ) func TestParseOptClientArchType(t *testing.T) { - data := []byte{ + o := Options{OptionClientSystemArchitectureType.Code(): []byte{ 0, 6, // EFI_IA32 - } - opt, err := ParseOptClientArchType(data) - require.NoError(t, err) - require.Equal(t, opt.ArchTypes[0], iana.EFI_IA32) + }} + archs := GetClientArch(o) + require.NotNil(t, archs) + require.Equal(t, archs[0], iana.EFI_IA32) } func TestParseOptClientArchTypeMultiple(t *testing.T) { - data := []byte{ + o := Options{OptionClientSystemArchitectureType.Code(): []byte{ 0, 6, // EFI_IA32 0, 2, // EFI_ITANIUM - } - opt, err := ParseOptClientArchType(data) - require.NoError(t, err) - require.Equal(t, opt.ArchTypes[0], iana.EFI_IA32) - require.Equal(t, opt.ArchTypes[1], iana.EFI_ITANIUM) + }} + archs := GetClientArch(o) + require.NotNil(t, archs) + require.Equal(t, archs[0], iana.EFI_IA32) + require.Equal(t, archs[1], iana.EFI_ITANIUM) } func TestParseOptClientArchTypeInvalid(t *testing.T) { - data := []byte{42} - _, err := ParseOptClientArchType(data) - require.Error(t, err) + o := Options{OptionClientSystemArchitectureType.Code(): []byte{42}} + archs := GetClientArch(o) + require.Nil(t, archs) } -func TestOptClientArchTypeParseAndToBytes(t *testing.T) { - data := []byte{ - 0, 8, // EFI_XSCALE - } - opt, err := ParseOptClientArchType(data) - require.NoError(t, err) - require.Equal(t, opt.ToBytes(), data) +func TestGetClientArchEmpty(t *testing.T) { + require.Nil(t, GetClientArch(Options{})) } func TestOptClientArchTypeParseAndToBytesMultiple(t *testing.T) { data := []byte{ - 0, 8, // EFI_XSCALE 0, 6, // EFI_IA32 + 0, 8, // EFI_XSCALE } - opt, err := ParseOptClientArchType(data) - require.NoError(t, err) - require.Equal(t, opt.ToBytes(), data) -} - -func TestOptClientArchType(t *testing.T) { - opt := OptClientArchType{ - ArchTypes: []iana.Arch{iana.EFI_ITANIUM}, - } - require.Equal(t, opt.Code(), OptionClientSystemArchitectureType) + opt := OptClientArch(iana.EFI_IA32, iana.EFI_XSCALE) + require.Equal(t, opt.Value.ToBytes(), data) + require.Equal(t, opt.Code, OptionClientSystemArchitectureType) + require.Equal(t, opt.String(), "Client System Architecture Type: EFI IA32, EFI Xscale") } diff --git a/dhcpv4/option_domain_search.go b/dhcpv4/option_domain_search.go index e352e34..6d2f7b2 100644 --- a/dhcpv4/option_domain_search.go +++ b/dhcpv4/option_domain_search.go @@ -1,41 +1,27 @@ package dhcpv4 import ( - "fmt" - "github.com/insomniacslk/dhcp/rfc1035label" ) -// OptDomainSearch implements the domain search list option described by RFC -// 3397, Section 2. +// OptDomainSearch returns a new domain search option. // -// FIXME: rename OptDomainSearch to OptDomainSearchList, and DomainSearch to -// SearchList, for consistency with the equivalent v6 option -type OptDomainSearch struct { - DomainSearch *rfc1035label.Labels -} - -// Code returns the option code. -func (op *OptDomainSearch) Code() OptionCode { - return OptionDNSDomainSearchList -} - -// ToBytes returns a serialized stream of bytes for this option. -func (op *OptDomainSearch) ToBytes() []byte { - return op.DomainSearch.ToBytes() +// The domain search option is described by RFC 3397, Section 2. +func OptDomainSearch(labels *rfc1035label.Labels) Option { + return Option{Code: OptionDNSDomainSearchList, Value: labels} } -// String returns a human-readable string. -func (op *OptDomainSearch) String() string { - return fmt.Sprintf("DNS Domain Search List -> %v", op.DomainSearch.Labels) -} - -// ParseOptDomainSearch returns a new OptDomainSearch from a byte stream, or -// error if any. -func ParseOptDomainSearch(data []byte) (*OptDomainSearch, error) { - labels, err := rfc1035label.FromBytes(data) +// GetDomainSearch returns the domain search list in o, if present. +// +// The domain search option is described by RFC 3397, Section 2. +func GetDomainSearch(o Options) *rfc1035label.Labels { + v := o.Get(OptionDNSDomainSearchList) + if v == nil { + return nil + } + labels, err := rfc1035label.FromBytes(v) if err != nil { - return nil, err + return nil } - return &OptDomainSearch{DomainSearch: labels}, nil + return labels } diff --git a/dhcpv4/option_domain_search_test.go b/dhcpv4/option_domain_search_test.go index 6425d57..9a508f2 100644 --- a/dhcpv4/option_domain_search_test.go +++ b/dhcpv4/option_domain_search_test.go @@ -7,17 +7,20 @@ import ( "github.com/stretchr/testify/require" ) -func TestParseOptDomainSearch(t *testing.T) { +func TestGetDomainSearch(t *testing.T) { data := []byte{ 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0, } - opt, err := ParseOptDomainSearch(data) - require.NoError(t, err) - require.Equal(t, 2, len(opt.DomainSearch.Labels)) - require.Equal(t, data, opt.DomainSearch.ToBytes()) - require.Equal(t, opt.DomainSearch.Labels[0], "example.com") - require.Equal(t, opt.DomainSearch.Labels[1], "subnet.example.org") + o := Options{ + OptionDNSDomainSearchList.Code(): data, + } + labels := GetDomainSearch(o) + require.NotNil(t, labels) + require.Equal(t, 2, len(labels.Labels)) + require.Equal(t, data, labels.ToBytes()) + require.Equal(t, labels.Labels[0], "example.com") + require.Equal(t, labels.Labels[1], "subnet.example.org") } func TestOptDomainSearchToBytes(t *testing.T) { @@ -25,13 +28,12 @@ func TestOptDomainSearchToBytes(t *testing.T) { 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0, } - opt := OptDomainSearch{ - DomainSearch: &rfc1035label.Labels{ - Labels: []string{ - "example.com", - "subnet.example.org", - }, + opt := OptDomainSearch(&rfc1035label.Labels{ + Labels: []string{ + "example.com", + "subnet.example.org", }, - } - require.Equal(t, opt.ToBytes(), expected) + }, + ) + require.Equal(t, opt.Value.ToBytes(), expected) } diff --git a/dhcpv4/option_generic.go b/dhcpv4/option_generic.go index 264340c..a54cdeb 100644 --- a/dhcpv4/option_generic.go +++ b/dhcpv4/option_generic.go @@ -1,7 +1,6 @@ package dhcpv4 import ( - "errors" "fmt" ) @@ -9,22 +8,7 @@ import ( // data. Every option that does not have a specific implementation will fall // back to this option. type OptionGeneric struct { - OptionCode OptionCode - Data []byte -} - -// ParseOptionGeneric parses a bytestream and creates a new OptionGeneric from -// it, or an error. -func ParseOptionGeneric(code OptionCode, data []byte) (Option, error) { - if len(data) == 0 { - return nil, errors.New("invalid zero-length bytestream") - } - return &OptionGeneric{OptionCode: code, Data: data}, nil -} - -// Code returns the generic option code. -func (o OptionGeneric) Code() OptionCode { - return o.OptionCode + Data []byte } // ToBytes returns a serialized generic option as a slice of bytes. @@ -34,5 +18,10 @@ func (o OptionGeneric) ToBytes() []byte { // String returns a human-readable representation of a generic option. func (o OptionGeneric) String() string { - return fmt.Sprintf("%v -> %v", o.OptionCode.String(), o.Data) + return fmt.Sprintf("%v", o.Data) +} + +// OptGeneric returns a generic option. +func OptGeneric(code OptionCode, value []byte) Option { + return Option{Code: code, Value: OptionGeneric{value}} } diff --git a/dhcpv4/option_generic_test.go b/dhcpv4/option_generic_test.go index ee35d65..4c4f2e8 100644 --- a/dhcpv4/option_generic_test.go +++ b/dhcpv4/option_generic_test.go @@ -6,42 +6,14 @@ import ( "github.com/stretchr/testify/require" ) -func TestParseOptionGeneric(t *testing.T) { - // Empty bytestream produces error - _, err := ParseOptionGeneric(OptionHostName, []byte{}) - require.Error(t, err, "error from empty bytestream") -} - func TestOptionGenericCode(t *testing.T) { - o := OptionGeneric{ - OptionCode: OptionDHCPMessageType, - Data: []byte{byte(MessageTypeDiscover)}, - } - require.Equal(t, OptionDHCPMessageType, o.Code()) -} - -func TestOptionGenericToBytes(t *testing.T) { - o := OptionGeneric{ - OptionCode: OptionDHCPMessageType, - Data: []byte{byte(MessageTypeDiscover)}, - } - serialized := o.ToBytes() - expected := []byte{1} - require.Equal(t, expected, serialized) -} - -func TestOptionGenericString(t *testing.T) { - o := OptionGeneric{ - OptionCode: OptionDHCPMessageType, - Data: []byte{byte(MessageTypeDiscover)}, - } - require.Equal(t, "DHCP Message Type -> [1]", o.String()) + o := OptGeneric(OptionDHCPMessageType, []byte{byte(MessageTypeDiscover)}) + require.Equal(t, OptionDHCPMessageType, o.Code) + require.Equal(t, []byte{1}, o.Value.ToBytes()) + require.Equal(t, "DHCP Message Type: [1]", o.String()) } func TestOptionGenericStringUnknown(t *testing.T) { - o := OptionGeneric{ - OptionCode: optionCode(102), // Returned option code. - Data: []byte{byte(MessageTypeDiscover)}, - } - require.Equal(t, "unknown (102) -> [1]", o.String()) + o := OptGeneric(optionCode(102), []byte{byte(MessageTypeDiscover)}) + require.Equal(t, "unknown (102): [1]", o.String()) } diff --git a/dhcpv4/option_ip.go b/dhcpv4/option_ip.go index ee0f5fe..6a4206c 100644 --- a/dhcpv4/option_ip.go +++ b/dhcpv4/option_ip.go @@ -1,92 +1,83 @@ package dhcpv4 import ( - "fmt" "net" "github.com/u-root/u-root/pkg/uio" ) -// OptBroadcastAddress implements the broadcast address option described in RFC -// 2132, Section 5.3. -type OptBroadcastAddress struct { - BroadcastAddress net.IP -} +// IP implements DHCPv4 IP option marshaling and unmarshaling as described by +// RFC 2132, Sections 5.3, 9.1, 9.7, and others. +type IP net.IP -// ParseOptBroadcastAddress returns a new OptBroadcastAddress from a byte -// stream, or error if any. -func ParseOptBroadcastAddress(data []byte) (*OptBroadcastAddress, error) { +// FromBytes parses an IP from data in binary form. +func (i *IP) FromBytes(data []byte) error { buf := uio.NewBigEndianBuffer(data) - return &OptBroadcastAddress{BroadcastAddress: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() -} - -// Code returns the option code. -func (o *OptBroadcastAddress) Code() OptionCode { - return OptionBroadcastAddress + *i = IP(buf.CopyN(net.IPv4len)) + return buf.FinError() } // ToBytes returns a serialized stream of bytes for this option. -func (o *OptBroadcastAddress) ToBytes() []byte { - return []byte(o.BroadcastAddress.To4()) +func (i IP) ToBytes() []byte { + return []byte(net.IP(i).To4()) } -// String returns a human-readable string. -func (o *OptBroadcastAddress) String() string { - return fmt.Sprintf("Broadcast Address -> %v", o.BroadcastAddress.String()) +// String returns a human-readable IP. +func (i IP) String() string { + return net.IP(i).String() } -// OptRequestedIPAddress implements the requested IP address option described -// by RFC 2132, Section 9.1. -type OptRequestedIPAddress struct { - RequestedAddr net.IP +// GetIP returns code out of o parsed as an IP. +func GetIP(code OptionCode, o Options) net.IP { + v := o.Get(code) + if v == nil { + return nil + } + var ip IP + if err := ip.FromBytes(v); err != nil { + return nil + } + return net.IP(ip) } -// ParseOptRequestedIPAddress returns a new OptRequestedIPAddress from a byte -// stream, or error if any. -func ParseOptRequestedIPAddress(data []byte) (*OptRequestedIPAddress, error) { - buf := uio.NewBigEndianBuffer(data) - return &OptRequestedIPAddress{RequestedAddr: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() +// GetBroadcastAddress returns the DHCPv4 Broadcast Address value in o. +// +// The broadcast address option is described in RFC 2132, Section 5.3. +func GetBroadcastAddress(o Options) net.IP { + return GetIP(OptionBroadcastAddress, o) } -// Code returns the option code. -func (o *OptRequestedIPAddress) Code() OptionCode { - return OptionRequestedIPAddress -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptRequestedIPAddress) ToBytes() []byte { - return o.RequestedAddr.To4() +// OptBroadcastAddress returns a new DHCPv4 Broadcast Address option. +// +// The broadcast address option is described in RFC 2132, Section 5.3. +func OptBroadcastAddress(ip net.IP) Option { + return Option{Code: OptionBroadcastAddress, Value: IP(ip)} } -// String returns a human-readable string. -func (o *OptRequestedIPAddress) String() string { - return fmt.Sprintf("Requested IP Address -> %v", o.RequestedAddr.String()) +// GetRequestedIPAddress returns the DHCPv4 Requested IP Address value in o. +// +// The requested IP address option is described by RFC 2132, Section 9.1. +func GetRequestedIPAddress(o Options) net.IP { + return GetIP(OptionRequestedIPAddress, o) } -// OptServerIdentifier implements the server identifier option described by RFC -// 2132, Section 9.7. -type OptServerIdentifier struct { - ServerID net.IP +// OptRequestedIPAddress returns a new DHCPv4 Requested IP Address option. +// +// The requested IP address option is described by RFC 2132, Section 9.1. +func OptRequestedIPAddress(ip net.IP) Option { + return Option{Code: OptionRequestedIPAddress, Value: IP(ip)} } -// ParseOptServerIdentifier returns a new OptServerIdentifier from a byte -// stream, or error if any. -func ParseOptServerIdentifier(data []byte) (*OptServerIdentifier, error) { - buf := uio.NewBigEndianBuffer(data) - return &OptServerIdentifier{ServerID: net.IP(buf.CopyN(net.IPv4len))}, buf.FinError() -} - -// Code returns the option code. -func (o *OptServerIdentifier) Code() OptionCode { - return OptionServerIdentifier -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptServerIdentifier) ToBytes() []byte { - return o.ServerID.To4() +// GetServerIdentifier returns the DHCPv4 Server Identifier value in o. +// +// The server identifier option is described by RFC 2132, Section 9.7. +func GetServerIdentifier(o Options) net.IP { + return GetIP(OptionServerIdentifier, o) } -// String returns a human-readable string. -func (o *OptServerIdentifier) String() string { - return fmt.Sprintf("Server Identifier -> %v", o.ServerID.String()) +// OptServerIdentifier returns a new DHCPv4 Server Identifier option. +// +// The server identifier option is described by RFC 2132, Section 9.7. +func OptServerIdentifier(ip net.IP) Option { + return Option{Code: OptionServerIdentifier, Value: IP(ip)} } diff --git a/dhcpv4/option_ip_address_lease_time.go b/dhcpv4/option_ip_address_lease_time.go index 4362419..6e09233 100644 --- a/dhcpv4/option_ip_address_lease_time.go +++ b/dhcpv4/option_ip_address_lease_time.go @@ -2,37 +2,53 @@ package dhcpv4 import ( "fmt" + "time" "github.com/u-root/u-root/pkg/uio" ) -// OptIPAddressLeaseTime implements the IP address lease time option described -// by RFC 2132, Section 9.2. -type OptIPAddressLeaseTime struct { - LeaseTime uint32 -} +// Duration implements the IP address lease time option described by RFC 2132, +// Section 9.2. +type Duration time.Duration -// ParseOptIPAddressLeaseTime constructs an OptIPAddressLeaseTime struct from a -// sequence of bytes and returns it, or an error. -func ParseOptIPAddressLeaseTime(data []byte) (*OptIPAddressLeaseTime, error) { +// FromBytes parses a duration from a byte stream according to RFC 2132, Section 9.2. +func (d *Duration) FromBytes(data []byte) error { buf := uio.NewBigEndianBuffer(data) - leaseTime := buf.Read32() - return &OptIPAddressLeaseTime{LeaseTime: leaseTime}, buf.FinError() -} - -// Code returns the option code. -func (o *OptIPAddressLeaseTime) Code() OptionCode { - return OptionIPAddressLeaseTime + *d = Duration(time.Duration(buf.Read32()) * time.Second) + return buf.FinError() } // ToBytes returns a serialized stream of bytes for this option. -func (o *OptIPAddressLeaseTime) ToBytes() []byte { +func (d Duration) ToBytes() []byte { buf := uio.NewBigEndianBuffer(nil) - buf.Write32(o.LeaseTime) + buf.Write32(uint32(time.Duration(d) / time.Second)) return buf.Data() } // String returns a human-readable string for this option. -func (o *OptIPAddressLeaseTime) String() string { - return fmt.Sprintf("IP Addresses Lease Time -> %v", o.LeaseTime) +func (d Duration) String() string { + return fmt.Sprintf("%s", time.Duration(d)) +} + +// OptIPAddressLeaseTime returns a new IP address lease time option. +// +// The IP address lease time option is described by RFC 2132, Section 9.2. +func OptIPAddressLeaseTime(d time.Duration) Option { + return Option{Code: OptionIPAddressLeaseTime, Value: Duration(d)} +} + +// GetIPAddressLeaseTime returns the IP address lease time in o, or the given +// default duration if not present. +// +// The IP address lease time option is described by RFC 2132, Section 9.2. +func GetIPAddressLeaseTime(o Options, def time.Duration) time.Duration { + v := o.Get(OptionIPAddressLeaseTime) + if v == nil { + return def + } + var d Duration + if err := d.FromBytes(v); err != nil { + return def + } + return time.Duration(d) } diff --git a/dhcpv4/option_ip_address_lease_time_test.go b/dhcpv4/option_ip_address_lease_time_test.go index 384db1c..70c4047 100644 --- a/dhcpv4/option_ip_address_lease_time_test.go +++ b/dhcpv4/option_ip_address_lease_time_test.go @@ -2,34 +2,33 @@ package dhcpv4 import ( "testing" + "time" "github.com/stretchr/testify/require" ) -func TestOptIPAddressLeaseTimeInterfaceMethods(t *testing.T) { - o := OptIPAddressLeaseTime{LeaseTime: 43200} - require.Equal(t, OptionIPAddressLeaseTime, o.Code(), "Code") - require.Equal(t, []byte{0, 0, 168, 192}, o.ToBytes(), "ToBytes") +func TestOptIPAddressLeaseTime(t *testing.T) { + o := OptIPAddressLeaseTime(43200 * time.Second) + require.Equal(t, OptionIPAddressLeaseTime, o.Code, "Code") + require.Equal(t, []byte{0, 0, 168, 192}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "IP Addresses Lease Time: 12h0m0s", o.String(), "String") } -func TestParseOptIPAddressLeaseTime(t *testing.T) { - data := []byte{0, 0, 168, 192} - o, err := ParseOptIPAddressLeaseTime(data) - require.NoError(t, err) - require.Equal(t, &OptIPAddressLeaseTime{LeaseTime: 43200}, o) +func TestGetIPAddressLeaseTime(t *testing.T) { + o := Options{OptionIPAddressLeaseTime.Code(): []byte{0, 0, 168, 192}} + leaseTime := GetIPAddressLeaseTime(o, 0) + require.Equal(t, 43200*time.Second, leaseTime) - // Short byte stream - data = []byte{168, 192} - _, err = ParseOptIPAddressLeaseTime(data) - require.Error(t, err, "should get error from short byte stream") + // Too short. + o = Options{OptionIPAddressLeaseTime.Code(): []byte{168, 192}} + leaseTime = GetIPAddressLeaseTime(o, 0) + require.Equal(t, time.Duration(0), leaseTime) - // Bad length - data = []byte{1, 1, 1, 1, 1} - _, err = ParseOptIPAddressLeaseTime(data) - require.Error(t, err, "should get error from bad length") -} + // Too long. + o = Options{OptionIPAddressLeaseTime.Code(): []byte{1, 1, 1, 1, 1}} + leaseTime = GetIPAddressLeaseTime(o, 0) + require.Equal(t, time.Duration(0), leaseTime) -func TestOptIPAddressLeaseTimeString(t *testing.T) { - o := OptIPAddressLeaseTime{LeaseTime: 43200} - require.Equal(t, "IP Addresses Lease Time -> 43200", o.String()) + // Empty. + require.Equal(t, time.Duration(10), GetIPAddressLeaseTime(Options{}, 10)) } diff --git a/dhcpv4/option_ip_test.go b/dhcpv4/option_ip_test.go index fe31487..e772224 100644 --- a/dhcpv4/option_ip_test.go +++ b/dhcpv4/option_ip_test.go @@ -8,61 +8,59 @@ import ( ) func TestOptBroadcastAddress(t *testing.T) { - o := OptBroadcastAddress{BroadcastAddress: net.IP{192, 168, 0, 1}} + ip := net.IP{192, 168, 0, 1} + o := OptBroadcastAddress(ip) - require.Equal(t, OptionBroadcastAddress, o.Code(), "Code") - require.Equal(t, []byte{192, 168, 0, 1}, o.ToBytes(), "ToBytes") - require.Equal(t, "Broadcast Address -> 192.168.0.1", o.String(), "String") + require.Equal(t, OptionBroadcastAddress, o.Code, "Code") + require.Equal(t, []byte(ip), o.Value.ToBytes(), "ToBytes") + require.Equal(t, "Broadcast Address: 192.168.0.1", o.String(), "String") } -func TestParseOptBroadcastAddress(t *testing.T) { - o, err := ParseOptBroadcastAddress([]byte{}) - require.Error(t, err, "empty byte stream") - - o, err = ParseOptBroadcastAddress([]byte{192, 168, 0}) - require.Error(t, err, "wrong IP length") +func TestGetIPs(t *testing.T) { + o := Options{102: []byte{}} + i := GetIPs(optionCode(102), o) + require.Nil(t, i) - o, err = ParseOptBroadcastAddress([]byte{192, 168, 0, 1}) - require.NoError(t, err) - require.Equal(t, net.IP{192, 168, 0, 1}, o.BroadcastAddress) -} + o = Options{102: []byte{192, 168, 0}} + i = GetIPs(optionCode(102), o) + require.Nil(t, i) -func TestOptRequestedIPAddress(t *testing.T) { - o := OptRequestedIPAddress{RequestedAddr: net.IP{192, 168, 0, 1}} + o = Options{102: []byte{192, 168, 0, 1}} + i = GetIPs(optionCode(102), o) + require.Equal(t, i, []net.IP{{192, 168, 0, 1}}) - require.Equal(t, OptionRequestedIPAddress, o.Code(), "Code") - require.Equal(t, []byte{192, 168, 0, 1}, o.ToBytes(), "ToBytes") - require.Equal(t, "Requested IP Address -> 192.168.0.1", o.String(), "String") + o = Options{102: []byte{192, 168, 0, 1, 192, 168, 0, 2}} + i = GetIPs(optionCode(102), o) + require.Equal(t, i, []net.IP{{192, 168, 0, 1}, {192, 168, 0, 2}}) } -func TestParseOptRequestedIPAddress(t *testing.T) { - o, err := ParseOptRequestedIPAddress([]byte{}) +func TestParseIP(t *testing.T) { + var ip IP + err := ip.FromBytes([]byte{}) require.Error(t, err, "empty byte stream") - o, err = ParseOptRequestedIPAddress([]byte{192}) + err = ip.FromBytes([]byte{192, 168, 0}) require.Error(t, err, "wrong IP length") - o, err = ParseOptRequestedIPAddress([]byte{192, 168, 0, 1}) + err = ip.FromBytes([]byte{192, 168, 0, 1}) require.NoError(t, err) - require.Equal(t, net.IP{192, 168, 0, 1}, o.RequestedAddr) + require.Equal(t, net.IP{192, 168, 0, 1}, net.IP(ip)) } -func TestOptServerIdentifierInterfaceMethods(t *testing.T) { - o := OptServerIdentifier{ServerID: net.IP{192, 168, 0, 1}} +func TestOptRequestedIPAddress(t *testing.T) { + ip := net.IP{192, 168, 0, 1} + o := OptRequestedIPAddress(ip) - require.Equal(t, OptionServerIdentifier, o.Code(), "Code") - require.Equal(t, []byte{192, 168, 0, 1}, o.ToBytes(), "ToBytes") - require.Equal(t, "Server Identifier -> 192.168.0.1", o.String(), "String") + require.Equal(t, OptionRequestedIPAddress, o.Code, "Code") + require.Equal(t, []byte(ip), o.Value.ToBytes(), "ToBytes") + require.Equal(t, "Requested IP Address: 192.168.0.1", o.String(), "String") } -func TestParseOptServerIdentifier(t *testing.T) { - o, err := ParseOptServerIdentifier([]byte{}) - require.Error(t, err, "empty byte stream") +func TestOptServerIdentifier(t *testing.T) { + ip := net.IP{192, 168, 0, 1} + o := OptServerIdentifier(ip) - o, err = ParseOptServerIdentifier([]byte{192, 168, 0}) - require.Error(t, err, "wrong IP length") - - o, err = ParseOptServerIdentifier([]byte{192, 168, 0, 1}) - require.NoError(t, err) - require.Equal(t, net.IP{192, 168, 0, 1}, o.ServerID) + require.Equal(t, OptionServerIdentifier, o.Code, "Code") + require.Equal(t, []byte(ip), o.Value.ToBytes(), "ToBytes") + require.Equal(t, "Server Identifier: 192.168.0.1", o.String(), "String") } diff --git a/dhcpv4/option_ips.go b/dhcpv4/option_ips.go index 8792ed0..693d62d 100644 --- a/dhcpv4/option_ips.go +++ b/dhcpv4/option_ips.go @@ -8,36 +8,40 @@ import ( "github.com/u-root/u-root/pkg/uio" ) -// ParseIPs parses an IPv4 address from a DHCP packet as used and specified by +// IPs are IPv4 addresses from a DHCP packet as used and specified by options +// in RFC 2132, Sections 3.5 through 3.13, 8.2, 8.3, 8.5, 8.6, 8.9, and 8.10. +// +// IPs implements the OptionValue type. +type IPs []net.IP + +// FromBytes parses an IPv4 address from a DHCP packet as used and specified by // options in RFC 2132, Sections 3.5 through 3.13, 8.2, 8.3, 8.5, 8.6, 8.9, and // 8.10. -func ParseIPs(data []byte) ([]net.IP, error) { +func (i *IPs) FromBytes(data []byte) error { buf := uio.NewBigEndianBuffer(data) - if buf.Len() == 0 { - return nil, fmt.Errorf("IP DHCP options must always list at least one IP") + return fmt.Errorf("IP DHCP options must always list at least one IP") } - ips := make([]net.IP, 0, buf.Len()/net.IPv4len) + *i = make(IPs, 0, buf.Len()/net.IPv4len) for buf.Has(net.IPv4len) { - ips = append(ips, net.IP(buf.CopyN(net.IPv4len))) + *i = append(*i, net.IP(buf.CopyN(net.IPv4len))) } - return ips, buf.FinError() + return buf.FinError() } -// IPsToBytes marshals an IPv4 address to a DHCP packet as specified by RFC -// 2132, Section 3.5 et al. -func IPsToBytes(i []net.IP) []byte { +// ToBytes marshals IPv4 addresses to a DHCP packet as specified by RFC 2132, +// Section 3.5 et al. +func (i IPs) ToBytes() []byte { buf := uio.NewBigEndianBuffer(nil) - for _, ip := range i { buf.WriteBytes(ip.To4()) } return buf.Data() } -// IPsToString returns a human-readable representation of a list of IPs. -func IPsToString(i []net.IP) string { +// String returns a human-readable representation of a list of IPs. +func (i IPs) String() string { s := make([]string, 0, len(i)) for _, ip := range i { s = append(s, ip.String()) @@ -45,92 +49,76 @@ func IPsToString(i []net.IP) string { return strings.Join(s, ", ") } -// OptRouter implements the router option described by RFC 2132, Section 3.5. -type OptRouter struct { - Routers []net.IP -} - -// ParseOptRouter returns a new OptRouter from a byte stream, or error if any. -func ParseOptRouter(data []byte) (*OptRouter, error) { - ips, err := ParseIPs(data) - if err != nil { - return nil, err +// GetIPs parses a list of IPs from code in o. +func GetIPs(code OptionCode, o Options) []net.IP { + v := o.Get(code) + if v == nil { + return nil } - return &OptRouter{Routers: ips}, nil -} - -// Code returns the option code. -func (o *OptRouter) Code() OptionCode { - return OptionRouter -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptRouter) ToBytes() []byte { - return IPsToBytes(o.Routers) -} - -// String returns a human-readable string. -func (o *OptRouter) String() string { - return fmt.Sprintf("Routers -> %s", IPsToString(o.Routers)) -} - -// OptNTPServers implements the NTP servers option described by RFC 2132, -// Section 8.3. -type OptNTPServers struct { - NTPServers []net.IP -} - -// ParseOptNTPServers returns a new OptNTPServers from a byte stream, or error if any. -func ParseOptNTPServers(data []byte) (*OptNTPServers, error) { - ips, err := ParseIPs(data) - if err != nil { - return nil, err + var ips IPs + if err := ips.FromBytes(v); err != nil { + return nil } - return &OptNTPServers{NTPServers: ips}, nil + return []net.IP(ips) } -// Code returns the option code. -func (o *OptNTPServers) Code() OptionCode { - return OptionNTPServers +// GetRouter parses the DHCPv4 Router option if present. +// +// The Router option is described by RFC 2132, Section 3.5. +func GetRouter(o Options) []net.IP { + return GetIPs(OptionRouter, o) } -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptNTPServers) ToBytes() []byte { - return IPsToBytes(o.NTPServers) +// OptRouter returns a new DHCPv4 Router option. +// +// The Router option is described by RFC 2132, Section 3.5. +func OptRouter(routers ...net.IP) Option { + return Option{ + Code: OptionRouter, + Value: IPs(routers), + } } -// String returns a human-readable string. -func (o *OptNTPServers) String() string { - return fmt.Sprintf("NTP Servers -> %v", IPsToString(o.NTPServers)) +// WithRouter updates a packet with the DHCPv4 Router option. +func WithRouter(routers ...net.IP) Modifier { + return WithOption(OptRouter(routers...)) } -// OptDomainNameServer implements the DNS server option described by RFC 2132, -// Section 3.8. -type OptDomainNameServer struct { - NameServers []net.IP +// GetNTPServers parses the DHCPv4 NTP Servers option if present. +// +// The NTP servers option is described by RFC 2132, Section 8.3. +func GetNTPServers(o Options) []net.IP { + return GetIPs(OptionNTPServers, o) } -// ParseOptDomainNameServer returns a new OptDomainNameServer from a byte -// stream, or error if any. -func ParseOptDomainNameServer(data []byte) (*OptDomainNameServer, error) { - ips, err := ParseIPs(data) - if err != nil { - return nil, err +// OptNTPServers returns a new DHCPv4 NTP Server option. +// +// The NTP servers option is described by RFC 2132, Section 8.3. +func OptNTPServers(ntpServers ...net.IP) Option { + return Option{ + Code: OptionNTPServers, + Value: IPs(ntpServers), } - return &OptDomainNameServer{NameServers: ips}, nil } -// Code returns the option code. -func (o *OptDomainNameServer) Code() OptionCode { - return OptionDomainNameServer +// GetDNS parses the DHCPv4 Domain Name Server option if present. +// +// The DNS server option is described by RFC 2132, Section 3.8. +func GetDNS(o Options) []net.IP { + return GetIPs(OptionDomainNameServer, o) } -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptDomainNameServer) ToBytes() []byte { - return IPsToBytes(o.NameServers) +// OptDNS returns a new DHCPv4 Domain Name Server option. +// +// The DNS server option is described by RFC 2132, Section 3.8. +func OptDNS(servers ...net.IP) Option { + return Option{ + Code: OptionDomainNameServer, + Value: IPs(servers), + } } -// String returns a human-readable string. -func (o *OptDomainNameServer) String() string { - return fmt.Sprintf("Domain Name Servers -> %s", IPsToString(o.NameServers)) +// WithDNS modifies a packet with the DHCPv4 Domain Name Server option. +func WithDNS(servers ...net.IP) Modifier { + return WithOption(OptDNS(servers...)) } diff --git a/dhcpv4/option_ips_test.go b/dhcpv4/option_ips_test.go index 5efd537..1e1b772 100644 --- a/dhcpv4/option_ips_test.go +++ b/dhcpv4/option_ips_test.go @@ -7,132 +7,77 @@ import ( "github.com/stretchr/testify/require" ) -func TestOptRoutersInterfaceMethods(t *testing.T) { - routers := []net.IP{ - net.IPv4(192, 168, 0, 10), - net.IPv4(192, 168, 0, 20), - } - o := OptRouter{Routers: routers} - require.Equal(t, OptionRouter, o.Code(), "Code") - require.Equal(t, routers, o.Routers, "Routers") -} - -func TestParseOptRouter(t *testing.T) { - data := []byte{ - byte(OptionRouter), - 8, // Length - 192, 168, 0, 10, // Router #1 - 192, 168, 0, 20, // Router #2 - } - o, err := ParseOptRouter(data[2:]) - require.NoError(t, err) - routers := []net.IP{ - net.IP{192, 168, 0, 10}, - net.IP{192, 168, 0, 20}, - } - require.Equal(t, &OptRouter{Routers: routers}, o) - - // Short byte stream - data = []byte{byte(OptionRouter)} - _, err = ParseOptRouter(data) - require.Error(t, err, "should get error from short byte stream") -} - -func TestParseOptRouterNoRouters(t *testing.T) { - // RFC2132 requires that at least one Router IP is specified - data := []byte{ - byte(OptionRouter), - 0, // Length - } - _, err := ParseOptRouter(data) - require.Error(t, err) -} - -func TestOptRouterString(t *testing.T) { - o := OptRouter{Routers: []net.IP{net.IP{192, 168, 0, 1}, net.IP{192, 168, 0, 10}}} - require.Equal(t, "Routers -> 192.168.0.1, 192.168.0.10", o.String()) -} - -func TestOptDomainNameServerInterfaceMethods(t *testing.T) { - servers := []net.IP{ - net.IPv4(192, 168, 0, 10), - net.IPv4(192, 168, 0, 20), - } - o := OptDomainNameServer{NameServers: servers} - require.Equal(t, OptionDomainNameServer, o.Code(), "Code") - require.Equal(t, servers, o.NameServers, "NameServers") -} - -func TestParseOptDomainNameServer(t *testing.T) { +func TestParseIPs(t *testing.T) { + var i IPs data := []byte{ - byte(OptionDomainNameServer), - 8, // Length 192, 168, 0, 10, // DNS #1 192, 168, 0, 20, // DNS #2 } - o, err := ParseOptDomainNameServer(data[2:]) + err := i.FromBytes(data) require.NoError(t, err) servers := []net.IP{ net.IP{192, 168, 0, 10}, net.IP{192, 168, 0, 20}, } - require.Equal(t, &OptDomainNameServer{NameServers: servers}, o) + require.Equal(t, servers, []net.IP(i)) // Bad length data = []byte{1, 1, 1} - _, err = ParseOptDomainNameServer(data) + err = i.FromBytes(data) require.Error(t, err, "should get error from bad length") -} -func TestParseOptDomainNameServerNoServers(t *testing.T) { - // RFC2132 requires that at least one DNS server IP is specified - _, err := ParseOptDomainNameServer([]byte{}) + // RFC2132 requires that at least one IP is specified for each IP field. + err = i.FromBytes([]byte{}) require.Error(t, err) } -func TestOptDomainNameServerString(t *testing.T) { - o := OptDomainNameServer{NameServers: []net.IP{net.IPv4(192, 168, 0, 1), net.IPv4(192, 168, 0, 10)}} - require.Equal(t, "Domain Name Servers -> 192.168.0.1, 192.168.0.10", o.String()) +func TestOptDomainNameServer(t *testing.T) { + o := OptDNS(net.IPv4(192, 168, 0, 1), net.IPv4(192, 168, 0, 10)) + require.Equal(t, OptionDomainNameServer, o.Code) + require.Equal(t, []byte{192, 168, 0, 1, 192, 168, 0, 10}, o.Value.ToBytes()) + require.Equal(t, "Domain Name Server: 192.168.0.1, 192.168.0.10", o.String()) } -func TestOptNTPServersInterfaceMethods(t *testing.T) { - ntpServers := []net.IP{ - net.IPv4(192, 168, 0, 10), - net.IPv4(192, 168, 0, 20), +func TestGetDomainNameServer(t *testing.T) { + ips := []net.IP{ + net.IP{192, 168, 0, 1}, + net.IP{192, 168, 0, 10}, } - o := OptNTPServers{NTPServers: ntpServers} - require.Equal(t, OptionNTPServers, o.Code(), "Code") - require.Equal(t, ntpServers, o.NTPServers, "NTPServers") + o := OptionsFromList(OptDNS(ips...)) + require.Equal(t, ips, GetDNS(o)) + require.Nil(t, GetDNS(Options{})) } -func TestParseOptNTPServers(t *testing.T) { - data := []byte{ - byte(OptionNTPServers), - 8, // Length - 192, 168, 0, 10, // NTP server #1 - 192, 168, 0, 20, // NTP server #2 - } - o, err := ParseOptNTPServers(data[2:]) - require.NoError(t, err) - ntpServers := []net.IP{ +func TestOptNTPServers(t *testing.T) { + o := OptNTPServers(net.IPv4(192, 168, 0, 1), net.IPv4(192, 168, 0, 10)) + require.Equal(t, OptionNTPServers, o.Code) + require.Equal(t, []byte{192, 168, 0, 1, 192, 168, 0, 10}, o.Value.ToBytes()) + require.Equal(t, "NTP Servers: 192.168.0.1, 192.168.0.10", o.String()) +} + +func TestGetNTPServers(t *testing.T) { + ips := []net.IP{ + net.IP{192, 168, 0, 1}, net.IP{192, 168, 0, 10}, - net.IP{192, 168, 0, 20}, } - require.Equal(t, &OptNTPServers{NTPServers: ntpServers}, o) - - // Bad length - data = []byte{1, 1, 1} - _, err = ParseOptNTPServers(data) - require.Error(t, err, "should get error from bad length") + o := OptionsFromList(OptNTPServers(ips...)) + require.Equal(t, ips, GetNTPServers(o)) + require.Nil(t, GetNTPServers(Options{})) } -func TestParseOptNTPserversNoNTPServers(t *testing.T) { - // RFC2132 requires that at least one NTP server IP is specified - _, err := ParseOptNTPServers([]byte{}) - require.Error(t, err) +func TestOptRouter(t *testing.T) { + o := OptRouter(net.IPv4(192, 168, 0, 1), net.IPv4(192, 168, 0, 10)) + require.Equal(t, OptionRouter, o.Code) + require.Equal(t, []byte{192, 168, 0, 1, 192, 168, 0, 10}, o.Value.ToBytes()) + require.Equal(t, "Router: 192.168.0.1, 192.168.0.10", o.String()) } -func TestOptNTPServersString(t *testing.T) { - o := OptNTPServers{NTPServers: []net.IP{net.IPv4(192, 168, 0, 1), net.IPv4(192, 168, 0, 10)}} - require.Equal(t, "NTP Servers -> 192.168.0.1, 192.168.0.10", o.String()) +func TestGetRouter(t *testing.T) { + ips := []net.IP{ + net.IP{192, 168, 0, 1}, + net.IP{192, 168, 0, 10}, + } + o := OptionsFromList(OptRouter(ips...)) + require.Equal(t, ips, GetRouter(o)) + require.Nil(t, GetRouter(Options{})) } diff --git a/dhcpv4/option_maximum_dhcp_message_size.go b/dhcpv4/option_maximum_dhcp_message_size.go index 904f3d2..900eea5 100644 --- a/dhcpv4/option_maximum_dhcp_message_size.go +++ b/dhcpv4/option_maximum_dhcp_message_size.go @@ -6,32 +6,52 @@ import ( "github.com/u-root/u-root/pkg/uio" ) -// OptMaximumDHCPMessageSize implements the maximum DHCP message size option -// described by RFC 2132, Section 9.10. -type OptMaximumDHCPMessageSize struct { - Size uint16 +// Uint16 implements encoding and decoding functions for a uint16 as used in +// RFC 2132, Section 9.10. +type Uint16 uint16 + +// ToBytes returns a serialized stream of bytes for this option. +func (o Uint16) ToBytes() []byte { + buf := uio.NewBigEndianBuffer(nil) + buf.Write16(uint16(o)) + return buf.Data() } -// ParseOptMaximumDHCPMessageSize constructs an OptMaximumDHCPMessageSize -// struct from a sequence of bytes and returns it, or an error. -func ParseOptMaximumDHCPMessageSize(data []byte) (*OptMaximumDHCPMessageSize, error) { +// String returns a human-readable string for this option. +func (o Uint16) String() string { + return fmt.Sprintf("%d", uint16(o)) +} + +// FromBytes decodes data into o as per RFC 2132, Section 9.10. +func (o *Uint16) FromBytes(data []byte) error { buf := uio.NewBigEndianBuffer(data) - return &OptMaximumDHCPMessageSize{Size: buf.Read16()}, buf.FinError() + *o = Uint16(buf.Read16()) + return buf.FinError() } -// Code returns the option code. -func (o *OptMaximumDHCPMessageSize) Code() OptionCode { - return OptionMaximumDHCPMessageSize +// GetUint16 parses a uint16 from code in o. +func GetUint16(code OptionCode, o Options) (uint16, error) { + v := o.Get(code) + if v == nil { + return 0, fmt.Errorf("option not present") + } + var u Uint16 + if err := u.FromBytes(v); err != nil { + return 0, err + } + return uint16(u), nil } -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptMaximumDHCPMessageSize) ToBytes() []byte { - buf := uio.NewBigEndianBuffer(nil) - buf.Write16(o.Size) - return buf.Data() +// OptMaxMessageSize returns a new DHCP Maximum Message Size option. +// +// The Maximum DHCP Message Size option is described by RFC 2132, Section 9.10. +func OptMaxMessageSize(size uint16) Option { + return Option{Code: OptionMaximumDHCPMessageSize, Value: Uint16(size)} } -// String returns a human-readable string for this option. -func (o *OptMaximumDHCPMessageSize) String() string { - return fmt.Sprintf("Maximum DHCP Message Size -> %v", o.Size) +// GetMaxMessageSize returns the DHCP Maximum Message Size in o if present. +// +// The Maximum DHCP Message Size option is described by RFC 2132, Section 9.10. +func GetMaxMessageSize(o Options) (uint16, error) { + return GetUint16(OptionMaximumDHCPMessageSize, o) } diff --git a/dhcpv4/option_maximum_dhcp_message_size_test.go b/dhcpv4/option_maximum_dhcp_message_size_test.go index 0b36b0e..147f280 100644 --- a/dhcpv4/option_maximum_dhcp_message_size_test.go +++ b/dhcpv4/option_maximum_dhcp_message_size_test.go @@ -6,30 +6,26 @@ import ( "github.com/stretchr/testify/require" ) -func TestOptMaximumDHCPMessageSizeInterfaceMethods(t *testing.T) { - o := OptMaximumDHCPMessageSize{Size: 1500} - require.Equal(t, OptionMaximumDHCPMessageSize, o.Code(), "Code") - require.Equal(t, []byte{5, 220}, o.ToBytes(), "ToBytes") +func TestOptMaximumDHCPMessageSize(t *testing.T) { + o := OptMaxMessageSize(1500) + require.Equal(t, OptionMaximumDHCPMessageSize, o.Code, "Code") + require.Equal(t, []byte{5, 220}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "Maximum DHCP Message Size: 1500", o.String()) } -func TestParseOptMaximumDHCPMessageSize(t *testing.T) { - data := []byte{5, 220} - o, err := ParseOptMaximumDHCPMessageSize(data) +func TestGetMaximumDHCPMessageSize(t *testing.T) { + options := Options{OptionMaximumDHCPMessageSize.Code(): []byte{5, 220}} + o, err := GetMaxMessageSize(options) require.NoError(t, err) - require.Equal(t, &OptMaximumDHCPMessageSize{Size: 1500}, o) + require.Equal(t, uint16(1500), o) // Short byte stream - data = []byte{2} - _, err = ParseOptMaximumDHCPMessageSize(data) + options = Options{OptionMaximumDHCPMessageSize.Code(): []byte{2}} + _, err = GetMaxMessageSize(options) require.Error(t, err, "should get error from short byte stream") // Bad length - data = []byte{1, 1, 1} - _, err = ParseOptMaximumDHCPMessageSize(data) + options = Options{OptionMaximumDHCPMessageSize.Code(): []byte{1, 1, 1}} + _, err = GetMaxMessageSize(options) require.Error(t, err, "should get error from bad length") } - -func TestOptMaximumDHCPMessageSizeString(t *testing.T) { - o := OptMaximumDHCPMessageSize{Size: 1500} - require.Equal(t, "Maximum DHCP Message Size -> 1500", o.String()) -} diff --git a/dhcpv4/option_message_type.go b/dhcpv4/option_message_type.go index 2857e41..141e3b7 100644 --- a/dhcpv4/option_message_type.go +++ b/dhcpv4/option_message_type.go @@ -1,35 +1,19 @@ package dhcpv4 -import ( - "fmt" - - "github.com/u-root/u-root/pkg/uio" -) - -// OptMessageType implements the DHCP message type option described by RFC -// 2132, Section 9.6. -type OptMessageType struct { - MessageType MessageType -} - -// ParseOptMessageType constructs an OptMessageType struct from a sequence of -// bytes and returns it, or an error. -func ParseOptMessageType(data []byte) (*OptMessageType, error) { - buf := uio.NewBigEndianBuffer(data) - return &OptMessageType{MessageType: MessageType(buf.Read8())}, buf.FinError() -} - -// Code returns the option code. -func (o *OptMessageType) Code() OptionCode { - return OptionDHCPMessageType -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptMessageType) ToBytes() []byte { - return []byte{byte(o.MessageType)} +// OptMessageType returns a new DHCPv4 Message Type option. +func OptMessageType(m MessageType) Option { + return Option{Code: OptionDHCPMessageType, Value: m} } -// String returns a human-readable string for this option. -func (o *OptMessageType) String() string { - return fmt.Sprintf("DHCP Message Type -> %s", o.MessageType.String()) +// GetMessageType returns the DHCPv4 Message Type option in o. +func GetMessageType(o Options) MessageType { + v := o.Get(OptionDHCPMessageType) + if v == nil { + return MessageTypeNone + } + var m MessageType + if err := m.FromBytes(v); err != nil { + return MessageTypeNone + } + return m } diff --git a/dhcpv4/option_message_type_test.go b/dhcpv4/option_message_type_test.go index c3b4904..a97889e 100644 --- a/dhcpv4/option_message_type_test.go +++ b/dhcpv4/option_message_type_test.go @@ -6,36 +6,32 @@ import ( "github.com/stretchr/testify/require" ) -func TestOptMessageTypeInterfaceMethods(t *testing.T) { - o := OptMessageType{MessageType: MessageTypeDiscover} - require.Equal(t, OptionDHCPMessageType, o.Code(), "Code") - require.Equal(t, []byte{1}, o.ToBytes(), "ToBytes") -} +func TestOptMessageType(t *testing.T) { + o := OptMessageType(MessageTypeDiscover) + require.Equal(t, OptionDHCPMessageType, o.Code, "Code") + require.Equal(t, []byte{1}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "DHCP Message Type: DISCOVER", o.String()) -func TestOptMessageTypeNew(t *testing.T) { - o := OptMessageType{MessageType: MessageTypeDiscover} - require.Equal(t, OptionDHCPMessageType, o.Code()) - require.Equal(t, MessageTypeDiscover, o.MessageType) + // unknown + o = OptMessageType(99) + require.Equal(t, "DHCP Message Type: unknown (99)", o.String()) } func TestParseOptMessageType(t *testing.T) { + var m MessageType data := []byte{1} // DISCOVER - o, err := ParseOptMessageType(data) + err := m.FromBytes(data) require.NoError(t, err) - require.Equal(t, &OptMessageType{MessageType: MessageTypeDiscover}, o) + require.Equal(t, MessageTypeDiscover, m) // Bad length data = []byte{1, 2} - _, err = ParseOptMessageType(data) + err = m.FromBytes(data) require.Error(t, err, "should get error from bad length") } -func TestOptMessageTypeString(t *testing.T) { - // known - o := OptMessageType{MessageType: MessageTypeDiscover} - require.Equal(t, "DHCP Message Type -> DISCOVER", o.String()) - - // unknown - o = OptMessageType{MessageType: 99} - require.Equal(t, "DHCP Message Type -> unknown (99)", o.String()) +func TestGetMessageType(t *testing.T) { + o := OptionsFromList(OptMessageType(MessageTypeDiscover)) + require.Equal(t, MessageTypeDiscover, GetMessageType(o)) + require.Equal(t, MessageTypeNone, GetMessageType(Options{})) } diff --git a/dhcpv4/option_parameter_request_list.go b/dhcpv4/option_parameter_request_list.go index 750c957..81d5efb 100644 --- a/dhcpv4/option_parameter_request_list.go +++ b/dhcpv4/option_parameter_request_list.go @@ -1,7 +1,7 @@ package dhcpv4 import ( - "fmt" + "sort" "strings" "github.com/u-root/u-root/pkg/uio" @@ -29,47 +29,59 @@ func (ol *OptionCodeList) Add(cs ...OptionCode) { } } +func (ol OptionCodeList) sort() { + sort.Slice(ol, func(i, j int) bool { return ol[i].Code() < ol[j].Code() }) +} + // String returns a human-readable string for the option names. func (ol OptionCodeList) String() string { var names []string + ol.sort() for _, code := range ol { names = append(names, code.String()) } return strings.Join(names, ", ") } -// OptParameterRequestList implements the parameter request list option -// described by RFC 2132, Section 9.8. -type OptParameterRequestList struct { - RequestedOpts OptionCodeList +// ToBytes returns a serialized stream of bytes for this option as defined by +// RFC 2132, Section 9.8. +func (ol OptionCodeList) ToBytes() []byte { + buf := uio.NewBigEndianBuffer(nil) + for _, req := range ol { + buf.Write8(req.Code()) + } + return buf.Data() } -// ParseOptParameterRequestList returns a new OptParameterRequestList from a -// byte stream, or error if any. -func ParseOptParameterRequestList(data []byte) (*OptParameterRequestList, error) { +// FromBytes parses a byte stream for this option as described by RFC 2132, +// Section 9.8. +func (ol *OptionCodeList) FromBytes(data []byte) error { buf := uio.NewBigEndianBuffer(data) - requestedOpts := make(OptionCodeList, 0, buf.Len()) + *ol = make(OptionCodeList, 0, buf.Len()) for buf.Has(1) { - requestedOpts = append(requestedOpts, optionCode(buf.Read8())) + *ol = append(*ol, optionCode(buf.Read8())) } - return &OptParameterRequestList{RequestedOpts: requestedOpts}, buf.Error() + return buf.FinError() } -// Code returns the option code. -func (o *OptParameterRequestList) Code() OptionCode { - return OptionParameterRequestList -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptParameterRequestList) ToBytes() []byte { - buf := uio.NewBigEndianBuffer(nil) - for _, req := range o.RequestedOpts { - buf.Write8(req.Code()) +// GetParameterRequestList returns the DHCPv4 Parameter Request List in o. +// +// The parameter request list option is described by RFC 2132, Section 9.8. +func GetParameterRequestList(o Options) OptionCodeList { + v := o.Get(OptionParameterRequestList) + if v == nil { + return nil } - return buf.Data() + var codes OptionCodeList + if err := codes.FromBytes(v); err != nil { + return nil + } + return codes } -// String returns a human-readable string for this option. -func (o *OptParameterRequestList) String() string { - return fmt.Sprintf("Parameter Request List -> %s", o.RequestedOpts) +// OptParameterRequestList returns a new DHCPv4 Parameter Request List. +// +// The parameter request list option is described by RFC 2132, Section 9.8. +func OptParameterRequestList(codes ...OptionCode) Option { + return Option{Code: OptionParameterRequestList, Value: OptionCodeList(codes)} } diff --git a/dhcpv4/option_parameter_request_list_test.go b/dhcpv4/option_parameter_request_list_test.go index a09aaad..7c358e2 100644 --- a/dhcpv4/option_parameter_request_list_test.go +++ b/dhcpv4/option_parameter_request_list_test.go @@ -7,24 +7,22 @@ import ( ) func TestOptParameterRequestListInterfaceMethods(t *testing.T) { - requestedOpts := []OptionCode{OptionBootfileName, OptionNameServer} - o := &OptParameterRequestList{RequestedOpts: requestedOpts} - require.Equal(t, OptionParameterRequestList, o.Code(), "Code") + opts := []OptionCode{OptionBootfileName, OptionNameServer} + o := OptParameterRequestList(opts...) + + require.Equal(t, OptionParameterRequestList, o.Code, "Code") expectedBytes := []byte{67, 5} - require.Equal(t, expectedBytes, o.ToBytes(), "ToBytes") + require.Equal(t, expectedBytes, o.Value.ToBytes(), "ToBytes") - expectedString := "Parameter Request List -> Bootfile Name, Name Server" + expectedString := "Parameter Request List: Name Server, Bootfile Name" require.Equal(t, expectedString, o.String(), "String") } func TestParseOptParameterRequestList(t *testing.T) { - var ( - o *OptParameterRequestList - err error - ) - o, err = ParseOptParameterRequestList([]byte{67, 5}) + var o OptionCodeList + err := o.FromBytes([]byte{67, 5}) require.NoError(t, err) expectedOpts := OptionCodeList{OptionBootfileName, OptionNameServer} - require.Equal(t, expectedOpts, o.RequestedOpts) + require.Equal(t, expectedOpts, o) } diff --git a/dhcpv4/option_relay_agent_information.go b/dhcpv4/option_relay_agent_information.go index 42625ca..fb86c70 100644 --- a/dhcpv4/option_relay_agent_information.go +++ b/dhcpv4/option_relay_agent_information.go @@ -2,37 +2,52 @@ package dhcpv4 import ( "fmt" - - "github.com/u-root/u-root/pkg/uio" ) -// OptRelayAgentInformation implements the relay agent info option described by -// RFC 3046. -type OptRelayAgentInformation struct { - Options Options +// RelayOptions is like Options, but stringifies using the Relay Agent Specific +// option space. +type RelayOptions struct { + Options } -// ParseOptRelayAgentInformation returns a new OptRelayAgentInformation from a -// byte stream, or error if any. -func ParseOptRelayAgentInformation(data []byte) (*OptRelayAgentInformation, error) { - options, err := OptionsFromBytesWithParser(data, codeGetter, ParseOptionGeneric, false /* don't check for OptionEnd tag */) - if err != nil { - return nil, err - } - return &OptRelayAgentInformation{Options: options}, nil +var relayHumanizer = OptionHumanizer{ + ValueHumanizer: func(code OptionCode, data []byte) fmt.Stringer { + return OptionGeneric{data} + }, + CodeHumanizer: func(c uint8) OptionCode { + return GenericOptionCode(c) + }, +} + +// String prints the contained options using Relay Agent-specific option code parsing. +func (r RelayOptions) String() string { + return r.Options.ToString(relayHumanizer) } -// Code returns the option code. -func (o *OptRelayAgentInformation) Code() OptionCode { - return OptionRelayAgentInformation +// FromBytes parses relay agent options from data. +func (r *RelayOptions) FromBytes(data []byte) error { + r.Options = make(Options) + return r.Options.FromBytes(data) } -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptRelayAgentInformation) ToBytes() []byte { - return uio.ToBigEndian(o.Options) +// OptRelayAgentInfo returns a new DHCP Relay Agent Info option. +// +// The relay agent info option is described by RFC 3046. +func OptRelayAgentInfo(o ...Option) Option { + return Option{Code: OptionRelayAgentInformation, Value: RelayOptions{OptionsFromList(o...)}} } -// String returns a human-readable string for this option. -func (o *OptRelayAgentInformation) String() string { - return fmt.Sprintf("Relay Agent Information -> %v", o.Options) +// GetRelayAgentInfo returns options embedded by the relay agent. +// +// The relay agent info option is described by RFC 3046. +func GetRelayAgentInfo(o Options) *RelayOptions { + v := o.Get(OptionRelayAgentInformation) + if v == nil { + return nil + } + var relayOptions RelayOptions + if err := relayOptions.FromBytes(v); err != nil { + return nil + } + return &relayOptions } diff --git a/dhcpv4/option_relay_agent_information_test.go b/dhcpv4/option_relay_agent_information_test.go index bb5fae0..6a7b275 100644 --- a/dhcpv4/option_relay_agent_information_test.go +++ b/dhcpv4/option_relay_agent_information_test.go @@ -6,47 +6,46 @@ import ( "github.com/stretchr/testify/require" ) -func TestParseOptRelayAgentInformation(t *testing.T) { - data := []byte{ - 1, 5, 'l', 'i', 'n', 'u', 'x', - 2, 4, 'b', 'o', 'o', 't', +func TestGetRelayAgentInformation(t *testing.T) { + o := Options{ + OptionRelayAgentInformation.Code(): []byte{ + 1, 5, 'l', 'i', 'n', 'u', 'x', + 2, 4, 'b', 'o', 'o', 't', + }, } - // short sub-option bytes - opt, err := ParseOptRelayAgentInformation([]byte{1, 0, 1}) - require.Error(t, err) + opt := GetRelayAgentInfo(o) + require.NotNil(t, opt) + require.Equal(t, len(opt.Options), 2) - // short sub-option length - opt, err = ParseOptRelayAgentInformation([]byte{1, 1}) - require.Error(t, err) + circuit := opt.Get(GenericOptionCode(1)) + remote := opt.Get(GenericOptionCode(2)) + require.Equal(t, circuit, []byte("linux")) + require.Equal(t, remote, []byte("boot")) - opt, err = ParseOptRelayAgentInformation(data) - require.NoError(t, err) - require.Equal(t, len(opt.Options), 2) - circuit := opt.Options.GetOne(optionCode(1)).(*OptionGeneric) - require.NoError(t, err) - remote := opt.Options.GetOne(optionCode(2)).(*OptionGeneric) - require.NoError(t, err) - require.Equal(t, circuit.Data, []byte("linux")) - require.Equal(t, remote.Data, []byte("boot")) -} + // Empty. + require.Nil(t, GetRelayAgentInfo(Options{})) -func TestParseOptRelayAgentInformationToBytes(t *testing.T) { - opt := OptRelayAgentInformation{ - Options: Options{ - &OptionGeneric{OptionCode: optionCode(1), Data: []byte("linux")}, - &OptionGeneric{OptionCode: optionCode(2), Data: []byte("boot")}, + // Invalid contents. + o = Options{ + OptionRelayAgentInformation.Code(): []byte{ + 1, 7, 'l', 'i', 'n', 'u', 'x', }, } - data := opt.ToBytes() - expected := []byte{ + require.Nil(t, GetRelayAgentInfo(o)) +} + +func TestOptRelayAgentInfo(t *testing.T) { + opt := OptRelayAgentInfo( + OptGeneric(GenericOptionCode(1), []byte("linux")), + OptGeneric(GenericOptionCode(2), []byte("boot")), + ) + wantBytes := []byte{ 1, 5, 'l', 'i', 'n', 'u', 'x', 2, 4, 'b', 'o', 'o', 't', } - require.Equal(t, expected, data) -} - -func TestOptRelayAgentInformationToBytesString(t *testing.T) { - o := OptRelayAgentInformation{} - require.Equal(t, "Relay Agent Information -> []", o.String()) + wantString := "Relay Agent Information:\n unknown (1): [108 105 110 117 120]\n unknown (2): [98 111 111 116]\n" + require.Equal(t, wantBytes, opt.Value.ToBytes()) + require.Equal(t, OptionRelayAgentInformation, opt.Code) + require.Equal(t, wantString, opt.String()) } diff --git a/dhcpv4/option_string.go b/dhcpv4/option_string.go index 76a2db2..9e16d6c 100644 --- a/dhcpv4/option_string.go +++ b/dhcpv4/option_string.go @@ -1,163 +1,117 @@ package dhcpv4 -import ( - "fmt" -) - -// OptDomainName implements the domain name option described in RFC 2132, -// Section 3.17. -type OptDomainName struct { - DomainName string -} - -// ParseOptDomainName returns a new OptDomainName from a byte stream, or error -// if any. -func ParseOptDomainName(data []byte) (*OptDomainName, error) { - return &OptDomainName{DomainName: string(data)}, nil -} - -// Code returns the option code. -func (o *OptDomainName) Code() OptionCode { - return OptionDomainName -} - -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptDomainName) ToBytes() []byte { - return []byte(o.DomainName) -} - -// String returns a human-readable string. -func (o *OptDomainName) String() string { - return fmt.Sprintf("Domain Name -> %v", o.DomainName) -} - -// OptHostName implements the host name option described by RFC 2132, Section -// 3.14. -type OptHostName struct { - HostName string -} - -// ParseOptHostName returns a new OptHostName from a byte stream, or error if -// any. -func ParseOptHostName(data []byte) (*OptHostName, error) { - return &OptHostName{HostName: string(data)}, nil -} - -// Code returns the option code. -func (o *OptHostName) Code() OptionCode { - return OptionHostName -} +// String represents an option encapsulating a string in IPv4 DHCP. +// +// This representation is shared by multiple options specified by RFC 2132, +// Sections 3.14, 3.16, 3.17, 3.19, and 3.20. +type String string // ToBytes returns a serialized stream of bytes for this option. -func (o *OptHostName) ToBytes() []byte { - return []byte(o.HostName) +func (o String) ToBytes() []byte { + return []byte(o) } // String returns a human-readable string. -func (o *OptHostName) String() string { - return fmt.Sprintf("Host Name -> %v", o.HostName) -} - -// OptRootPath implements the root path option described by RFC 2132, Section -// 3.19. -type OptRootPath struct { - Path string -} - -// ParseOptRootPath constructs an OptRootPath struct from a sequence of bytes -// and returns it, or an error. -func ParseOptRootPath(data []byte) (*OptRootPath, error) { - return &OptRootPath{Path: string(data)}, nil -} - -// Code returns the option code. -func (o *OptRootPath) Code() OptionCode { - return OptionRootPath +func (o String) String() string { + return string(o) } -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptRootPath) ToBytes() []byte { - return []byte(o.Path) -} - -// String returns a human-readable string for this option. -func (o *OptRootPath) String() string { - return fmt.Sprintf("Root Path -> %v", o.Path) -} - -// OptBootfileName implements the bootfile name option described in RFC 2132, -// Section 9.5. -type OptBootfileName struct { - BootfileName string -} - -// Code returns the option code -func (op *OptBootfileName) Code() OptionCode { - return OptionBootfileName +// FromBytes parses a serialized stream of bytes into o. +func (o *String) FromBytes(data []byte) error { + *o = String(string(data)) + return nil } -// ToBytes serializes the option and returns it as a sequence of bytes -func (op *OptBootfileName) ToBytes() []byte { - return []byte(op.BootfileName) +// GetString parses an RFC 2132 string from o[code]. +func GetString(code OptionCode, o Options) string { + v := o.Get(code) + if v == nil { + return "" + } + return string(v) } -func (op *OptBootfileName) String() string { - return fmt.Sprintf("Bootfile Name -> %s", op.BootfileName) +// OptDomainName returns a new DHCPv4 Domain Name option. +// +// The Domain Name option is described by RFC 2132, Section 3.17. +func OptDomainName(name string) Option { + return Option{Code: OptionDomainName, Value: String(name)} } -// ParseOptBootfileName returns a new OptBootfile from a byte stream or error if any -func ParseOptBootfileName(data []byte) (*OptBootfileName, error) { - return &OptBootfileName{BootfileName: string(data)}, nil +// GetDomainName parses the DHCPv4 Domain Name option from o if present. +// +// The Domain Name option is described by RFC 2132, Section 3.17. +func GetDomainName(o Options) string { + return GetString(OptionDomainName, o) } -// OptTFTPServerName implements the TFTP server name option described by RFC -// 2132, Section 9.4. -type OptTFTPServerName struct { - TFTPServerName string +// OptHostName returns a new DHCPv4 Host Name option. +// +// The Host Name option is described by RFC 2132, Section 3.14. +func OptHostName(name string) Option { + return Option{Code: OptionHostName, Value: String(name)} } -// Code returns the option code -func (op *OptTFTPServerName) Code() OptionCode { - return OptionTFTPServerName +// GetHostName parses the DHCPv4 Host Name option from o if present. +// +// The Host Name option is described by RFC 2132, Section 3.14. +func GetHostName(o Options) string { + return GetString(OptionHostName, o) } -// ToBytes serializes the option and returns it as a sequence of bytes -func (op *OptTFTPServerName) ToBytes() []byte { - return []byte(op.TFTPServerName) +// OptRootPath returns a new DHCPv4 Root Path option. +// +// The Root Path option is described by RFC 2132, Section 3.19. +func OptRootPath(name string) Option { + return Option{Code: OptionRootPath, Value: String(name)} } -func (op *OptTFTPServerName) String() string { - return fmt.Sprintf("TFTP Server Name -> %s", op.TFTPServerName) +// GetRootPath parses the DHCPv4 Root Path option from o if present. +// +// The Root Path option is described by RFC 2132, Section 3.19. +func GetRootPath(o Options) string { + return GetString(OptionRootPath, o) } -// ParseOptTFTPServerName returns a new OptTFTPServerName from a byte stream or error if any -func ParseOptTFTPServerName(data []byte) (*OptTFTPServerName, error) { - return &OptTFTPServerName{TFTPServerName: string(data)}, nil +// OptBootFileName returns a new DHCPv4 Boot File Name option. +// +// The Bootfile Name option is described by RFC 2132, Section 9.5. +func OptBootFileName(name string) Option { + return Option{Code: OptionBootfileName, Value: String(name)} } -// OptClassIdentifier implements the vendor class identifier option described -// in RFC 2132, Section 9.13. -type OptClassIdentifier struct { - Identifier string +// GetBootFileName parses the DHCPv4 Bootfile Name option from o if present. +// +// The Bootfile Name option is described by RFC 2132, Section 9.5. +func GetBootFileName(o Options) string { + return GetString(OptionBootfileName, o) } -// ParseOptClassIdentifier constructs an OptClassIdentifier struct from a sequence of -// bytes and returns it, or an error. -func ParseOptClassIdentifier(data []byte) (*OptClassIdentifier, error) { - return &OptClassIdentifier{Identifier: string(data)}, nil +// OptTFTPServerName returns a new DHCPv4 TFTP Server Name option. +// +// The TFTP Server Name option is described by RFC 2132, Section 9.4. +func OptTFTPServerName(name string) Option { + return Option{Code: OptionTFTPServerName, Value: String(name)} } -// Code returns the option code. -func (o *OptClassIdentifier) Code() OptionCode { - return OptionClassIdentifier +// GetTFTPServerName parses the DHCPv4 TFTP Server Name option from o if +// present. +// +// The TFTP Server Name option is described by RFC 2132, Section 9.4. +func GetTFTPServerName(o Options) string { + return GetString(OptionTFTPServerName, o) } -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptClassIdentifier) ToBytes() []byte { - return []byte(o.Identifier) +// OptClassIdentifier returns a new DHCPv4 Class Identifier option. +// +// The Vendor Class Identifier option is described by RFC 2132, Section 9.13. +func OptClassIdentifier(name string) Option { + return Option{Code: OptionClassIdentifier, Value: String(name)} } -// String returns a human-readable string for this option. -func (o *OptClassIdentifier) String() string { - return fmt.Sprintf("Class Identifier -> %v", o.Identifier) +// GetClassIdentifier parses the DHCPv4 Class Identifier option from o if present. +// +// The Vendor Class Identifier option is described by RFC 2132, Section 9.13. +func GetClassIdentifier(o Options) string { + return GetString(OptionClassIdentifier, o) } diff --git a/dhcpv4/option_string_test.go b/dhcpv4/option_string_test.go index 0704f31..bda6009 100644 --- a/dhcpv4/option_string_test.go +++ b/dhcpv4/option_string_test.go @@ -7,96 +7,83 @@ import ( ) func TestOptDomainName(t *testing.T) { - o := OptDomainName{DomainName: "foo"} - require.Equal(t, OptionDomainName, o.Code(), "Code") - require.Equal(t, []byte{'f', 'o', 'o'}, o.ToBytes(), "ToBytes") - require.Equal(t, "Domain Name -> foo", o.String()) + o := OptDomainName("foo") + require.Equal(t, OptionDomainName, o.Code, "Code") + require.Equal(t, []byte{'f', 'o', 'o'}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "Domain Name: foo", o.String()) } func TestParseOptDomainName(t *testing.T) { - data := []byte{'t', 'e', 's', 't'} - o, err := ParseOptDomainName(data) - require.NoError(t, err) - require.Equal(t, &OptDomainName{DomainName: "test"}, o) + o := Options{ + OptionDomainName.Code(): []byte{'t', 'e', 's', 't'}, + } + require.Equal(t, "test", GetDomainName(o)) + require.Equal(t, "", GetDomainName(Options{})) } func TestOptHostName(t *testing.T) { - o := OptHostName{HostName: "foo"} - require.Equal(t, OptionHostName, o.Code(), "Code") - require.Equal(t, []byte{'f', 'o', 'o'}, o.ToBytes(), "ToBytes") - require.Equal(t, "Host Name -> foo", o.String()) + o := OptHostName("foo") + require.Equal(t, OptionHostName, o.Code, "Code") + require.Equal(t, []byte{'f', 'o', 'o'}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "Host Name: foo", o.String()) } func TestParseOptHostName(t *testing.T) { - data := []byte{'t', 'e', 's', 't'} - o, err := ParseOptHostName(data) - require.NoError(t, err) - require.Equal(t, &OptHostName{HostName: "test"}, o) + o := Options{ + OptionHostName.Code(): []byte{'t', 'e', 's', 't'}, + } + require.Equal(t, "test", GetHostName(o)) + require.Equal(t, "", GetHostName(Options{})) } func TestOptRootPath(t *testing.T) { - o := OptRootPath{Path: "/foo/bar/baz"} - require.Equal(t, OptionRootPath, o.Code(), "Code") - wantBytes := []byte{ - '/', 'f', 'o', 'o', '/', 'b', 'a', 'r', '/', 'b', 'a', 'z', - } - require.Equal(t, wantBytes, o.ToBytes(), "ToBytes") - require.Equal(t, "Root Path -> /foo/bar/baz", o.String()) + o := OptRootPath("foo") + require.Equal(t, OptionRootPath, o.Code, "Code") + require.Equal(t, []byte{'f', 'o', 'o'}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "Root Path: foo", o.String()) } func TestParseOptRootPath(t *testing.T) { - data := []byte{byte(OptionRootPath), 4, '/', 'f', 'o', 'o'} - o, err := ParseOptRootPath(data[2:]) - require.NoError(t, err) - require.Equal(t, &OptRootPath{Path: "/foo"}, o) + o := OptionsFromList(OptRootPath("test")) + require.Equal(t, "test", GetRootPath(o)) + require.Equal(t, "", GetRootPath(Options{})) } -func TestOptBootfileName(t *testing.T) { - opt := OptBootfileName{ - BootfileName: "linuxboot", - } - require.Equal(t, OptionBootfileName, opt.Code()) - require.Equal(t, []byte{'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't'}, opt.ToBytes()) - require.Equal(t, "Bootfile Name -> linuxboot", opt.String()) +func TestOptBootFileName(t *testing.T) { + o := OptBootFileName("foo") + require.Equal(t, OptionBootfileName, o.Code, "Code") + require.Equal(t, []byte{'f', 'o', 'o'}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "Bootfile Name: foo", o.String()) } -func TestParseOptBootfileName(t *testing.T) { - expected := []byte{ - 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', - } - opt, err := ParseOptBootfileName(expected) - require.NoError(t, err) - require.Equal(t, "linuxboot", opt.BootfileName) +func TestParseOptBootFileName(t *testing.T) { + o := OptionsFromList(OptBootFileName("test")) + require.Equal(t, "test", GetBootFileName(o)) + require.Equal(t, "", GetBootFileName(Options{})) } -func TestOptTFTPServer(t *testing.T) { - opt := OptTFTPServerName{ - TFTPServerName: "linuxboot", - } - require.Equal(t, OptionTFTPServerName, opt.Code()) - require.Equal(t, []byte("linuxboot"), opt.ToBytes()) - require.Equal(t, "TFTP Server Name -> linuxboot", opt.String()) +func TestOptTFTPServerName(t *testing.T) { + o := OptTFTPServerName("foo") + require.Equal(t, OptionTFTPServerName, o.Code, "Code") + require.Equal(t, []byte{'f', 'o', 'o'}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "TFTP Server Name: foo", o.String()) } func TestParseOptTFTPServerName(t *testing.T) { - expected := []byte{ - 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', - } - opt, err := ParseOptTFTPServerName(expected) - require.NoError(t, err) - require.Equal(t, "linuxboot", string(opt.TFTPServerName)) + o := OptionsFromList(OptTFTPServerName("test")) + require.Equal(t, "test", GetTFTPServerName(o)) + require.Equal(t, "", GetTFTPServerName(Options{})) } func TestOptClassIdentifier(t *testing.T) { - o := OptClassIdentifier{Identifier: "foo"} - require.Equal(t, OptionClassIdentifier, o.Code(), "Code") - require.Equal(t, []byte("foo"), o.ToBytes(), "ToBytes") - require.Equal(t, "Class Identifier -> foo", o.String()) + o := OptClassIdentifier("foo") + require.Equal(t, OptionClassIdentifier, o.Code, "Code") + require.Equal(t, []byte{'f', 'o', 'o'}, o.Value.ToBytes(), "ToBytes") + require.Equal(t, "Class Identifier: foo", o.String()) } func TestParseOptClassIdentifier(t *testing.T) { - data := []byte("test") - o, err := ParseOptClassIdentifier(data) - require.NoError(t, err) - require.Equal(t, &OptClassIdentifier{Identifier: "test"}, o) + o := OptionsFromList(OptClassIdentifier("test")) + require.Equal(t, "test", GetClassIdentifier(o)) + require.Equal(t, "", GetClassIdentifier(Options{})) } diff --git a/dhcpv4/option_subnet_mask.go b/dhcpv4/option_subnet_mask.go index 19401d8..82b344b 100644 --- a/dhcpv4/option_subnet_mask.go +++ b/dhcpv4/option_subnet_mask.go @@ -1,36 +1,55 @@ package dhcpv4 import ( - "fmt" "net" "github.com/u-root/u-root/pkg/uio" ) -// OptSubnetMask implements the subnet mask option described by RFC 2132, -// Section 3.3. -type OptSubnetMask struct { - SubnetMask net.IPMask +// IPMask represents an option encapsulating the subnet mask. +// +// This option implements the subnet mask option in RFC 2132, Section 3.3. +type IPMask net.IPMask + +// ToBytes returns a serialized stream of bytes for this option. +func (im IPMask) ToBytes() []byte { + if len(im) > net.IPv4len { + return im[:net.IPv4len] + } + return im } -// ParseOptSubnetMask returns a new OptSubnetMask from a byte -// stream, or error if any. -func ParseOptSubnetMask(data []byte) (*OptSubnetMask, error) { - buf := uio.NewBigEndianBuffer(data) - return &OptSubnetMask{SubnetMask: net.IPMask(buf.CopyN(net.IPv4len))}, buf.FinError() +// String returns a human-readable string. +func (im IPMask) String() string { + return net.IPMask(im).String() } -// Code returns the option code. -func (o *OptSubnetMask) Code() OptionCode { - return OptionSubnetMask +// FromBytes parses im from data per RFC 2132. +func (im *IPMask) FromBytes(data []byte) error { + buf := uio.NewBigEndianBuffer(data) + *im = IPMask(buf.CopyN(net.IPv4len)) + return buf.FinError() } -// ToBytes returns a serialized stream of bytes for this option. -func (o *OptSubnetMask) ToBytes() []byte { - return o.SubnetMask[:net.IPv4len] +// GetSubnetMask returns a subnet mask option contained in o, if there is one. +// +// The subnet mask option is described by RFC 2132, Section 3.3. +func GetSubnetMask(o Options) net.IPMask { + v := o.Get(OptionSubnetMask) + if v == nil { + return nil + } + var im IPMask + if err := im.FromBytes(v); err != nil { + return nil + } + return net.IPMask(im) } -// String returns a human-readable string. -func (o *OptSubnetMask) String() string { - return fmt.Sprintf("Subnet Mask -> %v", o.SubnetMask.String()) +// OptSubnetMask returns a new DHCPv4 SubnetMask option per RFC 2132, Section 3.3. +func OptSubnetMask(mask net.IPMask) Option { + return Option{ + Code: OptionSubnetMask, + Value: IPMask(mask), + } } diff --git a/dhcpv4/option_subnet_mask_test.go b/dhcpv4/option_subnet_mask_test.go index f04a481..bc82cf1 100644 --- a/dhcpv4/option_subnet_mask_test.go +++ b/dhcpv4/option_subnet_mask_test.go @@ -7,30 +7,23 @@ import ( "github.com/stretchr/testify/require" ) -func TestOptSubnetMaskInterfaceMethods(t *testing.T) { - mask := net.IPMask{255, 255, 255, 0} - o := OptSubnetMask{SubnetMask: mask} - - require.Equal(t, OptionSubnetMask, o.Code(), "Code") - - expectedBytes := []byte{255, 255, 255, 0} - require.Equal(t, expectedBytes, o.ToBytes(), "ToBytes") - - require.Equal(t, "Subnet Mask -> ffffff00", o.String(), "String") +func TestOptSubnetMask(t *testing.T) { + o := OptSubnetMask(net.IPMask{255, 255, 255, 0}) + require.Equal(t, o.Code, OptionSubnetMask, "Code") + require.Equal(t, "Subnet Mask: ffffff00", o.String(), "String") + require.Equal(t, []byte{255, 255, 255, 0}, o.Value.ToBytes(), "ToBytes") } -func TestParseOptSubnetMask(t *testing.T) { - var ( - o *OptSubnetMask - err error - ) - o, err = ParseOptSubnetMask([]byte{}) - require.Error(t, err, "empty byte stream") +func TestGetSubnetMask(t *testing.T) { + o := OptionsFromList(OptSubnetMask(net.IPMask{})) + mask := GetSubnetMask(o) + require.Nil(t, mask, "empty byte stream") - o, err = ParseOptSubnetMask([]byte{255}) - require.Error(t, err, "short byte stream") + o = OptionsFromList(OptSubnetMask(net.IPMask{255})) + mask = GetSubnetMask(o) + require.Nil(t, mask, "short byte stream") - o, err = ParseOptSubnetMask([]byte{255, 255, 255, 0}) - require.NoError(t, err) - require.Equal(t, net.IPMask{255, 255, 255, 0}, o.SubnetMask) + o = OptionsFromList(OptSubnetMask(net.IPMask{255, 255, 255, 0})) + mask = GetSubnetMask(o) + require.Equal(t, net.IPMask{255, 255, 255, 0}, mask) } diff --git a/dhcpv4/option_userclass.go b/dhcpv4/option_userclass.go index 110cb37..f273a84 100644 --- a/dhcpv4/option_userclass.go +++ b/dhcpv4/option_userclass.go @@ -8,21 +8,53 @@ import ( "github.com/u-root/u-root/pkg/uio" ) -// OptUserClass implements the user class option described by RFC 3004. -type OptUserClass struct { +// UserClass implements the user class option described by RFC 3004. +type UserClass struct { UserClasses [][]byte - Rfc3004 bool + RFC3004 bool } -// Code returns the option code -func (op *OptUserClass) Code() OptionCode { - return OptionUserClassInformation +// GetUserClass returns the user class in o if present. +// +// The user class information option is defined by RFC 3004. +func GetUserClass(o Options) *UserClass { + v := o.Get(OptionUserClassInformation) + if v == nil { + return nil + } + var uc UserClass + if err := uc.FromBytes(v); err != nil { + return nil + } + return &uc +} + +// OptUserClass returns a new user class option. +func OptUserClass(v []byte) Option { + return Option{ + Code: OptionUserClassInformation, + Value: &UserClass{ + UserClasses: [][]byte{v}, + RFC3004: false, + }, + } +} + +// OptRFC3004UserClass returns a new user class option according to RFC 3004. +func OptRFC3004UserClass(v [][]byte) Option { + return Option{ + Code: OptionUserClassInformation, + Value: &UserClass{ + UserClasses: v, + RFC3004: true, + }, + } } // ToBytes serializes the option and returns it as a sequence of bytes -func (op *OptUserClass) ToBytes() []byte { +func (op *UserClass) ToBytes() []byte { buf := uio.NewBigEndianBuffer(nil) - if !op.Rfc3004 { + if !op.RFC3004 { buf.WriteBytes(op.UserClasses[0]) } else { for _, uc := range op.UserClasses { @@ -33,22 +65,21 @@ func (op *OptUserClass) ToBytes() []byte { return buf.Data() } -func (op *OptUserClass) String() string { +// String returns a human-readable user class. +func (op *UserClass) String() string { ucStrings := make([]string, 0, len(op.UserClasses)) - if !op.Rfc3004 { + if !op.RFC3004 { ucStrings = append(ucStrings, string(op.UserClasses[0])) } else { for _, uc := range op.UserClasses { ucStrings = append(ucStrings, string(uc)) } } - return fmt.Sprintf("User Class Information -> %v", strings.Join(ucStrings, ", ")) + return strings.Join(ucStrings, ", ") } -// ParseOptUserClass returns a new OptUserClass from a byte stream or -// error if any -func ParseOptUserClass(data []byte) (*OptUserClass, error) { - opt := OptUserClass{} +// FromBytes parses data into op. +func (op *UserClass) FromBytes(data []byte) error { buf := uio.NewBigEndianBuffer(data) // Check if option is Microsoft style instead of RFC compliant, issue #113 @@ -64,19 +95,19 @@ func ParseOptUserClass(data []byte) (*OptUserClass, error) { counting += int(data[counting]) + 1 } if counting != buf.Len() { - opt.UserClasses = append(opt.UserClasses, data) - return &opt, nil + op.UserClasses = append(op.UserClasses, data) + return nil } - opt.Rfc3004 = true + op.RFC3004 = true for buf.Has(1) { ucLen := buf.Read8() if ucLen == 0 { - return nil, fmt.Errorf("DHCP user class must have length greater than 0") + return fmt.Errorf("DHCP user class must have length greater than 0") } - opt.UserClasses = append(opt.UserClasses, buf.CopyN(int(ucLen))) + op.UserClasses = append(op.UserClasses, buf.CopyN(int(ucLen))) } - if len(opt.UserClasses) == 0 { - return nil, errors.New("ParseOptUserClass: at least one user class is required") + if len(op.UserClasses) == 0 { + return errors.New("ParseOptUserClass: at least one user class is required") } - return &opt, buf.FinError() + return buf.FinError() } diff --git a/dhcpv4/option_userclass_test.go b/dhcpv4/option_userclass_test.go index e321a64..149fb92 100644 --- a/dhcpv4/option_userclass_test.go +++ b/dhcpv4/option_userclass_test.go @@ -7,11 +7,8 @@ import ( ) func TestOptUserClassToBytes(t *testing.T) { - opt := OptUserClass{ - UserClasses: [][]byte{[]byte("linuxboot")}, - Rfc3004: true, - } - data := opt.ToBytes() + opt := OptRFC3004UserClass([][]byte{[]byte("linuxboot")}) + data := opt.Value.ToBytes() expected := []byte{ 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } @@ -19,10 +16,8 @@ func TestOptUserClassToBytes(t *testing.T) { } func TestOptUserClassMicrosoftToBytes(t *testing.T) { - opt := OptUserClass{ - UserClasses: [][]byte{[]byte("linuxboot")}, - } - data := opt.ToBytes() + opt := OptUserClass([]byte("linuxboot")) + data := opt.Value.ToBytes() expected := []byte{ 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } @@ -30,11 +25,12 @@ func TestOptUserClassMicrosoftToBytes(t *testing.T) { } func TestParseOptUserClassMultiple(t *testing.T) { + var opt UserClass expected := []byte{ 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 4, 't', 'e', 's', 't', } - opt, err := ParseOptUserClass(expected) + err := opt.FromBytes(expected) require.NoError(t, err) require.Equal(t, len(opt.UserClasses), 2) require.Equal(t, []byte("linuxboot"), opt.UserClasses[0]) @@ -42,50 +38,53 @@ func TestParseOptUserClassMultiple(t *testing.T) { } func TestParseOptUserClassNone(t *testing.T) { + var opt UserClass expected := []byte{} - _, err := ParseOptUserClass(expected) + err := opt.FromBytes(expected) require.Error(t, err) } func TestParseOptUserClassMicrosoft(t *testing.T) { + var opt UserClass expected := []byte{ 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } - opt, err := ParseOptUserClass(expected) + err := opt.FromBytes(expected) require.NoError(t, err) require.Equal(t, 1, len(opt.UserClasses)) require.Equal(t, []byte("linuxboot"), opt.UserClasses[0]) } func TestParseOptUserClassMicrosoftShort(t *testing.T) { + var opt UserClass expected := []byte{ 'l', } - opt, err := ParseOptUserClass(expected) + err := opt.FromBytes(expected) require.NoError(t, err) require.Equal(t, 1, len(opt.UserClasses)) require.Equal(t, []byte("l"), opt.UserClasses[0]) } func TestParseOptUserClass(t *testing.T) { + var opt UserClass expected := []byte{ 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } - opt, err := ParseOptUserClass(expected) + err := opt.FromBytes(expected) require.NoError(t, err) require.Equal(t, 1, len(opt.UserClasses)) require.Equal(t, []byte("linuxboot"), opt.UserClasses[0]) } func TestOptUserClassToBytesMultiple(t *testing.T) { - opt := OptUserClass{ - UserClasses: [][]byte{ + opt := OptRFC3004UserClass( + [][]byte{ []byte("linuxboot"), []byte("test"), }, - Rfc3004: true, - } - data := opt.ToBytes() + ) + data := opt.Value.ToBytes() expected := []byte{ 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', 4, 't', 'e', 's', 't', @@ -94,14 +93,7 @@ func TestOptUserClassToBytesMultiple(t *testing.T) { } func TestParseOptUserClassZeroLength(t *testing.T) { - expected := []byte{ - 0, 0, - } - _, err := ParseOptUserClass(expected) + var opt UserClass + err := opt.FromBytes([]byte{0, 0}) require.Error(t, err) } - -func TestOptUserClassCode(t *testing.T) { - opt := OptUserClass{} - require.Equal(t, OptionUserClassInformation, opt.Code()) -} diff --git a/dhcpv4/option_vivc.go b/dhcpv4/option_vivc.go index b6efab9..509ba80 100644 --- a/dhcpv4/option_vivc.go +++ b/dhcpv4/option_vivc.go @@ -10,39 +10,53 @@ import ( // VIVCIdentifier implements the vendor-identifying vendor class option // described by RFC 3925. type VIVCIdentifier struct { + // EntID is the enterprise ID. EntID uint32 Data []byte } -// OptVIVC represents the DHCP message type option. -type OptVIVC struct { - Identifiers []VIVCIdentifier +// OptVIVC returns a new vendor-identifying vendor class option. +// +// The option is described by RFC 3925. +func OptVIVC(identifiers ...VIVCIdentifier) Option { + return Option{ + Code: OptionVendorIdentifyingVendorClass, + Value: VIVCIdentifiers(identifiers), + } } -// ParseOptVIVC contructs an OptVIVC tsruct from a sequence of bytes and returns -// it, or an error. -func ParseOptVIVC(data []byte) (*OptVIVC, error) { - buf := uio.NewBigEndianBuffer(data) +// GetVIVC returns the vendor-identifying vendor class option in o if present. +func GetVIVC(o Options) VIVCIdentifiers { + v := o.Get(OptionVendorIdentifyingVendorClass) + if v == nil { + return nil + } + var ids VIVCIdentifiers + if err := ids.FromBytes(v); err != nil { + return nil + } + return ids +} + +// VIVCIdentifiers implements encoding and decoding methods for a DHCP option +// described in RFC 3925. +type VIVCIdentifiers []VIVCIdentifier - var ids []VIVCIdentifier +// FromBytes parses data into ids per RFC 3925. +func (ids *VIVCIdentifiers) FromBytes(data []byte) error { + buf := uio.NewBigEndianBuffer(data) for buf.Has(5) { entID := buf.Read32() idLen := int(buf.Read8()) - ids = append(ids, VIVCIdentifier{EntID: entID, Data: buf.CopyN(idLen)}) + *ids = append(*ids, VIVCIdentifier{EntID: entID, Data: buf.CopyN(idLen)}) } - - return &OptVIVC{Identifiers: ids}, buf.FinError() -} - -// Code returns the option code. -func (o *OptVIVC) Code() OptionCode { - return OptionVendorIdentifyingVendorClass + return buf.FinError() } // ToBytes returns a serialized stream of bytes for this option. -func (o *OptVIVC) ToBytes() []byte { +func (ids VIVCIdentifiers) ToBytes() []byte { buf := uio.NewBigEndianBuffer(nil) - for _, id := range o.Identifiers { + for _, id := range ids { buf.Write32(id.EntID) buf.Write8(uint8(len(id.Data))) buf.WriteBytes(id.Data) @@ -51,13 +65,13 @@ func (o *OptVIVC) ToBytes() []byte { } // String returns a human-readable string for this option. -func (o *OptVIVC) String() string { +func (ids VIVCIdentifiers) String() string { + if len(ids) == 0 { + return "" + } buf := bytes.Buffer{} - fmt.Fprintf(&buf, "Vendor-Identifying Vendor Class ->") - - for _, id := range o.Identifiers { + for _, id := range ids { fmt.Fprintf(&buf, " %d:'%s',", id.EntID, id.Data) } - - return buf.String()[:buf.Len()-1] + return buf.String()[1 : buf.Len()-1] } diff --git a/dhcpv4/option_vivc_test.go b/dhcpv4/option_vivc_test.go index 9b3b704..b1ec398 100644 --- a/dhcpv4/option_vivc_test.go +++ b/dhcpv4/option_vivc_test.go @@ -7,11 +7,9 @@ import ( ) var ( - sampleVIVCOpt = OptVIVC{ - Identifiers: []VIVCIdentifier{ - {EntID: 9, Data: []byte("CiscoIdentifier")}, - {EntID: 18, Data: []byte("WellfleetIdentifier")}, - }, + sampleVIVCOpt = VIVCIdentifiers{ + VIVCIdentifier{EntID: 9, Data: []byte("CiscoIdentifier")}, + VIVCIdentifier{EntID: 18, Data: []byte("WellfleetIdentifier")}, } sampleVIVCOptRaw = []byte{ 0x0, 0x0, 0x0, 0x9, // enterprise id 9 @@ -24,30 +22,31 @@ var ( ) func TestOptVIVCInterfaceMethods(t *testing.T) { - require.Equal(t, OptionVendorIdentifyingVendorClass, sampleVIVCOpt.Code(), "Code") - require.Equal(t, sampleVIVCOptRaw, sampleVIVCOpt.ToBytes(), "ToBytes") + opt := OptVIVC(sampleVIVCOpt...) + require.Equal(t, OptionVendorIdentifyingVendorClass, opt.Code, "Code") + require.Equal(t, sampleVIVCOptRaw, opt.Value.ToBytes(), "ToBytes") + require.Equal(t, "Vendor-Identifying Vendor Class: 9:'CiscoIdentifier', 18:'WellfleetIdentifier'", + opt.String()) } func TestParseOptVICO(t *testing.T) { - o, err := ParseOptVIVC(sampleVIVCOptRaw) - require.NoError(t, err) - require.Equal(t, &sampleVIVCOpt, o) + options := Options{OptionVendorIdentifyingVendorClass.Code(): sampleVIVCOptRaw} + o := GetVIVC(options) + require.Equal(t, sampleVIVCOpt, o) // Identifier len too long data := make([]byte, len(sampleVIVCOptRaw)) copy(data, sampleVIVCOptRaw) data[4] = 40 - _, err = ParseOptVIVC(data) - require.Error(t, err, "should get error from bad length") + options = Options{OptionVendorIdentifyingVendorClass.Code(): data} + o = GetVIVC(options) + require.Nil(t, o, "should get error from bad length") // Longer than length data[4] = 5 - o, err = ParseOptVIVC(data[:10]) - require.NoError(t, err) - require.Equal(t, o.Identifiers[0].Data, []byte("Cisco")) -} + options = Options{OptionVendorIdentifyingVendorClass.Code(): data[:10]} + o = GetVIVC(options) + require.Equal(t, o[0].Data, []byte("Cisco")) -func TestOptVIVCString(t *testing.T) { - require.Equal(t, "Vendor-Identifying Vendor Class -> 9:'CiscoIdentifier', 18:'WellfleetIdentifier'", - sampleVIVCOpt.String()) + require.Equal(t, VIVCIdentifiers(nil), GetVIVC(Options{})) } diff --git a/dhcpv4/options.go b/dhcpv4/options.go index 3d774d1..4c70743 100644 --- a/dhcpv4/options.go +++ b/dhcpv4/options.go @@ -5,7 +5,11 @@ import ( "fmt" "io" "math" + "sort" + "strings" + "github.com/insomniacslk/dhcp/iana" + "github.com/insomniacslk/dhcp/rfc1035label" "github.com/u-root/u-root/pkg/uio" ) @@ -24,149 +28,95 @@ var ( ErrInvalidOptions = errors.New("invalid options data") ) -// Option is an interface that all DHCP v4 options adhere to. -type Option interface { - Code() OptionCode +// OptionValue is an interface that all DHCP v4 options adhere to. +type OptionValue interface { ToBytes() []byte String() string } -// ParseOption parses a sequence of bytes as a single DHCPv4 option, returning -// the specific option structure or error, if any. -func ParseOption(code OptionCode, data []byte) (Option, error) { - var ( - opt Option - err error - ) - switch code { - case OptionSubnetMask: - opt, err = ParseOptSubnetMask(data) - case OptionRouter: - opt, err = ParseOptRouter(data) - case OptionDomainNameServer: - opt, err = ParseOptDomainNameServer(data) - case OptionHostName: - opt, err = ParseOptHostName(data) - case OptionDomainName: - opt, err = ParseOptDomainName(data) - case OptionRootPath: - opt, err = ParseOptRootPath(data) - case OptionBroadcastAddress: - opt, err = ParseOptBroadcastAddress(data) - case OptionNTPServers: - opt, err = ParseOptNTPServers(data) - case OptionRequestedIPAddress: - opt, err = ParseOptRequestedIPAddress(data) - case OptionIPAddressLeaseTime: - opt, err = ParseOptIPAddressLeaseTime(data) - case OptionDHCPMessageType: - opt, err = ParseOptMessageType(data) - case OptionServerIdentifier: - opt, err = ParseOptServerIdentifier(data) - case OptionParameterRequestList: - opt, err = ParseOptParameterRequestList(data) - case OptionMaximumDHCPMessageSize: - opt, err = ParseOptMaximumDHCPMessageSize(data) - case OptionClassIdentifier: - opt, err = ParseOptClassIdentifier(data) - case OptionTFTPServerName: - opt, err = ParseOptTFTPServerName(data) - case OptionBootfileName: - opt, err = ParseOptBootfileName(data) - case OptionUserClassInformation: - opt, err = ParseOptUserClass(data) - case OptionRelayAgentInformation: - opt, err = ParseOptRelayAgentInformation(data) - case OptionClientSystemArchitectureType: - opt, err = ParseOptClientArchType(data) - case OptionDNSDomainSearchList: - opt, err = ParseOptDomainSearch(data) - case OptionVendorIdentifyingVendorClass: - opt, err = ParseOptVIVC(data) - default: - opt, err = ParseOptionGeneric(code, data) - } - if err != nil { - return nil, err +// Option is a DHCPv4 option and consists of a 1-byte option code and a value +// stream of bytes. +// +// The value is to be interpreted based on the option code. +type Option struct { + Code OptionCode + Value OptionValue +} + +// String returns a human-readable version of this option. +func (o Option) String() string { + v := o.Value.String() + if strings.Contains(v, "\n") { + return fmt.Sprintf("%s:\n%s", o.Code, v) } - return opt, nil + return fmt.Sprintf("%s: %s", o.Code, v) } // Options is a collection of options. -type Options []Option +type Options map[uint8][]byte -// GetOne will attempt to get an option that match a Option code. If there -// are multiple options with the same OptionCode it will only return the first -// one found. If no matching option is found nil will be returned. -func (o Options) GetOne(code OptionCode) Option { +// OptionsFromList adds all given options to an options map. +func OptionsFromList(o ...Option) Options { + opts := make(Options) for _, opt := range o { - if opt.Code() == code { - return opt - } + opts.Update(opt) } - return nil + return opts } -// Has checks whether o has the given `opcode` Option. -func (o Options) Has(code OptionCode) bool { - return o.GetOne(code) != nil +// Get will attempt to get all options that match a DHCPv4 option +// from its OptionCode. If the option was not found it will return an +// empty list. +// +// According to RFC 3396, options that are specified more than once are +// concatenated, and hence this should always just return one option. This +// currently returns a list to be API compatible. +func (o Options) Get(code OptionCode) []byte { + return o[code.Code()] } -// Update replaces an existing option with the same option code with the given -// one, adding it if not already present. -// -// Per RFC 2131, Section 4.1, "options may appear only once." -// -// An End option is ignored. -func (o *Options) Update(option Option) { - if option.Code() == OptionEnd { - return - } +// Has checks whether o has the given opcode. +func (o Options) Has(opcode OptionCode) bool { + _, ok := o[opcode.Code()] + return ok +} - for idx, opt := range *o { - if opt.Code() == option.Code() { - (*o)[idx] = option - // Don't look further. - return - } - } - // If not found, add it. - *o = append(*o, option) +// Update updates the existing options with the passed option, adding it +// at the end if not present already +func (o Options) Update(option Option) { + o[option.Code.Code()] = option.Value.ToBytes() +} + +// ToBytes makes Options usable as an OptionValue as well. +// +// Used in the case of vendor-specific and relay agent options. +func (o Options) ToBytes() []byte { + return uio.ToBigEndian(o) } -// OptionsFromBytes parses a sequence of bytes until the end and builds a list -// of options from it. +// FromBytes parses a sequence of bytes until the end and builds a list of +// options from it. // // The sequence should not contain the DHCP magic cookie. // // Returns an error if any invalid option or length is found. -func OptionsFromBytes(data []byte) (Options, error) { - return OptionsFromBytesWithParser(data, codeGetter, ParseOption, true) +func (o Options) FromBytes(data []byte) error { + return o.fromBytesCheckEnd(data, false) } -// OptionParser is a function signature for option parsing. -type OptionParser func(code OptionCode, data []byte) (Option, error) - -// OptionCodeGetter parses a code into an OptionCode. -type OptionCodeGetter func(code uint8) OptionCode - -// codeGetter is an OptionCodeGetter for DHCP optionCodes. -func codeGetter(c uint8) OptionCode { - return optionCode(c) -} +const ( + optPad = 0 + optEnd = 255 +) -// OptionsFromBytesWithParser parses Options from byte sequences using the +// FromBytesCheckEnd parses Options from byte sequences using the // parsing function that is passed in as a paremeter -func OptionsFromBytesWithParser(data []byte, coder OptionCodeGetter, parser OptionParser, checkEndOption bool) (Options, error) { +func (o Options) fromBytesCheckEnd(data []byte, checkEndOption bool) error { if len(data) == 0 { - return nil, nil + return nil } buf := uio.NewBigEndianBuffer(data) - options := make(map[OptionCode][]byte, 10) - var order []OptionCode - // Due to RFC 2131, 3396 allowing an option to be specified multiple - // times, we have to collect all option data first, and then parse it. var end bool for buf.Len() >= 1 { // 1 byte: option code @@ -174,9 +124,9 @@ func OptionsFromBytesWithParser(data []byte, coder OptionCodeGetter, parser Opti // n bytes: data code := buf.Read8() - if code == OptionPad.Code() { + if code == optPad { continue - } else if code == OptionEnd.Code() { + } else if code == optEnd { end = true break } @@ -185,16 +135,10 @@ func OptionsFromBytesWithParser(data []byte, coder OptionCodeGetter, parser Opti // N bytes: option data data := buf.Consume(length) if data == nil { - return nil, fmt.Errorf("error collecting options: %v", buf.Error()) + return fmt.Errorf("error collecting options: %v", buf.Error()) } data = data[:length:length] - // Get the OptionCode for this guy. - c := coder(code) - if _, ok := options[c]; !ok { - order = append(order, c) - } - // RFC 2131, Section 4.1 "Options may appear only once, [...]. // The client concatenates the values of multiple instances of // the same option into a single parameter list for @@ -202,56 +146,54 @@ func OptionsFromBytesWithParser(data []byte, coder OptionCodeGetter, parser Opti // // See also RFC 3396 for concatenation order and options longer // than 255 bytes. - options[c] = append(options[c], data...) + o[code] = append(o[code], data...) } // If we never read the End option, the sender of this packet screwed // up. if !end && checkEndOption { - return nil, io.ErrUnexpectedEOF + return io.ErrUnexpectedEOF } // Any bytes left must be padding. for buf.Len() >= 1 { - if buf.Read8() != OptionPad.Code() { - return nil, ErrInvalidOptions + if buf.Read8() != optPad { + return ErrInvalidOptions } } + return nil +} - opts := make(Options, 0, len(options)) - for _, code := range order { - parsedOpt, err := parser(code, options[code]) - if err != nil { - return nil, fmt.Errorf("error parsing option code %s: %v", code, err) - } - opts = append(opts, parsedOpt) +// sortedKeys returns an ordered slice of option keys from the Options map, for +// use in serializing options to binary. +func (o Options) sortedKeys() []int { + // Send all values for a given key + var codes []int + for k := range o { + codes = append(codes, int(k)) } - return opts, nil + + sort.Sort(sort.IntSlice(codes)) + return codes } // Marshal writes options binary representations to b. func (o Options) Marshal(b *uio.Lexer) { - for _, opt := range o { - code := opt.Code().Code() - + for _, c := range o.sortedKeys() { + code := uint8(c) // Even if the End option is in there, don't marshal it until // the end. - if code == OptionEnd.Code() { - continue - } else if code == OptionPad.Code() { - // Some DHCPv4 options have fixed length and do not put - // length on the wire. - b.Write8(code) + if code == optEnd { continue } - data := opt.ToBytes() + data := o[code] // RFC 3396: If more than 256 bytes of data are given, the // option is simply listed multiple times. for len(data) > 0 { // 1 byte: option code - b.Write8(code) + b.Write8(uint8(code)) n := len(data) if n > math.MaxUint8 { @@ -267,3 +209,137 @@ func (o Options) Marshal(b *uio.Lexer) { } } } + +// String prints options using DHCP-specified option codes. +func (o Options) String() string { + return o.ToString(dhcpHumanizer) +} + +// Summary prints options in human-readable values. +// +// Summary uses vendorParser to interpret the OptionVendorSpecificInformation option. +func (o Options) Summary(vendorDecoder OptionDecoder) string { + return o.ToString(OptionHumanizer{ + ValueHumanizer: parserFor(vendorDecoder), + CodeHumanizer: func(c uint8) OptionCode { + return optionCode(c) + }, + }) +} + +// OptionParser gives a human-legible interpretation of data for the given option code. +type OptionParser func(code OptionCode, data []byte) fmt.Stringer + +// OptionHumanizer is used to interpret a set of Options for their option code +// name and values. +// +// There should be separate OptionHumanizers for each Option "space": DHCP, +// BSDP, Relay Agent Info, and others. +type OptionHumanizer struct { + ValueHumanizer OptionParser + CodeHumanizer func(code uint8) OptionCode +} + +// Stringify returns a human-readable interpretation of the option code and its +// associated data. +func (oh OptionHumanizer) Stringify(code uint8, data []byte) string { + c := oh.CodeHumanizer(code) + val := oh.ValueHumanizer(c, data) + return fmt.Sprintf("%s: %s", c, val) +} + +// dhcpHumanizer humanizes the set of DHCP option codes. +var dhcpHumanizer = OptionHumanizer{ + ValueHumanizer: parseOption, + CodeHumanizer: func(c uint8) OptionCode { + return optionCode(c) + }, +} + +// ToString uses parse to parse options into human-readable values. +func (o Options) ToString(humanizer OptionHumanizer) string { + var ret string + for _, c := range o.sortedKeys() { + code := uint8(c) + v := o[code] + optString := humanizer.Stringify(code, v) + // If this option has sub structures, offset them accordingly. + if strings.Contains(optString, "\n") { + optString = strings.Replace(optString, "\n ", "\n ", -1) + } + ret += fmt.Sprintf(" %v\n", optString) + } + return ret +} + +func parseOption(code OptionCode, data []byte) fmt.Stringer { + return parserFor(nil)(code, data) +} + +func parserFor(vendorParser OptionDecoder) OptionParser { + return func(code OptionCode, data []byte) fmt.Stringer { + return getOption(code, data, vendorParser) + } +} + +// OptionDecoder can decode a byte stream into a human-readable option. +type OptionDecoder interface { + fmt.Stringer + FromBytes([]byte) error +} + +func getOption(code OptionCode, data []byte, vendorDecoder OptionDecoder) fmt.Stringer { + var d OptionDecoder + switch code { + case OptionRouter, OptionDomainNameServer, OptionNTPServers, OptionServerIdentifier: + d = &IPs{} + + case OptionBroadcastAddress, OptionRequestedIPAddress: + d = &IP{} + + case OptionClientSystemArchitectureType: + d = &iana.Archs{} + + case OptionSubnetMask: + d = &IPMask{} + + case OptionDHCPMessageType: + var mt MessageType + d = &mt + + case OptionParameterRequestList: + d = &OptionCodeList{} + + case OptionHostName, OptionDomainName, OptionRootPath, + OptionClassIdentifier, OptionTFTPServerName, OptionBootfileName: + var s String + d = &s + + case OptionRelayAgentInformation: + d = &RelayOptions{} + + case OptionDNSDomainSearchList: + d = &rfc1035label.Labels{} + + case OptionIPAddressLeaseTime: + var dur Duration + d = &dur + + case OptionMaximumDHCPMessageSize: + var u Uint16 + d = &u + + case OptionUserClassInformation: + d = &UserClass{} + + case OptionVendorIdentifyingVendorClass: + d = &VIVCIdentifiers{} + + case OptionVendorSpecificInformation: + d = vendorDecoder + } + if d != nil && d.FromBytes(data) == nil { + return d + } + return OptionGeneric{data} +} diff --git a/dhcpv4/options_test.go b/dhcpv4/options_test.go index 0c1c1fa..6c5393c 100644 --- a/dhcpv4/options_test.go +++ b/dhcpv4/options_test.go @@ -11,154 +11,148 @@ import ( ) func TestParseOption(t *testing.T) { - // Generic - option := []byte{192, 168, 1, 254} // Name server option - opt, err := ParseOption(OptionNameServer, option) - require.NoError(t, err) - generic := opt.(*OptionGeneric) - require.Equal(t, OptionNameServer, generic.Code()) - require.Equal(t, []byte{192, 168, 1, 254}, generic.Data) - require.Equal(t, "Name Server -> [192 168 1 254]", generic.String()) - - // Option subnet mask - option = []byte{255, 255, 255, 0} - opt, err = ParseOption(OptionSubnetMask, option) - require.NoError(t, err) - require.Equal(t, OptionSubnetMask, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option router - option = []byte{192, 168, 1, 1} - opt, err = ParseOption(OptionRouter, option) - require.NoError(t, err) - require.Equal(t, OptionRouter, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option domain name server - option = []byte{192, 168, 1, 1} - opt, err = ParseOption(OptionDomainNameServer, option) - require.NoError(t, err) - require.Equal(t, OptionDomainNameServer, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option host name - option = []byte{'t', 'e', 's', 't'} - opt, err = ParseOption(OptionHostName, option) - require.NoError(t, err) - require.Equal(t, OptionHostName, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option domain name - option = []byte{'t', 'e', 's', 't'} - opt, err = ParseOption(OptionDomainName, option) - require.NoError(t, err) - require.Equal(t, OptionDomainName, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option root path - option = []byte{'/', 'f', 'o', 'o'} - opt, err = ParseOption(OptionRootPath, option) - require.NoError(t, err) - require.Equal(t, OptionRootPath, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option broadcast address - option = []byte{255, 255, 255, 255} - opt, err = ParseOption(OptionBroadcastAddress, option) - require.NoError(t, err) - require.Equal(t, OptionBroadcastAddress, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option NTP servers - option = []byte{10, 10, 10, 10} - opt, err = ParseOption(OptionNTPServers, option) - require.NoError(t, err) - require.Equal(t, OptionNTPServers, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Requested IP address - option = []byte{1, 2, 3, 4} - opt, err = ParseOption(OptionRequestedIPAddress, option) - require.NoError(t, err) - require.Equal(t, OptionRequestedIPAddress, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Requested IP address lease time - option = []byte{0, 0, 0, 0} - opt, err = ParseOption(OptionIPAddressLeaseTime, option) - require.NoError(t, err) - require.Equal(t, OptionIPAddressLeaseTime, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Message type - option = []byte{1} - opt, err = ParseOption(OptionDHCPMessageType, option) - require.NoError(t, err) - require.Equal(t, OptionDHCPMessageType, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option server ID - option = []byte{1, 2, 3, 4} - opt, err = ParseOption(OptionServerIdentifier, option) - require.NoError(t, err) - require.Equal(t, OptionServerIdentifier, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Parameter request list - option = []byte{5, 53, 61} - opt, err = ParseOption(OptionParameterRequestList, option) - require.NoError(t, err) - require.Equal(t, OptionParameterRequestList, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option max message size - option = []byte{1, 2} - opt, err = ParseOption(OptionMaximumDHCPMessageSize, option) - require.NoError(t, err) - require.Equal(t, OptionMaximumDHCPMessageSize, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option class identifier - option = []byte{'t', 'e', 's', 't'} - opt, err = ParseOption(OptionClassIdentifier, option) - require.NoError(t, err) - require.Equal(t, OptionClassIdentifier, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option TFTP server name - option = []byte{'t', 'e', 's', 't'} - opt, err = ParseOption(OptionTFTPServerName, option) - require.NoError(t, err) - require.Equal(t, OptionTFTPServerName, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") - - // Option Bootfile name - option = []byte{'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't'} - opt, err = ParseOption(OptionBootfileName, option) - require.NoError(t, err) - require.Equal(t, OptionBootfileName, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") + for _, tt := range []struct { + code OptionCode + value []byte + want string + }{ + { + code: OptionNameServer, + value: []byte{192, 168, 1, 254}, + want: "[192 168 1 254]", + }, + { + code: OptionSubnetMask, + value: []byte{255, 255, 255, 0}, + want: "ffffff00", + }, + { + code: OptionRouter, + value: []byte{192, 168, 1, 1, 192, 168, 2, 1}, + want: "192.168.1.1, 192.168.2.1", + }, + { + code: OptionDomainNameServer, + value: []byte{192, 168, 1, 1, 192, 168, 2, 1}, + want: "192.168.1.1, 192.168.2.1", + }, + { + code: OptionNTPServers, + value: []byte{192, 168, 1, 1, 192, 168, 2, 1}, + want: "192.168.1.1, 192.168.2.1", + }, + { + code: OptionServerIdentifier, + value: []byte{192, 168, 1, 1, 192, 168, 2, 1}, + want: "192.168.1.1, 192.168.2.1", + }, + { + code: OptionHostName, + value: []byte("test"), + want: "test", + }, + { + code: OptionDomainName, + value: []byte("test"), + want: "test", + }, + { + code: OptionRootPath, + value: []byte("test"), + want: "test", + }, + { + code: OptionClassIdentifier, + value: []byte("test"), + want: "test", + }, + { + code: OptionTFTPServerName, + value: []byte("test"), + want: "test", + }, + { + code: OptionBootfileName, + value: []byte("test"), + want: "test", + }, + { + code: OptionBroadcastAddress, + value: []byte{192, 168, 1, 1}, + want: "192.168.1.1", + }, + { + code: OptionRequestedIPAddress, + value: []byte{192, 168, 1, 1}, + want: "192.168.1.1", + }, + { + code: OptionIPAddressLeaseTime, + value: []byte{0, 0, 0, 12}, + want: "12s", + }, + { + code: OptionDHCPMessageType, + value: []byte{1}, + want: "DISCOVER", + }, + { + code: OptionParameterRequestList, + value: []byte{3, 4, 5}, + want: "Router, Time Server, Name Server", + }, + { + code: OptionMaximumDHCPMessageSize, + value: []byte{1, 2}, + want: "258", + }, + { + code: OptionUserClassInformation, + value: []byte{4, 't', 'e', 's', 't', 3, 'f', 'o', 'o'}, + want: "test, foo", + }, + { + code: OptionRelayAgentInformation, + value: []byte{1, 4, 129, 168, 0, 1}, + want: " unknown (1): [129 168 0 1]\n", + }, + { + code: OptionClientSystemArchitectureType, + value: []byte{0, 0}, + want: "Intel x86PC", + }, + } { + s := parseOption(tt.code, tt.value) + if got := s.String(); got != tt.want { + t.Errorf("parseOption(%s, %v) = %s, want %s", tt.code, tt.value, got, tt.want) + } + } +} - // Option user class information - option = []byte{4, 't', 'e', 's', 't'} - opt, err = ParseOption(OptionUserClassInformation, option) - require.NoError(t, err) - require.Equal(t, OptionUserClassInformation, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") +func TestOptionToBytes(t *testing.T) { + o := Option{ + Code: OptionDHCPMessageType, + Value: &OptionGeneric{[]byte{byte(MessageTypeDiscover)}}, + } + serialized := o.Value.ToBytes() + expected := []byte{1} + require.Equal(t, expected, serialized) +} - // Option relay agent information - option = []byte{1, 4, 129, 168, 0, 1} - opt, err = ParseOption(OptionRelayAgentInformation, option) - require.NoError(t, err) - require.Equal(t, OptionRelayAgentInformation, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") +func TestOptionString(t *testing.T) { + o := Option{ + Code: OptionDHCPMessageType, + Value: MessageTypeDiscover, + } + require.Equal(t, "DHCP Message Type: DISCOVER", o.String()) +} - // Option client system architecture type option - option = []byte{'t', 'e', 's', 't'} - opt, err = ParseOption(OptionClientSystemArchitectureType, option) - require.NoError(t, err) - require.Equal(t, OptionClientSystemArchitectureType, opt.Code(), "Code") - require.Equal(t, option, opt.ToBytes(), "ToBytes") +func TestOptionStringUnknown(t *testing.T) { + o := Option{ + Code: GenericOptionCode(102), // Returend option code. + Value: &OptionGeneric{[]byte{byte(MessageTypeDiscover)}}, + } + require.Equal(t, "unknown (102): [1]", o.String()) } func TestOptionsMarshal(t *testing.T) { @@ -172,10 +166,7 @@ func TestOptionsMarshal(t *testing.T) { }, { opts: Options{ - &OptionGeneric{ - OptionCode: optionCode(5), - Data: []byte{1, 2, 3, 4}, - }, + 5: []byte{1, 2, 3, 4}, }, want: []byte{ 5 /* key */, 4 /* length */, 1, 2, 3, 4, @@ -184,14 +175,9 @@ func TestOptionsMarshal(t *testing.T) { { // Test sorted key order. opts: Options{ - &OptionGeneric{ - OptionCode: optionCode(5), - Data: []byte{1, 2, 3}, - }, - &OptionGeneric{ - OptionCode: optionCode(100), - Data: []byte{101, 102, 103}, - }, + 5: []byte{1, 2, 3}, + 100: []byte{101, 102, 103}, + 255: []byte{}, }, want: []byte{ 5, 3, 1, 2, 3, @@ -201,10 +187,7 @@ func TestOptionsMarshal(t *testing.T) { { // Test RFC 3396. opts: Options{ - &OptionGeneric{ - OptionCode: optionCode(5), - Data: bytes.Repeat([]byte{10}, math.MaxUint8+1), - }, + 5: bytes.Repeat([]byte{10}, math.MaxUint8+1), }, want: append(append( []byte{5, math.MaxUint8}, bytes.Repeat([]byte{10}, math.MaxUint8)...), @@ -262,10 +245,7 @@ func TestOptionsUnmarshal(t *testing.T) { byte(OptionEnd), }, want: Options{ - &OptionGeneric{ - OptionCode: optionCode(3), - Data: []byte{5, 6}, - }, + 3: []byte{5, 6}, }, }, { @@ -276,10 +256,7 @@ func TestOptionsUnmarshal(t *testing.T) { byte(OptionEnd), ), want: Options{ - &OptionGeneric{ - OptionCode: optionCode(3), - Data: bytes.Repeat([]byte{10}, math.MaxUint8+5), - }, + 3: bytes.Repeat([]byte{10}, math.MaxUint8+5), }, }, { @@ -289,14 +266,8 @@ func TestOptionsUnmarshal(t *testing.T) { byte(OptionEnd), }, want: Options{ - &OptionGeneric{ - OptionCode: optionCode(10), - Data: []byte{255, 254}, - }, - &OptionGeneric{ - OptionCode: optionCode(11), - Data: []byte{5, 5, 5}, - }, + 10: []byte{255, 254}, + 11: []byte{5, 5, 5}, }, }, { @@ -305,15 +276,13 @@ func TestOptionsUnmarshal(t *testing.T) { byte(OptionEnd), ), want: Options{ - &OptionGeneric{ - OptionCode: optionCode(10), - Data: []byte{255, 254}, - }, + 10: []byte{255, 254}, }, }, } { t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) { - opt, err := OptionsFromBytesWithParser(tt.input, codeGetter, ParseOptionGeneric, true) + opt := make(Options) + err := opt.fromBytesCheckEnd(tt.input, true) if tt.wantError { require.Error(t, err) } else { diff --git a/dhcpv4/server_test.go b/dhcpv4/server_test.go index 4307924..8626451 100644 --- a/dhcpv4/server_test.go +++ b/dhcpv4/server_test.go @@ -41,19 +41,15 @@ func DORAHandler(conn net.PacketConn, peer net.Addr, m *DHCPv4) { log.Printf("NewReplyFromRequest failed: %v", err) return } - reply.UpdateOption(&OptServerIdentifier{ServerID: net.IP{1, 2, 3, 4}}) - opt := m.GetOneOption(OptionDHCPMessageType) - if opt == nil { - log.Printf("No message type found!") - return - } - switch opt.(*OptMessageType).MessageType { + reply.UpdateOption(OptServerIdentifier(net.IP{1, 2, 3, 4})) + mt := GetMessageType(m.Options) + switch mt { case MessageTypeDiscover: - reply.UpdateOption(&OptMessageType{MessageType: MessageTypeOffer}) + reply.UpdateOption(OptMessageType(MessageTypeOffer)) case MessageTypeRequest: - reply.UpdateOption(&OptMessageType{MessageType: MessageTypeAck}) + reply.UpdateOption(OptMessageType(MessageTypeAck)) default: - log.Printf("Unhandled message type: %v", opt.(*OptMessageType).MessageType) + log.Printf("Unhandled message type: %v", mt) return } diff --git a/dhcpv4/types.go b/dhcpv4/types.go index 6214dbd..e1762cc 100644 --- a/dhcpv4/types.go +++ b/dhcpv4/types.go @@ -2,6 +2,8 @@ package dhcpv4 import ( "fmt" + + "github.com/u-root/u-root/pkg/uio" ) // values from http://www.networksorcery.com/enp/protocol/dhcp.htm and @@ -36,6 +38,13 @@ const ( MessageTypeInform MessageType = 8 ) +// ToBytes returns the serialized version of this option described by RFC 2132, +// Section 9.6. +func (m MessageType) ToBytes() []byte { + return []byte{byte(m)} +} + +// String prints a human-readable message type name. func (m MessageType) String() string { if s, ok := messageTypeToString[m]; ok { return s @@ -43,6 +52,14 @@ func (m MessageType) String() string { return fmt.Sprintf("unknown (%d)", byte(m)) } +// FromBytes reads a message type from data as described by RFC 2132, Section +// 9.6. +func (m *MessageType) FromBytes(data []byte) error { + buf := uio.NewBigEndianBuffer(data) + *m = MessageType(buf.Read8()) + return buf.FinError() +} + var messageTypeToString = map[MessageType]string{ MessageTypeDiscover: "DISCOVER", MessageTypeOffer: "OFFER", @@ -81,22 +98,40 @@ var opcodeToString = map[OpcodeType]string{ // with the same Code value, as vendor-specific options use option codes that // have the same value, but mean a different thing. type OptionCode interface { + // Code is the 1 byte option code for the wire. Code() uint8 + + // String returns the option's name. String() string } // optionCode is a DHCP option code. type optionCode uint8 +// Code implements OptionCode.Code. func (o optionCode) Code() uint8 { return uint8(o) } +// String returns an option name. func (o optionCode) String() string { if s, ok := optionCodeToString[o]; ok { return s } - return fmt.Sprintf("unknown (%d)", o) + return fmt.Sprintf("unknown (%d)", uint8(o)) +} + +// GenericOptionCode is an unnamed option code. +type GenericOptionCode uint8 + +// Code implements OptionCode.Code. +func (o GenericOptionCode) Code() uint8 { + return uint8(o) +} + +// String returns the option's name. +func (o GenericOptionCode) String() string { + return fmt.Sprintf("unknown (%d)", uint8(o)) } // DHCPv4 Options @@ -263,7 +298,7 @@ const ( OptionEnd optionCode = 255 ) -var optionCodeToString = map[optionCode]string{ +var optionCodeToString = map[OptionCode]string{ OptionPad: "Pad", OptionSubnetMask: "Subnet Mask", OptionTimeOffset: "Time Offset", diff --git a/dhcpv4/ztpv4/ztp.go b/dhcpv4/ztpv4/ztp.go index 18075e9..4401e9d 100644 --- a/dhcpv4/ztpv4/ztp.go +++ b/dhcpv4/ztpv4/ztp.go @@ -18,11 +18,10 @@ var errVendorOptionMalformed = errors.New("malformed vendor option") // ParseVendorData will try to parse dhcp4 options looking for more // specific vendor data (like model, serial number, etc). func ParseVendorData(packet *dhcpv4.DHCPv4) (*VendorData, error) { - opt := packet.GetOneOption(dhcpv4.OptionClassIdentifier) - if opt == nil { + vc := dhcpv4.GetClassIdentifier(packet.Options) + if len(vc) == 0 { return nil, errors.New("vendor options not found") } - vc := opt.(*dhcpv4.OptClassIdentifier).Identifier vd := &VendorData{} switch { @@ -59,9 +58,8 @@ func ParseVendorData(packet *dhcpv4.DHCPv4) (*VendorData, error) { p := strings.Split(vc, "-") if len(p) < 3 { vd.Model = p[1] - if opt := packet.GetOneOption(dhcpv4.OptionHostName); opt != nil { - vd.Serial = opt.(*dhcpv4.OptHostName).HostName - } else { + vd.Serial = dhcpv4.GetHostName(packet.Options) + if len(vd.Serial) == 0 { return nil, errors.New("host name option is missing") } } else { diff --git a/dhcpv4/ztpv4/ztp_test.go b/dhcpv4/ztpv4/ztp_test.go index 680f15a..6e050d9 100644 --- a/dhcpv4/ztpv4/ztp_test.go +++ b/dhcpv4/ztpv4/ztp_test.go @@ -54,15 +54,10 @@ func TestParseV4VendorClass(t *testing.T) { } if tc.vc != "" { - packet.UpdateOption(&dhcpv4.OptClassIdentifier{ - Identifier: tc.vc, - }) + packet.UpdateOption(dhcpv4.OptClassIdentifier(tc.vc)) } - if tc.hostname != "" { - packet.UpdateOption(&dhcpv4.OptHostName{ - HostName: tc.hostname, - }) + packet.UpdateOption(dhcpv4.OptHostName(tc.hostname)) } vd, err := ParseVendorData(packet) diff --git a/iana/archtype.go b/iana/archtype.go index 510c6fc..255687c 100644 --- a/iana/archtype.go +++ b/iana/archtype.go @@ -1,5 +1,12 @@ package iana +import ( + "fmt" + "strings" + + "github.com/u-root/u-root/pkg/uio" +) + // Arch encodes an architecture type per RFC 4578, Section 2.1. type Arch uint16 @@ -38,3 +45,40 @@ func (a Arch) String() string { } return "unknown" } + +// Archs represents multiple Arch values. +type Archs []Arch + +// ToBytes returns the serialized option defined by RFC 4578 (DHCPv4) and RFC +// 5970 (DHCPv6) as the Client System Architecture Option. +func (a Archs) ToBytes() []byte { + buf := uio.NewBigEndianBuffer(nil) + for _, at := range a { + buf.Write16(uint16(at)) + } + return buf.Data() +} + +// String returns the list of archs in a human-readable manner. +func (a Archs) String() string { + s := make([]string, 0, len(a)) + for _, arch := range a { + s = append(s, arch.String()) + } + return strings.Join(s, ", ") +} + +// FromBytes parses a DHCP list of architecture types as defined by RFC 4578 +// and RFC 5970. +func (a *Archs) FromBytes(data []byte) error { + buf := uio.NewBigEndianBuffer(data) + if buf.Len() == 0 { + return fmt.Errorf("must have at least one archtype if option is present") + } + + *a = make([]Arch, 0, buf.Len()/2) + for buf.Has(2) { + *a = append(*a, Arch(buf.Read16())) + } + return buf.FinError() +} diff --git a/netboot/netconf.go b/netboot/netconf.go index 9dfa858..8c56262 100644 --- a/netboot/netconf.go +++ b/netboot/netconf.go @@ -87,11 +87,10 @@ func GetNetConfFromPacketv4(d *dhcpv4.DHCPv4) (*NetConf, error) { // get the subnet mask from OptionSubnetMask. If the netmask is not defined // in the packet, an error is returned - netmaskOption := d.GetOneOption(dhcpv4.OptionSubnetMask) - if netmaskOption == nil { + netmask := dhcpv4.GetSubnetMask(d.Options) + if netmask == nil { return nil, errors.New("no netmask option in response packet") } - netmask := netmaskOption.(*dhcpv4.OptSubnetMask).SubnetMask ones, _ := netmask.Size() if ones == 0 { return nil, errors.New("netmask extracted from OptSubnetMask options is null") @@ -100,11 +99,7 @@ func GetNetConfFromPacketv4(d *dhcpv4.DHCPv4) (*NetConf, error) { // netconf struct requires a valid lifetime to be specified. ValidLifetime is a dhcpv6 // concept, the closest mapping in dhcpv4 world is "IP Address Lease Time". If the lease // time option is nil, we set it to 0 - leaseTimeOption := d.GetOneOption(dhcpv4.OptionIPAddressLeaseTime) - leaseTime := uint32(0) - if leaseTimeOption != nil { - leaseTime = leaseTimeOption.(*dhcpv4.OptIPAddressLeaseTime).LeaseTime - } + leaseTime := dhcpv4.GetIPAddressLeaseTime(d.Options, 0) netconf.Addresses = append(netconf.Addresses, AddrConf{ IPNet: net.IPNet{ @@ -112,24 +107,19 @@ func GetNetConfFromPacketv4(d *dhcpv4.DHCPv4) (*NetConf, error) { Mask: netmask, }, PreferredLifetime: 0, - ValidLifetime: int(leaseTime), + ValidLifetime: int(leaseTime / time.Second), }) // get DNS configuration - dnsServersOption := d.GetOneOption(dhcpv4.OptionDomainNameServer) - if dnsServersOption == nil { - return nil, errors.New("name servers option is empty") - } - dnsServers := dnsServersOption.(*dhcpv4.OptDomainNameServer).NameServers + dnsServers := dhcpv4.GetDNS(d.Options) if len(dnsServers) == 0 { return nil, errors.New("no dns servers options in response packet") } netconf.DNSServers = dnsServers // get domain search list - dnsDomainSearchListOption := d.GetOneOption(dhcpv4.OptionDNSDomainSearchList) - if dnsDomainSearchListOption != nil { - dnsSearchList := dnsDomainSearchListOption.(*dhcpv4.OptDomainSearch).DomainSearch + dnsSearchList := dhcpv4.GetDomainSearch(d.Options) + if dnsSearchList != nil { if len(dnsSearchList.Labels) == 0 { return nil, errors.New("dns search list is empty") } @@ -137,18 +127,11 @@ func GetNetConfFromPacketv4(d *dhcpv4.DHCPv4) (*NetConf, error) { } // get default gateway - routerOption := d.GetOneOption(dhcpv4.OptionRouter) - if routerOption == nil { - return nil, errors.New("no router option specified in response packet") - } - - routersList := routerOption.(*dhcpv4.OptRouter).Routers + routersList := dhcpv4.GetRouter(d.Options) if len(routersList) == 0 { return nil, errors.New("no routers specified in the corresponding option") } - netconf.Routers = routersList - return &netconf, nil } diff --git a/netboot/netconf_test.go b/netboot/netconf_test.go index 00b39b8..5e954fa 100644 --- a/netboot/netconf_test.go +++ b/netboot/netconf_test.go @@ -224,13 +224,13 @@ func TestGetNetConfFromPacketv4(t *testing.T) { require.Equal(t, 5200, netconf.Addresses[0].ValidLifetime) // check DNSes require.Equal(t, 2, len(netconf.DNSServers)) - require.Equal(t, net.ParseIP("10.10.0.1"), netconf.DNSServers[0]) - require.Equal(t, net.ParseIP("10.10.0.2"), netconf.DNSServers[1]) + require.Equal(t, net.ParseIP("10.10.0.1").To4(), netconf.DNSServers[0]) + require.Equal(t, net.ParseIP("10.10.0.2").To4(), netconf.DNSServers[1]) // check DNS search list require.Equal(t, 2, len(netconf.DNSSearchList)) require.Equal(t, "slackware.it", netconf.DNSSearchList[0]) require.Equal(t, "dhcp.slackware.it", netconf.DNSSearchList[1]) // check routers require.Equal(t, 1, len(netconf.Routers)) - require.Equal(t, net.ParseIP("10.0.0.254"), netconf.Routers[0]) + require.Equal(t, net.ParseIP("10.0.0.254").To4(), netconf.Routers[0]) } diff --git a/rfc1035label/label.go b/rfc1035label/label.go index 5093de8..5a67d7c 100644 --- a/rfc1035label/label.go +++ b/rfc1035label/label.go @@ -2,13 +2,14 @@ package rfc1035label import ( "errors" + "fmt" "strings" ) +// Labels represents RFC1035 labels +// // This implements RFC 1035 labels, including compression. // https://tools.ietf.org/html/rfc1035#section-4.1.4 - -// Labels represents RFC1035 labels type Labels struct { // original contains the original bytes if the object was parsed from a byte // sequence, or nil otherwise. The `original` field is necessary to deal @@ -33,6 +34,11 @@ func same(a, b []string) bool { return true } +// String prints labels. +func (l *Labels) String() string { + return fmt.Sprintf("%v", l.Labels) +} + // ToBytes returns a byte sequence representing the labels. If the original // sequence is modified, the labels are parsed again, otherwise the original // byte sequence is returned. @@ -62,17 +68,25 @@ func NewLabels() *Labels { } } +// FromBytes reads labels from a bytes stream according to RFC 1035. +func (l *Labels) FromBytes(data []byte) error { + labs, err := labelsFromBytes(data) + if err != nil { + return err + } + l.original = data + l.Labels = labs + return nil +} + // FromBytes returns a Labels object from the given byte sequence, or an error if // any. func FromBytes(data []byte) (*Labels, error) { - lab := NewLabels() - l, err := labelsFromBytes(data) - if err != nil { + var l Labels + if err := l.FromBytes(data); err != nil { return nil, err } - lab.original = data - lab.Labels = l - return lab, nil + return &l, nil } // fromBytes decodes a serialized stream and returns a list of labels |