diff options
author | Christopher Koch <c@chrisko.ch> | 2019-01-19 21:29:26 +0000 |
---|---|---|
committer | insomniac <insomniacslk@users.noreply.github.com> | 2019-01-19 22:32:20 +0000 |
commit | fe6f307df5d78a54ddd4a56a275043317148fe5a (patch) | |
tree | 96c357bf87bd4939b503763ffc94c66aa73e336c | |
parent | 5e6e8baddaa29b866abe0b865e0c66c9190ec2f7 (diff) |
dhcpv4: build more packets with modifiers
Also drop unnecessary return value of Modifier.
-rw-r--r-- | dhcpv4/async/client.go | 6 | ||||
-rw-r--r-- | dhcpv4/bsdp/bsdp.go | 72 | ||||
-rw-r--r-- | dhcpv4/bsdp/bsdp_test.go | 1 | ||||
-rw-r--r-- | dhcpv4/client.go | 5 | ||||
-rw-r--r-- | dhcpv4/dhcpv4.go | 115 | ||||
-rw-r--r-- | dhcpv4/modifiers.go | 73 | ||||
-rw-r--r-- | dhcpv4/modifiers_test.go | 96 | ||||
-rw-r--r-- | netboot/netconf_test.go | 112 |
8 files changed, 238 insertions, 242 deletions
diff --git a/dhcpv4/async/client.go b/dhcpv4/async/client.go index 9844180..81b10f7 100644 --- a/dhcpv4/async/client.go +++ b/dhcpv4/async/client.go @@ -194,11 +194,7 @@ func (c *Client) remoteAddr() (*net.UDPAddr, error) { // Send inserts a message to the queue to be sent asynchronously. // Returns a future which resolves to response and error. -func (c *Client) Send(message *dhcpv4.DHCPv4, modifiers ...dhcpv4.Modifier) *promise.Future { - for _, mod := range modifiers { - message = mod(message) - } - +func (c *Client) Send(message *dhcpv4.DHCPv4) *promise.Future { p := promise.NewPromise() c.packetsLock.Lock() c.packets[message.TransactionID] = p diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go index 21cc71a..3cc87d2 100644 --- a/dhcpv4/bsdp/bsdp.go +++ b/dhcpv4/bsdp/bsdp.go @@ -96,17 +96,17 @@ func NewInformListForInterface(ifname string, replyPort uint16) (*dhcpv4.DHCPv4, // NewInformList creates a new INFORM packet for interface with hardware address // `hwaddr` and IP `localIP`. Packet will be sent out on port `replyPort`. -func NewInformList(hwaddr net.HardwareAddr, localIP net.IP, replyPort uint16) (*dhcpv4.DHCPv4, error) { - d, err := dhcpv4.NewInform(hwaddr, localIP) - if err != nil { - return nil, err - } - +func NewInformList(hwaddr net.HardwareAddr, localIP net.IP, replyPort uint16, modifiers ...dhcpv4.Modifier) (*dhcpv4.DHCPv4, error) { // Validate replyPort first if needsReplyPort(replyPort) && replyPort >= 1024 { return nil, errors.New("replyPort must be a privileged port") } + vendorClassID, err := MakeVendorClassIdentifier() + if err != nil { + return nil, err + } + // These are vendor-specific options used to pass along BSDP information. vendorOpts := []dhcpv4.Option{ &OptMessageType{MessageTypeList}, @@ -115,44 +115,24 @@ func NewInformList(hwaddr net.HardwareAddr, localIP net.IP, replyPort uint16) (* if needsReplyPort(replyPort) { vendorOpts = append(vendorOpts, &OptReplyPort{replyPort}) } - d.UpdateOption(&OptVendorSpecificInformation{vendorOpts}) - d.UpdateOption(&dhcpv4.OptParameterRequestList{ - RequestedOpts: []dhcpv4.OptionCode{ + return dhcpv4.NewInform(hwaddr, localIP, + dhcpv4.PrependModifiers(modifiers, dhcpv4.WithRequestedOptions( dhcpv4.OptionVendorSpecificInformation, dhcpv4.OptionClassIdentifier, - }, - }) - d.UpdateOption(&dhcpv4.OptMaximumDHCPMessageSize{Size: MaxDHCPMessageSize}) - - vendorClassID, err := MakeVendorClassIdentifier() - if err != nil { - return nil, err - } - d.UpdateOption(&dhcpv4.OptClassIdentifier{Identifier: vendorClassID}) - return d, nil + ), + dhcpv4.WithOption(&dhcpv4.OptMaximumDHCPMessageSize{Size: MaxDHCPMessageSize}), + dhcpv4.WithOption(&dhcpv4.OptClassIdentifier{Identifier: vendorClassID}), + dhcpv4.WithOption(&OptVendorSpecificInformation{vendorOpts}), + )...) } // InformSelectForAck constructs an INFORM[SELECT] packet given an ACK to the // previously-sent INFORM[LIST]. func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootImage) (*dhcpv4.DHCPv4, error) { - d, err := dhcpv4.New() - if err != nil { - return nil, err - } - if needsReplyPort(replyPort) && replyPort >= 1024 { return nil, errors.New("replyPort must be a privileged port") } - d.OpCode = dhcpv4.OpcodeBootRequest - d.HWType = ack.HWType - d.ClientHWAddr = ack.ClientHWAddr - d.TransactionID = ack.TransactionID - if ack.IsBroadcast() { - d.SetBroadcast() - } else { - d.SetUnicast() - } // Data for OptionSelectedBootImageID vendorOpts := []dhcpv4.Option{ @@ -161,6 +141,11 @@ func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootI &OptSelectedBootImageID{selectedImage.ID}, } + // Validate replyPort if requested. + if needsReplyPort(replyPort) { + vendorOpts = append(vendorOpts, &OptReplyPort{replyPort}) + } + // Find server IP address var serverIP net.IP if opt := ack.GetOneOption(dhcpv4.OptionServerIdentifier); opt != nil { @@ -171,28 +156,23 @@ func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootI } vendorOpts = append(vendorOpts, &OptServerIdentifier{serverIP}) - // Validate replyPort if requested. - if needsReplyPort(replyPort) { - vendorOpts = append(vendorOpts, &OptReplyPort{replyPort}) - } - vendorClassID, err := MakeVendorClassIdentifier() if err != nil { return nil, err } - d.UpdateOption(&dhcpv4.OptClassIdentifier{Identifier: vendorClassID}) - d.UpdateOption(&dhcpv4.OptParameterRequestList{ - RequestedOpts: []dhcpv4.OptionCode{ + + return dhcpv4.New(dhcpv4.WithReply(&ack), + dhcpv4.WithOption(&dhcpv4.OptClassIdentifier{Identifier: vendorClassID}), + dhcpv4.WithRequestedOptions( dhcpv4.OptionSubnetMask, dhcpv4.OptionRouter, dhcpv4.OptionBootfileName, dhcpv4.OptionVendorSpecificInformation, dhcpv4.OptionClassIdentifier, - }, - }) - d.UpdateOption(&dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeInform}) - d.UpdateOption(&OptVendorSpecificInformation{vendorOpts}) - return d, nil + ), + dhcpv4.WithMessageType(dhcpv4.MessageTypeInform), + dhcpv4.WithOption(&OptVendorSpecificInformation{vendorOpts}), + ) } // NewReplyForInformList constructs an ACK for the INFORM[LIST] packet `inform` diff --git a/dhcpv4/bsdp/bsdp_test.go b/dhcpv4/bsdp/bsdp_test.go index 2caa6e5..638a408 100644 --- a/dhcpv4/bsdp/bsdp_test.go +++ b/dhcpv4/bsdp/bsdp_test.go @@ -104,6 +104,7 @@ func TestNewInformList_ReplyPort(t *testing.T) { func newAck(hwAddr net.HardwareAddr, transactionID [4]byte) *dhcpv4.DHCPv4 { ack, _ := dhcpv4.New() + ack.OpCode = dhcpv4.OpcodeBootReply ack.TransactionID = transactionID ack.HWType = iana.HWTypeEthernet ack.ClientHWAddr = hwAddr diff --git a/dhcpv4/client.go b/dhcpv4/client.go index d2d18db..ad80b34 100644 --- a/dhcpv4/client.go +++ b/dhcpv4/client.go @@ -220,13 +220,10 @@ func (c *Client) Exchange(ifname string, modifiers ...Modifier) ([]*DHCPv4, erro }() // Discover - discover, err := NewDiscoveryForInterface(ifname) + discover, err := NewDiscoveryForInterface(ifname, modifiers...) if err != nil { return conversation, err } - for _, mod := range modifiers { - discover = mod(discover) - } conversation = append(conversation, discover) // Offer diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index 429d404..93a6531 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -50,7 +50,7 @@ type DHCPv4 struct { // Modifier defines the signature for functions that can modify DHCPv4 // structures. This is used to simplify packet manipulation -type Modifier func(d *DHCPv4) *DHCPv4 +type Modifier func(d *DHCPv4) // IPv4AddrsForInterface obtains the currently-configured, non-loopback IPv4 // addresses for iface. @@ -106,7 +106,7 @@ func GenerateTransactionID() (TransactionID, error) { // won't be a valid DHCPv4 message so you will need to adjust its fields. // See also NewDiscovery, NewOffer, NewRequest, NewAcknowledge, NewInform and // NewRelease . -func New() (*DHCPv4, error) { +func New(modifiers ...Modifier) (*DHCPv4, error) { xid, err := GenerateTransactionID() if err != nil { return nil, err @@ -124,42 +124,37 @@ func New() (*DHCPv4, error) { GatewayIPAddr: net.IPv4zero, Options: make([]Option, 0, 10), } + for _, mod := range modifiers { + mod(&d) + } return &d, nil } // NewDiscoveryForInterface builds a new DHCPv4 Discovery message, with a default // Ethernet HW type and the hardware address obtained from the specified // interface. -func NewDiscoveryForInterface(ifname string) (*DHCPv4, error) { +func NewDiscoveryForInterface(ifname string, modifiers ...Modifier) (*DHCPv4, error) { iface, err := net.InterfaceByName(ifname) if err != nil { return nil, err } - return NewDiscovery(iface.HardwareAddr) + return NewDiscovery(iface.HardwareAddr, modifiers...) } // NewDiscovery builds a new DHCPv4 Discovery message, with a default Ethernet // HW type and specified hardware address. -func NewDiscovery(hwaddr net.HardwareAddr) (*DHCPv4, error) { - d, err := New() - if err != nil { - return nil, err - } - // get hw addr - d.OpCode = OpcodeBootRequest - d.HWType = iana.HWTypeEthernet - d.ClientHWAddr = hwaddr - d.SetBroadcast() - d.UpdateOption(&OptMessageType{MessageType: MessageTypeDiscover}) - d.UpdateOption(&OptParameterRequestList{ - RequestedOpts: []OptionCode{ +func NewDiscovery(hwaddr net.HardwareAddr, modifiers ...Modifier) (*DHCPv4, error) { + return New(PrependModifiers(modifiers, + WithBroadcast(true), + WithHwAddr(hwaddr), + WithRequestedOptions( OptionSubnetMask, OptionRouter, OptionDomainName, OptionDomainNameServer, - }, - }) - return d, nil + ), + WithMessageType(MessageTypeDiscover), + )...) } // NewInformForInterface builds a new DHCPv4 Informational message with default @@ -190,74 +185,46 @@ func NewInformForInterface(ifname string, needsBroadcast bool) (*DHCPv4, error) return pkt, nil } -// NewInform builds a new DHCPv4 Informational message with default Ethernet HW -// type and specified hardware address. It does NOT put a DHCP End option at the -// end. -func NewInform(hwaddr net.HardwareAddr, localIP net.IP) (*DHCPv4, error) { - d, err := New() - if err != nil { - return nil, err - } +// PrependModifiers prepends other to m. +func PrependModifiers(m []Modifier, other ...Modifier) []Modifier { + return append(other, m...) +} - d.OpCode = OpcodeBootRequest - d.HWType = iana.HWTypeEthernet - d.ClientHWAddr = hwaddr - d.ClientIPAddr = localIP - d.UpdateOption(&OptMessageType{MessageType: MessageTypeInform}) - return d, nil +// NewInform builds a new DHCPv4 Informational message with the specified +// hardware address. +func NewInform(hwaddr net.HardwareAddr, localIP net.IP, modifiers ...Modifier) (*DHCPv4, error) { + return New(PrependModifiers( + modifiers, + WithHwAddr(hwaddr), + WithMessageType(MessageTypeInform), + WithClientIP(localIP), + )...) } // NewRequestFromOffer builds a DHCPv4 request from an offer. func NewRequestFromOffer(offer *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) { - d, err := New() - if err != nil { - return nil, err - } - d.OpCode = OpcodeBootRequest - d.HWType = offer.HWType - d.ClientHWAddr = offer.ClientHWAddr - d.TransactionID = offer.TransactionID - if offer.IsBroadcast() { - d.SetBroadcast() - } else { - d.SetUnicast() - } // find server IP address - var serverIP []byte - for _, opt := range offer.Options { - if opt.Code() == OptionServerIdentifier { - serverIP = opt.(*OptServerIdentifier).ServerID - } + var serverIP net.IP + serverID := offer.GetOneOption(OptionServerIdentifier) + if serverID != nil { + serverIP = serverID.(*OptServerIdentifier).ServerID } if serverIP == nil { return nil, errors.New("Missing Server IP Address in DHCP Offer") } - d.ServerIPAddr = serverIP - d.UpdateOption(&OptMessageType{MessageType: MessageTypeRequest}) - d.UpdateOption(&OptRequestedIPAddress{RequestedAddr: offer.YourIPAddr}) - d.UpdateOption(&OptServerIdentifier{ServerID: serverIP}) - for _, mod := range modifiers { - d = mod(d) - } - return d, nil + + return New(PrependModifiers(modifiers, + WithReply(offer), + WithMessageType(MessageTypeRequest), + WithServerIP(serverIP), + WithOption(&OptRequestedIPAddress{RequestedAddr: offer.YourIPAddr}), + WithOption(&OptServerIdentifier{ServerID: serverIP}), + )...) } // NewReplyFromRequest builds a DHCPv4 reply from a request. func NewReplyFromRequest(request *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) { - reply, err := New() - if err != nil { - return nil, err - } - reply.OpCode = OpcodeBootReply - reply.HWType = request.HWType - reply.ClientHWAddr = request.ClientHWAddr - reply.TransactionID = request.TransactionID - reply.Flags = request.Flags - reply.GatewayIPAddr = request.GatewayIPAddr - for _, mod := range modifiers { - reply = mod(reply) - } - return reply, nil + return New(PrependModifiers(modifiers, WithReply(request))...) } // FromBytes encodes the DHCPv4 packet into a sequence of bytes, and returns an diff --git a/dhcpv4/modifiers.go b/dhcpv4/modifiers.go index cd80bc9..0759491 100644 --- a/dhcpv4/modifiers.go +++ b/dhcpv4/modifiers.go @@ -3,42 +3,84 @@ package dhcpv4 import ( "net" + "github.com/insomniacslk/dhcp/iana" "github.com/insomniacslk/dhcp/rfc1035label" ) // WithTransactionID sets the Transaction ID for the DHCPv4 packet func WithTransactionID(xid TransactionID) Modifier { - return func(d *DHCPv4) *DHCPv4 { + return func(d *DHCPv4) { d.TransactionID = xid - return d + } +} + +// WithClientIP sets the Client IP for a DHCPv4 packet. +func WithClientIP(ip net.IP) Modifier { + return func(d *DHCPv4) { + d.ClientIPAddr = ip + } +} + +// WithYourIP sets the Your IP for a DHCPv4 packet. +func WithYourIP(ip net.IP) Modifier { + return func(d *DHCPv4) { + d.YourIPAddr = ip + } +} + +// WithServerIP sets the Server IP for a DHCPv4 packet. +func WithServerIP(ip net.IP) Modifier { + return func(d *DHCPv4) { + d.ServerIPAddr = ip + } +} + +// WithReply fills in opcode, hwtype, xid, clienthwaddr, flags, and gateway ip +// addr from the given packet. +func WithReply(request *DHCPv4) Modifier { + return func(d *DHCPv4) { + if request.OpCode == OpcodeBootRequest { + d.OpCode = OpcodeBootReply + } else { + d.OpCode = OpcodeBootRequest + } + d.HWType = request.HWType + d.TransactionID = request.TransactionID + d.ClientHWAddr = request.ClientHWAddr + d.Flags = request.Flags + d.GatewayIPAddr = request.GatewayIPAddr + } +} + +// WithHWType sets the Hardware Type for a DHCPv4 packet. +func WithHWType(hwt iana.HWType) Modifier { + return func(d *DHCPv4) { + d.HWType = hwt } } // WithBroadcast sets the packet to be broadcast or unicast func WithBroadcast(broadcast bool) Modifier { - return func(d *DHCPv4) *DHCPv4 { + return func(d *DHCPv4) { if broadcast { d.SetBroadcast() } else { d.SetUnicast() } - return d } } // WithHwAddr sets the hardware address for a packet func WithHwAddr(hwaddr net.HardwareAddr) Modifier { - return func(d *DHCPv4) *DHCPv4 { + return func(d *DHCPv4) { d.ClientHWAddr = hwaddr - return d } } // WithOption appends a DHCPv4 option provided by the user func WithOption(opt Option) Modifier { - return func(d *DHCPv4) *DHCPv4 { + return func(d *DHCPv4) { d.UpdateOption(opt) - return d } } @@ -54,13 +96,18 @@ func WithUserClass(uc []byte, rfc bool) Modifier { } // WithNetboot adds bootfile URL and bootfile param options to a DHCPv4 packet. -func WithNetboot(d *DHCPv4) *DHCPv4 { - return WithRequestedOptions(OptionTFTPServerName, OptionBootfileName)(d) +func WithNetboot(d *DHCPv4) { + WithRequestedOptions(OptionTFTPServerName, OptionBootfileName)(d) +} + +// WithMessageType adds the DHCPv4 message type m to a packet. +func WithMessageType(m MessageType) Modifier { + return WithOption(&OptMessageType{m}) } // WithRequestedOptions adds requested options to the packet. func WithRequestedOptions(optionCodes ...OptionCode) Modifier { - return func(d *DHCPv4) *DHCPv4 { + return func(d *DHCPv4) { params := d.GetOneOption(OptionParameterRequestList) if params == nil { d.UpdateOption(&OptParameterRequestList{OptionCodeList(optionCodes)}) @@ -68,18 +115,16 @@ func WithRequestedOptions(optionCodes ...OptionCode) Modifier { opts := params.(*OptParameterRequestList) opts.RequestedOpts.Add(optionCodes...) } - return d } } // WithRelay adds parameters required for DHCPv4 to be relayed by the relay // server with given ip func WithRelay(ip net.IP) Modifier { - return func(d *DHCPv4) *DHCPv4 { + return func(d *DHCPv4) { d.SetUnicast() d.GatewayIPAddr = ip d.HopCount += 1 - return d } } diff --git a/dhcpv4/modifiers_test.go b/dhcpv4/modifiers_test.go index d9bb3c7..2cac2a0 100644 --- a/dhcpv4/modifiers_test.go +++ b/dhcpv4/modifiers_test.go @@ -10,33 +10,34 @@ import ( func TestTransactionIDModifier(t *testing.T) { d, err := New() require.NoError(t, err) - d = WithTransactionID(TransactionID{0xdd, 0xcc, 0xbb, 0xaa})(d) + WithTransactionID(TransactionID{0xdd, 0xcc, 0xbb, 0xaa})(d) require.Equal(t, TransactionID{0xdd, 0xcc, 0xbb, 0xaa}, d.TransactionID) } func TestBroadcastModifier(t *testing.T) { d, err := New() require.NoError(t, err) + // set and test broadcast - d = WithBroadcast(true)(d) + WithBroadcast(true)(d) require.Equal(t, true, d.IsBroadcast()) + // set and test unicast - d = WithBroadcast(false)(d) + WithBroadcast(false)(d) require.Equal(t, true, d.IsUnicast()) } func TestHwAddrModifier(t *testing.T) { - d, err := New() + hwaddr := net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf} + d, err := New(WithHwAddr(hwaddr)) require.NoError(t, err) - hwaddr := net.HardwareAddr{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} - d = WithHwAddr(hwaddr)(d) require.Equal(t, hwaddr, d.ClientHWAddr) } func TestWithOptionModifier(t *testing.T) { - d, err := New() + d, err := New(WithOption(&OptDomainName{DomainName: "slackware.it"})) require.NoError(t, err) - d = WithOption(&OptDomainName{DomainName: "slackware.it"})(d) + opt := d.GetOneOption(OptionDomainName) require.NotNil(t, opt) dnOpt := opt.(*OptDomainName) @@ -44,10 +45,9 @@ func TestWithOptionModifier(t *testing.T) { } func TestUserClassModifier(t *testing.T) { - d, err := New() + d, err := New(WithUserClass([]byte("linuxboot"), false)) require.NoError(t, err) - userClass := WithUserClass([]byte("linuxboot"), false) - d = userClass(d) + expected := []byte{ 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } @@ -56,9 +56,9 @@ func TestUserClassModifier(t *testing.T) { } func TestUserClassModifierRFC(t *testing.T) { - d, _ := New() - userClass := WithUserClass([]byte("linuxboot"), true) - d = userClass(d) + d, err := New(WithUserClass([]byte("linuxboot"), true)) + require.NoError(t, err) + expected := []byte{ 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } @@ -67,53 +67,52 @@ func TestUserClassModifierRFC(t *testing.T) { } func TestWithNetboot(t *testing.T) { - d, _ := New() - d = WithNetboot(d) + d, err := New(WithNetboot) + require.NoError(t, err) + require.Equal(t, "Parameter Request List -> TFTP Server Name, Bootfile Name", d.Options[0].String()) } func TestWithNetbootExistingTFTP(t *testing.T) { - d, _ := New() - OptParams := &OptParameterRequestList{ + d, err := New() + require.NoError(t, err) + d.UpdateOption(&OptParameterRequestList{ RequestedOpts: []OptionCode{OptionTFTPServerName}, - } - d.UpdateOption(OptParams) - d = WithNetboot(d) + }) + WithNetboot(d) require.Equal(t, "Parameter Request List -> TFTP Server Name, Bootfile Name", d.Options[0].String()) } func TestWithNetbootExistingBootfileName(t *testing.T) { d, _ := New() - OptParams := &OptParameterRequestList{ + d.UpdateOption(&OptParameterRequestList{ RequestedOpts: []OptionCode{OptionBootfileName}, - } - d.UpdateOption(OptParams) - d = WithNetboot(d) + }) + WithNetboot(d) require.Equal(t, "Parameter Request List -> Bootfile Name, TFTP Server Name", d.Options[0].String()) } func TestWithNetbootExistingBoth(t *testing.T) { d, _ := New() - OptParams := &OptParameterRequestList{ + d.UpdateOption(&OptParameterRequestList{ RequestedOpts: []OptionCode{OptionBootfileName, OptionTFTPServerName}, - } - d.UpdateOption(OptParams) - d = WithNetboot(d) + }) + WithNetboot(d) require.Equal(t, "Parameter Request List -> Bootfile Name, TFTP Server Name", d.Options[0].String()) } func TestWithRequestedOptions(t *testing.T) { // Check if OptionParameterRequestList is created when not present - d, err := New() + d, err := New(WithRequestedOptions(OptionFQDN)) require.NoError(t, err) - d = WithRequestedOptions(OptionFQDN)(d) require.NotNil(t, d) o := d.GetOneOption(OptionParameterRequestList) require.NotNil(t, o) opts := o.(*OptParameterRequestList) require.ElementsMatch(t, opts.RequestedOpts, []OptionCode{OptionFQDN}) + // Check if already set options are preserved - d = WithRequestedOptions(OptionHostName)(d) + WithRequestedOptions(OptionHostName)(d) require.NotNil(t, d) o = d.GetOneOption(OptionParameterRequestList) require.NotNil(t, o) @@ -122,20 +121,19 @@ func TestWithRequestedOptions(t *testing.T) { } func TestWithRelay(t *testing.T) { - d, err := New() + ip := net.IP{10, 0, 0, 1} + d, err := New(WithRelay(ip)) require.NoError(t, err) - ip := net.ParseIP("10.0.0.1") - require.NotNil(t, ip) - d = WithRelay(ip)(d) - require.NotNil(t, d) + require.True(t, d.IsUnicast(), "expected unicast") require.Equal(t, ip, d.GatewayIPAddr) require.Equal(t, uint8(1), d.HopCount) } func TestWithNetmask(t *testing.T) { - d := &DHCPv4{} - d = WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) + d, err := New(WithNetmask(net.IPv4Mask(255, 255, 255, 0))) + require.NoError(t, err) + require.Equal(t, 1, len(d.Options)) require.Equal(t, OptionSubnetMask, d.Options[0].Code()) osm := d.Options[0].(*OptSubnetMask) @@ -143,8 +141,9 @@ func TestWithNetmask(t *testing.T) { } func TestWithLeaseTime(t *testing.T) { - d := &DHCPv4{} - d = WithLeaseTime(uint32(3600))(d) + d, err := New(WithLeaseTime(uint32(3600))) + require.NoError(t, err) + require.Equal(t, 1, len(d.Options)) require.Equal(t, OptionIPAddressLeaseTime, d.Options[0].Code()) olt := d.Options[0].(*OptIPAddressLeaseTime) @@ -152,8 +151,9 @@ func TestWithLeaseTime(t *testing.T) { } func TestWithDNS(t *testing.T) { - d := &DHCPv4{} - d = WithDNS(net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2"))(d) + d, err := New(WithDNS(net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2"))) + require.NoError(t, err) + require.Equal(t, 1, len(d.Options)) require.Equal(t, OptionDomainNameServer, d.Options[0].Code()) olt := d.Options[0].(*OptDomainNameServer) @@ -164,8 +164,9 @@ func TestWithDNS(t *testing.T) { } func TestWithDomainSearchList(t *testing.T) { - d := &DHCPv4{} - d = WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d) + d, err := New(WithDomainSearchList("slackware.it", "dhcp.slackware.it")) + require.NoError(t, err) + require.Equal(t, 1, len(d.Options)) osl := d.Options[0].(*OptDomainSearch) require.Equal(t, OptionDNSDomainSearchList, osl.Code()) @@ -176,9 +177,10 @@ func TestWithDomainSearchList(t *testing.T) { } func TestWithRouter(t *testing.T) { - d := &DHCPv4{} rtr := net.ParseIP("10.0.0.254") - d = WithRouter(rtr)(d) + d, err := New(WithRouter(rtr)) + require.NoError(t, err) + require.Equal(t, 1, len(d.Options)) ortr := d.Options[0].(*OptRouter) require.Equal(t, OptionRouter, ortr.Code()) diff --git a/netboot/netconf_test.go b/netboot/netconf_test.go index 91d4482..00b39b8 100644 --- a/netboot/netconf_test.go +++ b/netboot/netconf_test.go @@ -107,106 +107,114 @@ func TestGetNetConfFromPacketv6(t *testing.T) { } func TestGetNetConfFromPacketv4AddrZero(t *testing.T) { - d := dhcpv4.DHCPv4{} - d.YourIPAddr = net.IPv4zero - _, err := GetNetConfFromPacketv4(&d) + d, _ := dhcpv4.New(dhcpv4.WithYourIP(net.IPv4zero)) + _, err := GetNetConfFromPacketv4(d) require.Error(t, err) } func TestGetNetConfFromPacketv4NoMask(t *testing.T) { - d := dhcpv4.DHCPv4{} - d.YourIPAddr = net.ParseIP("10.0.0.1") - _, err := GetNetConfFromPacketv4(&d) + d, _ := dhcpv4.New(dhcpv4.WithYourIP(net.ParseIP("10.0.0.1"))) + _, err := GetNetConfFromPacketv4(d) require.Error(t, err) } func TestGetNetConfFromPacketv4NullMask(t *testing.T) { - d := &dhcpv4.DHCPv4{} - d.YourIPAddr = net.ParseIP("10.0.0.1") - d = dhcpv4.WithNetmask(net.IPv4Mask(0, 0, 0, 0))(d) + d, _ := dhcpv4.New( + dhcpv4.WithNetmask(net.IPv4Mask(0, 0, 0, 0)), + dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), + ) _, err := GetNetConfFromPacketv4(d) require.Error(t, err) } func TestGetNetConfFromPacketv4NoLeaseTime(t *testing.T) { - d := &dhcpv4.DHCPv4{} - d.YourIPAddr = net.ParseIP("10.0.0.1") - d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) + d, _ := dhcpv4.New( + dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), + dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), + ) _, err := GetNetConfFromPacketv4(d) require.Error(t, err) } func TestGetNetConfFromPacketv4NoDNS(t *testing.T) { - d := &dhcpv4.DHCPv4{} - d.YourIPAddr = net.ParseIP("10.0.0.1") - d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) - d = dhcpv4.WithLeaseTime(uint32(0))(d) + d, _ := dhcpv4.New( + dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), + dhcpv4.WithLeaseTime(uint32(0)), + dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), + ) _, err := GetNetConfFromPacketv4(d) require.Error(t, err) } func TestGetNetConfFromPacketv4EmptyDNSList(t *testing.T) { - d := &dhcpv4.DHCPv4{} - d.YourIPAddr = net.ParseIP("10.0.0.1") - d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) - d = dhcpv4.WithLeaseTime(uint32(0))(d) - d = dhcpv4.WithDNS()(d) + d, _ := dhcpv4.New( + dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), + dhcpv4.WithLeaseTime(uint32(0)), + dhcpv4.WithDNS(), + dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), + ) _, err := GetNetConfFromPacketv4(d) require.Error(t, err) } func TestGetNetConfFromPacketv4NoSearchList(t *testing.T) { - d := &dhcpv4.DHCPv4{} - d.YourIPAddr = net.ParseIP("10.0.0.1") - d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) - d = dhcpv4.WithLeaseTime(uint32(0))(d) - d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d) + d, _ := dhcpv4.New( + dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), + dhcpv4.WithLeaseTime(uint32(0)), + dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")), + dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), + ) _, err := GetNetConfFromPacketv4(d) require.Error(t, err) } func TestGetNetConfFromPacketv4EmptySearchList(t *testing.T) { - d := &dhcpv4.DHCPv4{} - d.YourIPAddr = net.ParseIP("10.0.0.1") - d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) - d = dhcpv4.WithLeaseTime(uint32(0))(d) - d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d) - d = dhcpv4.WithDomainSearchList()(d) + d, _ := dhcpv4.New( + dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), + dhcpv4.WithLeaseTime(uint32(0)), + dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")), + dhcpv4.WithDomainSearchList(), + dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), + ) _, err := GetNetConfFromPacketv4(d) require.Error(t, err) } func TestGetNetConfFromPacketv4NoRouter(t *testing.T) { - d := &dhcpv4.DHCPv4{} - d.YourIPAddr = net.ParseIP("10.0.0.1") - d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) - d = dhcpv4.WithLeaseTime(uint32(0))(d) - d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d) - d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d) + d, _ := dhcpv4.New( + dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), + dhcpv4.WithLeaseTime(uint32(0)), + dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")), + dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it"), + dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), + ) _, err := GetNetConfFromPacketv4(d) require.Error(t, err) } func TestGetNetConfFromPacketv4EmptyRouter(t *testing.T) { - d := &dhcpv4.DHCPv4{} - d.YourIPAddr = net.ParseIP("10.0.0.1") - d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) - d = dhcpv4.WithLeaseTime(uint32(0))(d) - d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d) - d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d) - d = dhcpv4.WithRouter()(d) + d, _ := dhcpv4.New( + dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), + dhcpv4.WithLeaseTime(uint32(0)), + dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")), + dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it"), + dhcpv4.WithRouter(), + dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), + ) _, err := GetNetConfFromPacketv4(d) require.Error(t, err) } func TestGetNetConfFromPacketv4(t *testing.T) { - d := &dhcpv4.DHCPv4{} - d.YourIPAddr = net.ParseIP("10.0.0.1") - d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) - d = dhcpv4.WithLeaseTime(uint32(5200))(d) - d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d) - d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d) - d = dhcpv4.WithRouter(net.ParseIP("10.0.0.254"))(d) + d, _ := dhcpv4.New( + dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), + dhcpv4.WithLeaseTime(uint32(5200)), + dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")), + dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it"), + dhcpv4.WithRouter(net.ParseIP("10.0.0.254")), + dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), + ) + netconf, err := GetNetConfFromPacketv4(d) require.NoError(t, err) // check addresses |