summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorChristopher Koch <c@chrisko.ch>2019-01-19 21:29:26 +0000
committerinsomniac <insomniacslk@users.noreply.github.com>2019-01-19 22:32:20 +0000
commitfe6f307df5d78a54ddd4a56a275043317148fe5a (patch)
tree96c357bf87bd4939b503763ffc94c66aa73e336c
parent5e6e8baddaa29b866abe0b865e0c66c9190ec2f7 (diff)
dhcpv4: build more packets with modifiers
Also drop unnecessary return value of Modifier.
-rw-r--r--dhcpv4/async/client.go6
-rw-r--r--dhcpv4/bsdp/bsdp.go72
-rw-r--r--dhcpv4/bsdp/bsdp_test.go1
-rw-r--r--dhcpv4/client.go5
-rw-r--r--dhcpv4/dhcpv4.go115
-rw-r--r--dhcpv4/modifiers.go73
-rw-r--r--dhcpv4/modifiers_test.go96
-rw-r--r--netboot/netconf_test.go112
8 files changed, 238 insertions, 242 deletions
diff --git a/dhcpv4/async/client.go b/dhcpv4/async/client.go
index 9844180..81b10f7 100644
--- a/dhcpv4/async/client.go
+++ b/dhcpv4/async/client.go
@@ -194,11 +194,7 @@ func (c *Client) remoteAddr() (*net.UDPAddr, error) {
// Send inserts a message to the queue to be sent asynchronously.
// Returns a future which resolves to response and error.
-func (c *Client) Send(message *dhcpv4.DHCPv4, modifiers ...dhcpv4.Modifier) *promise.Future {
- for _, mod := range modifiers {
- message = mod(message)
- }
-
+func (c *Client) Send(message *dhcpv4.DHCPv4) *promise.Future {
p := promise.NewPromise()
c.packetsLock.Lock()
c.packets[message.TransactionID] = p
diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go
index 21cc71a..3cc87d2 100644
--- a/dhcpv4/bsdp/bsdp.go
+++ b/dhcpv4/bsdp/bsdp.go
@@ -96,17 +96,17 @@ func NewInformListForInterface(ifname string, replyPort uint16) (*dhcpv4.DHCPv4,
// NewInformList creates a new INFORM packet for interface with hardware address
// `hwaddr` and IP `localIP`. Packet will be sent out on port `replyPort`.
-func NewInformList(hwaddr net.HardwareAddr, localIP net.IP, replyPort uint16) (*dhcpv4.DHCPv4, error) {
- d, err := dhcpv4.NewInform(hwaddr, localIP)
- if err != nil {
- return nil, err
- }
-
+func NewInformList(hwaddr net.HardwareAddr, localIP net.IP, replyPort uint16, modifiers ...dhcpv4.Modifier) (*dhcpv4.DHCPv4, error) {
// Validate replyPort first
if needsReplyPort(replyPort) && replyPort >= 1024 {
return nil, errors.New("replyPort must be a privileged port")
}
+ vendorClassID, err := MakeVendorClassIdentifier()
+ if err != nil {
+ return nil, err
+ }
+
// These are vendor-specific options used to pass along BSDP information.
vendorOpts := []dhcpv4.Option{
&OptMessageType{MessageTypeList},
@@ -115,44 +115,24 @@ func NewInformList(hwaddr net.HardwareAddr, localIP net.IP, replyPort uint16) (*
if needsReplyPort(replyPort) {
vendorOpts = append(vendorOpts, &OptReplyPort{replyPort})
}
- d.UpdateOption(&OptVendorSpecificInformation{vendorOpts})
- d.UpdateOption(&dhcpv4.OptParameterRequestList{
- RequestedOpts: []dhcpv4.OptionCode{
+ return dhcpv4.NewInform(hwaddr, localIP,
+ dhcpv4.PrependModifiers(modifiers, dhcpv4.WithRequestedOptions(
dhcpv4.OptionVendorSpecificInformation,
dhcpv4.OptionClassIdentifier,
- },
- })
- d.UpdateOption(&dhcpv4.OptMaximumDHCPMessageSize{Size: MaxDHCPMessageSize})
-
- vendorClassID, err := MakeVendorClassIdentifier()
- if err != nil {
- return nil, err
- }
- d.UpdateOption(&dhcpv4.OptClassIdentifier{Identifier: vendorClassID})
- return d, nil
+ ),
+ dhcpv4.WithOption(&dhcpv4.OptMaximumDHCPMessageSize{Size: MaxDHCPMessageSize}),
+ dhcpv4.WithOption(&dhcpv4.OptClassIdentifier{Identifier: vendorClassID}),
+ dhcpv4.WithOption(&OptVendorSpecificInformation{vendorOpts}),
+ )...)
}
// InformSelectForAck constructs an INFORM[SELECT] packet given an ACK to the
// previously-sent INFORM[LIST].
func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootImage) (*dhcpv4.DHCPv4, error) {
- d, err := dhcpv4.New()
- if err != nil {
- return nil, err
- }
-
if needsReplyPort(replyPort) && replyPort >= 1024 {
return nil, errors.New("replyPort must be a privileged port")
}
- d.OpCode = dhcpv4.OpcodeBootRequest
- d.HWType = ack.HWType
- d.ClientHWAddr = ack.ClientHWAddr
- d.TransactionID = ack.TransactionID
- if ack.IsBroadcast() {
- d.SetBroadcast()
- } else {
- d.SetUnicast()
- }
// Data for OptionSelectedBootImageID
vendorOpts := []dhcpv4.Option{
@@ -161,6 +141,11 @@ func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootI
&OptSelectedBootImageID{selectedImage.ID},
}
+ // Validate replyPort if requested.
+ if needsReplyPort(replyPort) {
+ vendorOpts = append(vendorOpts, &OptReplyPort{replyPort})
+ }
+
// Find server IP address
var serverIP net.IP
if opt := ack.GetOneOption(dhcpv4.OptionServerIdentifier); opt != nil {
@@ -171,28 +156,23 @@ func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootI
}
vendorOpts = append(vendorOpts, &OptServerIdentifier{serverIP})
- // Validate replyPort if requested.
- if needsReplyPort(replyPort) {
- vendorOpts = append(vendorOpts, &OptReplyPort{replyPort})
- }
-
vendorClassID, err := MakeVendorClassIdentifier()
if err != nil {
return nil, err
}
- d.UpdateOption(&dhcpv4.OptClassIdentifier{Identifier: vendorClassID})
- d.UpdateOption(&dhcpv4.OptParameterRequestList{
- RequestedOpts: []dhcpv4.OptionCode{
+
+ return dhcpv4.New(dhcpv4.WithReply(&ack),
+ dhcpv4.WithOption(&dhcpv4.OptClassIdentifier{Identifier: vendorClassID}),
+ dhcpv4.WithRequestedOptions(
dhcpv4.OptionSubnetMask,
dhcpv4.OptionRouter,
dhcpv4.OptionBootfileName,
dhcpv4.OptionVendorSpecificInformation,
dhcpv4.OptionClassIdentifier,
- },
- })
- d.UpdateOption(&dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeInform})
- d.UpdateOption(&OptVendorSpecificInformation{vendorOpts})
- return d, nil
+ ),
+ dhcpv4.WithMessageType(dhcpv4.MessageTypeInform),
+ dhcpv4.WithOption(&OptVendorSpecificInformation{vendorOpts}),
+ )
}
// NewReplyForInformList constructs an ACK for the INFORM[LIST] packet `inform`
diff --git a/dhcpv4/bsdp/bsdp_test.go b/dhcpv4/bsdp/bsdp_test.go
index 2caa6e5..638a408 100644
--- a/dhcpv4/bsdp/bsdp_test.go
+++ b/dhcpv4/bsdp/bsdp_test.go
@@ -104,6 +104,7 @@ func TestNewInformList_ReplyPort(t *testing.T) {
func newAck(hwAddr net.HardwareAddr, transactionID [4]byte) *dhcpv4.DHCPv4 {
ack, _ := dhcpv4.New()
+ ack.OpCode = dhcpv4.OpcodeBootReply
ack.TransactionID = transactionID
ack.HWType = iana.HWTypeEthernet
ack.ClientHWAddr = hwAddr
diff --git a/dhcpv4/client.go b/dhcpv4/client.go
index d2d18db..ad80b34 100644
--- a/dhcpv4/client.go
+++ b/dhcpv4/client.go
@@ -220,13 +220,10 @@ func (c *Client) Exchange(ifname string, modifiers ...Modifier) ([]*DHCPv4, erro
}()
// Discover
- discover, err := NewDiscoveryForInterface(ifname)
+ discover, err := NewDiscoveryForInterface(ifname, modifiers...)
if err != nil {
return conversation, err
}
- for _, mod := range modifiers {
- discover = mod(discover)
- }
conversation = append(conversation, discover)
// Offer
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go
index 429d404..93a6531 100644
--- a/dhcpv4/dhcpv4.go
+++ b/dhcpv4/dhcpv4.go
@@ -50,7 +50,7 @@ type DHCPv4 struct {
// Modifier defines the signature for functions that can modify DHCPv4
// structures. This is used to simplify packet manipulation
-type Modifier func(d *DHCPv4) *DHCPv4
+type Modifier func(d *DHCPv4)
// IPv4AddrsForInterface obtains the currently-configured, non-loopback IPv4
// addresses for iface.
@@ -106,7 +106,7 @@ func GenerateTransactionID() (TransactionID, error) {
// won't be a valid DHCPv4 message so you will need to adjust its fields.
// See also NewDiscovery, NewOffer, NewRequest, NewAcknowledge, NewInform and
// NewRelease .
-func New() (*DHCPv4, error) {
+func New(modifiers ...Modifier) (*DHCPv4, error) {
xid, err := GenerateTransactionID()
if err != nil {
return nil, err
@@ -124,42 +124,37 @@ func New() (*DHCPv4, error) {
GatewayIPAddr: net.IPv4zero,
Options: make([]Option, 0, 10),
}
+ for _, mod := range modifiers {
+ mod(&d)
+ }
return &d, nil
}
// NewDiscoveryForInterface builds a new DHCPv4 Discovery message, with a default
// Ethernet HW type and the hardware address obtained from the specified
// interface.
-func NewDiscoveryForInterface(ifname string) (*DHCPv4, error) {
+func NewDiscoveryForInterface(ifname string, modifiers ...Modifier) (*DHCPv4, error) {
iface, err := net.InterfaceByName(ifname)
if err != nil {
return nil, err
}
- return NewDiscovery(iface.HardwareAddr)
+ return NewDiscovery(iface.HardwareAddr, modifiers...)
}
// NewDiscovery builds a new DHCPv4 Discovery message, with a default Ethernet
// HW type and specified hardware address.
-func NewDiscovery(hwaddr net.HardwareAddr) (*DHCPv4, error) {
- d, err := New()
- if err != nil {
- return nil, err
- }
- // get hw addr
- d.OpCode = OpcodeBootRequest
- d.HWType = iana.HWTypeEthernet
- d.ClientHWAddr = hwaddr
- d.SetBroadcast()
- d.UpdateOption(&OptMessageType{MessageType: MessageTypeDiscover})
- d.UpdateOption(&OptParameterRequestList{
- RequestedOpts: []OptionCode{
+func NewDiscovery(hwaddr net.HardwareAddr, modifiers ...Modifier) (*DHCPv4, error) {
+ return New(PrependModifiers(modifiers,
+ WithBroadcast(true),
+ WithHwAddr(hwaddr),
+ WithRequestedOptions(
OptionSubnetMask,
OptionRouter,
OptionDomainName,
OptionDomainNameServer,
- },
- })
- return d, nil
+ ),
+ WithMessageType(MessageTypeDiscover),
+ )...)
}
// NewInformForInterface builds a new DHCPv4 Informational message with default
@@ -190,74 +185,46 @@ func NewInformForInterface(ifname string, needsBroadcast bool) (*DHCPv4, error)
return pkt, nil
}
-// NewInform builds a new DHCPv4 Informational message with default Ethernet HW
-// type and specified hardware address. It does NOT put a DHCP End option at the
-// end.
-func NewInform(hwaddr net.HardwareAddr, localIP net.IP) (*DHCPv4, error) {
- d, err := New()
- if err != nil {
- return nil, err
- }
+// PrependModifiers prepends other to m.
+func PrependModifiers(m []Modifier, other ...Modifier) []Modifier {
+ return append(other, m...)
+}
- d.OpCode = OpcodeBootRequest
- d.HWType = iana.HWTypeEthernet
- d.ClientHWAddr = hwaddr
- d.ClientIPAddr = localIP
- d.UpdateOption(&OptMessageType{MessageType: MessageTypeInform})
- return d, nil
+// NewInform builds a new DHCPv4 Informational message with the specified
+// hardware address.
+func NewInform(hwaddr net.HardwareAddr, localIP net.IP, modifiers ...Modifier) (*DHCPv4, error) {
+ return New(PrependModifiers(
+ modifiers,
+ WithHwAddr(hwaddr),
+ WithMessageType(MessageTypeInform),
+ WithClientIP(localIP),
+ )...)
}
// NewRequestFromOffer builds a DHCPv4 request from an offer.
func NewRequestFromOffer(offer *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) {
- d, err := New()
- if err != nil {
- return nil, err
- }
- d.OpCode = OpcodeBootRequest
- d.HWType = offer.HWType
- d.ClientHWAddr = offer.ClientHWAddr
- d.TransactionID = offer.TransactionID
- if offer.IsBroadcast() {
- d.SetBroadcast()
- } else {
- d.SetUnicast()
- }
// find server IP address
- var serverIP []byte
- for _, opt := range offer.Options {
- if opt.Code() == OptionServerIdentifier {
- serverIP = opt.(*OptServerIdentifier).ServerID
- }
+ var serverIP net.IP
+ serverID := offer.GetOneOption(OptionServerIdentifier)
+ if serverID != nil {
+ serverIP = serverID.(*OptServerIdentifier).ServerID
}
if serverIP == nil {
return nil, errors.New("Missing Server IP Address in DHCP Offer")
}
- d.ServerIPAddr = serverIP
- d.UpdateOption(&OptMessageType{MessageType: MessageTypeRequest})
- d.UpdateOption(&OptRequestedIPAddress{RequestedAddr: offer.YourIPAddr})
- d.UpdateOption(&OptServerIdentifier{ServerID: serverIP})
- for _, mod := range modifiers {
- d = mod(d)
- }
- return d, nil
+
+ return New(PrependModifiers(modifiers,
+ WithReply(offer),
+ WithMessageType(MessageTypeRequest),
+ WithServerIP(serverIP),
+ WithOption(&OptRequestedIPAddress{RequestedAddr: offer.YourIPAddr}),
+ WithOption(&OptServerIdentifier{ServerID: serverIP}),
+ )...)
}
// NewReplyFromRequest builds a DHCPv4 reply from a request.
func NewReplyFromRequest(request *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) {
- reply, err := New()
- if err != nil {
- return nil, err
- }
- reply.OpCode = OpcodeBootReply
- reply.HWType = request.HWType
- reply.ClientHWAddr = request.ClientHWAddr
- reply.TransactionID = request.TransactionID
- reply.Flags = request.Flags
- reply.GatewayIPAddr = request.GatewayIPAddr
- for _, mod := range modifiers {
- reply = mod(reply)
- }
- return reply, nil
+ return New(PrependModifiers(modifiers, WithReply(request))...)
}
// FromBytes encodes the DHCPv4 packet into a sequence of bytes, and returns an
diff --git a/dhcpv4/modifiers.go b/dhcpv4/modifiers.go
index cd80bc9..0759491 100644
--- a/dhcpv4/modifiers.go
+++ b/dhcpv4/modifiers.go
@@ -3,42 +3,84 @@ package dhcpv4
import (
"net"
+ "github.com/insomniacslk/dhcp/iana"
"github.com/insomniacslk/dhcp/rfc1035label"
)
// WithTransactionID sets the Transaction ID for the DHCPv4 packet
func WithTransactionID(xid TransactionID) Modifier {
- return func(d *DHCPv4) *DHCPv4 {
+ return func(d *DHCPv4) {
d.TransactionID = xid
- return d
+ }
+}
+
+// WithClientIP sets the Client IP for a DHCPv4 packet.
+func WithClientIP(ip net.IP) Modifier {
+ return func(d *DHCPv4) {
+ d.ClientIPAddr = ip
+ }
+}
+
+// WithYourIP sets the Your IP for a DHCPv4 packet.
+func WithYourIP(ip net.IP) Modifier {
+ return func(d *DHCPv4) {
+ d.YourIPAddr = ip
+ }
+}
+
+// WithServerIP sets the Server IP for a DHCPv4 packet.
+func WithServerIP(ip net.IP) Modifier {
+ return func(d *DHCPv4) {
+ d.ServerIPAddr = ip
+ }
+}
+
+// WithReply fills in opcode, hwtype, xid, clienthwaddr, flags, and gateway ip
+// addr from the given packet.
+func WithReply(request *DHCPv4) Modifier {
+ return func(d *DHCPv4) {
+ if request.OpCode == OpcodeBootRequest {
+ d.OpCode = OpcodeBootReply
+ } else {
+ d.OpCode = OpcodeBootRequest
+ }
+ d.HWType = request.HWType
+ d.TransactionID = request.TransactionID
+ d.ClientHWAddr = request.ClientHWAddr
+ d.Flags = request.Flags
+ d.GatewayIPAddr = request.GatewayIPAddr
+ }
+}
+
+// WithHWType sets the Hardware Type for a DHCPv4 packet.
+func WithHWType(hwt iana.HWType) Modifier {
+ return func(d *DHCPv4) {
+ d.HWType = hwt
}
}
// WithBroadcast sets the packet to be broadcast or unicast
func WithBroadcast(broadcast bool) Modifier {
- return func(d *DHCPv4) *DHCPv4 {
+ return func(d *DHCPv4) {
if broadcast {
d.SetBroadcast()
} else {
d.SetUnicast()
}
- return d
}
}
// WithHwAddr sets the hardware address for a packet
func WithHwAddr(hwaddr net.HardwareAddr) Modifier {
- return func(d *DHCPv4) *DHCPv4 {
+ return func(d *DHCPv4) {
d.ClientHWAddr = hwaddr
- return d
}
}
// WithOption appends a DHCPv4 option provided by the user
func WithOption(opt Option) Modifier {
- return func(d *DHCPv4) *DHCPv4 {
+ return func(d *DHCPv4) {
d.UpdateOption(opt)
- return d
}
}
@@ -54,13 +96,18 @@ func WithUserClass(uc []byte, rfc bool) Modifier {
}
// WithNetboot adds bootfile URL and bootfile param options to a DHCPv4 packet.
-func WithNetboot(d *DHCPv4) *DHCPv4 {
- return WithRequestedOptions(OptionTFTPServerName, OptionBootfileName)(d)
+func WithNetboot(d *DHCPv4) {
+ WithRequestedOptions(OptionTFTPServerName, OptionBootfileName)(d)
+}
+
+// WithMessageType adds the DHCPv4 message type m to a packet.
+func WithMessageType(m MessageType) Modifier {
+ return WithOption(&OptMessageType{m})
}
// WithRequestedOptions adds requested options to the packet.
func WithRequestedOptions(optionCodes ...OptionCode) Modifier {
- return func(d *DHCPv4) *DHCPv4 {
+ return func(d *DHCPv4) {
params := d.GetOneOption(OptionParameterRequestList)
if params == nil {
d.UpdateOption(&OptParameterRequestList{OptionCodeList(optionCodes)})
@@ -68,18 +115,16 @@ func WithRequestedOptions(optionCodes ...OptionCode) Modifier {
opts := params.(*OptParameterRequestList)
opts.RequestedOpts.Add(optionCodes...)
}
- return d
}
}
// WithRelay adds parameters required for DHCPv4 to be relayed by the relay
// server with given ip
func WithRelay(ip net.IP) Modifier {
- return func(d *DHCPv4) *DHCPv4 {
+ return func(d *DHCPv4) {
d.SetUnicast()
d.GatewayIPAddr = ip
d.HopCount += 1
- return d
}
}
diff --git a/dhcpv4/modifiers_test.go b/dhcpv4/modifiers_test.go
index d9bb3c7..2cac2a0 100644
--- a/dhcpv4/modifiers_test.go
+++ b/dhcpv4/modifiers_test.go
@@ -10,33 +10,34 @@ import (
func TestTransactionIDModifier(t *testing.T) {
d, err := New()
require.NoError(t, err)
- d = WithTransactionID(TransactionID{0xdd, 0xcc, 0xbb, 0xaa})(d)
+ WithTransactionID(TransactionID{0xdd, 0xcc, 0xbb, 0xaa})(d)
require.Equal(t, TransactionID{0xdd, 0xcc, 0xbb, 0xaa}, d.TransactionID)
}
func TestBroadcastModifier(t *testing.T) {
d, err := New()
require.NoError(t, err)
+
// set and test broadcast
- d = WithBroadcast(true)(d)
+ WithBroadcast(true)(d)
require.Equal(t, true, d.IsBroadcast())
+
// set and test unicast
- d = WithBroadcast(false)(d)
+ WithBroadcast(false)(d)
require.Equal(t, true, d.IsUnicast())
}
func TestHwAddrModifier(t *testing.T) {
- d, err := New()
+ hwaddr := net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
+ d, err := New(WithHwAddr(hwaddr))
require.NoError(t, err)
- hwaddr := net.HardwareAddr{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
- d = WithHwAddr(hwaddr)(d)
require.Equal(t, hwaddr, d.ClientHWAddr)
}
func TestWithOptionModifier(t *testing.T) {
- d, err := New()
+ d, err := New(WithOption(&OptDomainName{DomainName: "slackware.it"}))
require.NoError(t, err)
- d = WithOption(&OptDomainName{DomainName: "slackware.it"})(d)
+
opt := d.GetOneOption(OptionDomainName)
require.NotNil(t, opt)
dnOpt := opt.(*OptDomainName)
@@ -44,10 +45,9 @@ func TestWithOptionModifier(t *testing.T) {
}
func TestUserClassModifier(t *testing.T) {
- d, err := New()
+ d, err := New(WithUserClass([]byte("linuxboot"), false))
require.NoError(t, err)
- userClass := WithUserClass([]byte("linuxboot"), false)
- d = userClass(d)
+
expected := []byte{
'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't',
}
@@ -56,9 +56,9 @@ func TestUserClassModifier(t *testing.T) {
}
func TestUserClassModifierRFC(t *testing.T) {
- d, _ := New()
- userClass := WithUserClass([]byte("linuxboot"), true)
- d = userClass(d)
+ d, err := New(WithUserClass([]byte("linuxboot"), true))
+ require.NoError(t, err)
+
expected := []byte{
9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't',
}
@@ -67,53 +67,52 @@ func TestUserClassModifierRFC(t *testing.T) {
}
func TestWithNetboot(t *testing.T) {
- d, _ := New()
- d = WithNetboot(d)
+ d, err := New(WithNetboot)
+ require.NoError(t, err)
+
require.Equal(t, "Parameter Request List -> TFTP Server Name, Bootfile Name", d.Options[0].String())
}
func TestWithNetbootExistingTFTP(t *testing.T) {
- d, _ := New()
- OptParams := &OptParameterRequestList{
+ d, err := New()
+ require.NoError(t, err)
+ d.UpdateOption(&OptParameterRequestList{
RequestedOpts: []OptionCode{OptionTFTPServerName},
- }
- d.UpdateOption(OptParams)
- d = WithNetboot(d)
+ })
+ WithNetboot(d)
require.Equal(t, "Parameter Request List -> TFTP Server Name, Bootfile Name", d.Options[0].String())
}
func TestWithNetbootExistingBootfileName(t *testing.T) {
d, _ := New()
- OptParams := &OptParameterRequestList{
+ d.UpdateOption(&OptParameterRequestList{
RequestedOpts: []OptionCode{OptionBootfileName},
- }
- d.UpdateOption(OptParams)
- d = WithNetboot(d)
+ })
+ WithNetboot(d)
require.Equal(t, "Parameter Request List -> Bootfile Name, TFTP Server Name", d.Options[0].String())
}
func TestWithNetbootExistingBoth(t *testing.T) {
d, _ := New()
- OptParams := &OptParameterRequestList{
+ d.UpdateOption(&OptParameterRequestList{
RequestedOpts: []OptionCode{OptionBootfileName, OptionTFTPServerName},
- }
- d.UpdateOption(OptParams)
- d = WithNetboot(d)
+ })
+ WithNetboot(d)
require.Equal(t, "Parameter Request List -> Bootfile Name, TFTP Server Name", d.Options[0].String())
}
func TestWithRequestedOptions(t *testing.T) {
// Check if OptionParameterRequestList is created when not present
- d, err := New()
+ d, err := New(WithRequestedOptions(OptionFQDN))
require.NoError(t, err)
- d = WithRequestedOptions(OptionFQDN)(d)
require.NotNil(t, d)
o := d.GetOneOption(OptionParameterRequestList)
require.NotNil(t, o)
opts := o.(*OptParameterRequestList)
require.ElementsMatch(t, opts.RequestedOpts, []OptionCode{OptionFQDN})
+
// Check if already set options are preserved
- d = WithRequestedOptions(OptionHostName)(d)
+ WithRequestedOptions(OptionHostName)(d)
require.NotNil(t, d)
o = d.GetOneOption(OptionParameterRequestList)
require.NotNil(t, o)
@@ -122,20 +121,19 @@ func TestWithRequestedOptions(t *testing.T) {
}
func TestWithRelay(t *testing.T) {
- d, err := New()
+ ip := net.IP{10, 0, 0, 1}
+ d, err := New(WithRelay(ip))
require.NoError(t, err)
- ip := net.ParseIP("10.0.0.1")
- require.NotNil(t, ip)
- d = WithRelay(ip)(d)
- require.NotNil(t, d)
+
require.True(t, d.IsUnicast(), "expected unicast")
require.Equal(t, ip, d.GatewayIPAddr)
require.Equal(t, uint8(1), d.HopCount)
}
func TestWithNetmask(t *testing.T) {
- d := &DHCPv4{}
- d = WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
+ d, err := New(WithNetmask(net.IPv4Mask(255, 255, 255, 0)))
+ require.NoError(t, err)
+
require.Equal(t, 1, len(d.Options))
require.Equal(t, OptionSubnetMask, d.Options[0].Code())
osm := d.Options[0].(*OptSubnetMask)
@@ -143,8 +141,9 @@ func TestWithNetmask(t *testing.T) {
}
func TestWithLeaseTime(t *testing.T) {
- d := &DHCPv4{}
- d = WithLeaseTime(uint32(3600))(d)
+ d, err := New(WithLeaseTime(uint32(3600)))
+ require.NoError(t, err)
+
require.Equal(t, 1, len(d.Options))
require.Equal(t, OptionIPAddressLeaseTime, d.Options[0].Code())
olt := d.Options[0].(*OptIPAddressLeaseTime)
@@ -152,8 +151,9 @@ func TestWithLeaseTime(t *testing.T) {
}
func TestWithDNS(t *testing.T) {
- d := &DHCPv4{}
- d = WithDNS(net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2"))(d)
+ d, err := New(WithDNS(net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2")))
+ require.NoError(t, err)
+
require.Equal(t, 1, len(d.Options))
require.Equal(t, OptionDomainNameServer, d.Options[0].Code())
olt := d.Options[0].(*OptDomainNameServer)
@@ -164,8 +164,9 @@ func TestWithDNS(t *testing.T) {
}
func TestWithDomainSearchList(t *testing.T) {
- d := &DHCPv4{}
- d = WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d)
+ d, err := New(WithDomainSearchList("slackware.it", "dhcp.slackware.it"))
+ require.NoError(t, err)
+
require.Equal(t, 1, len(d.Options))
osl := d.Options[0].(*OptDomainSearch)
require.Equal(t, OptionDNSDomainSearchList, osl.Code())
@@ -176,9 +177,10 @@ func TestWithDomainSearchList(t *testing.T) {
}
func TestWithRouter(t *testing.T) {
- d := &DHCPv4{}
rtr := net.ParseIP("10.0.0.254")
- d = WithRouter(rtr)(d)
+ d, err := New(WithRouter(rtr))
+ require.NoError(t, err)
+
require.Equal(t, 1, len(d.Options))
ortr := d.Options[0].(*OptRouter)
require.Equal(t, OptionRouter, ortr.Code())
diff --git a/netboot/netconf_test.go b/netboot/netconf_test.go
index 91d4482..00b39b8 100644
--- a/netboot/netconf_test.go
+++ b/netboot/netconf_test.go
@@ -107,106 +107,114 @@ func TestGetNetConfFromPacketv6(t *testing.T) {
}
func TestGetNetConfFromPacketv4AddrZero(t *testing.T) {
- d := dhcpv4.DHCPv4{}
- d.YourIPAddr = net.IPv4zero
- _, err := GetNetConfFromPacketv4(&d)
+ d, _ := dhcpv4.New(dhcpv4.WithYourIP(net.IPv4zero))
+ _, err := GetNetConfFromPacketv4(d)
require.Error(t, err)
}
func TestGetNetConfFromPacketv4NoMask(t *testing.T) {
- d := dhcpv4.DHCPv4{}
- d.YourIPAddr = net.ParseIP("10.0.0.1")
- _, err := GetNetConfFromPacketv4(&d)
+ d, _ := dhcpv4.New(dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")))
+ _, err := GetNetConfFromPacketv4(d)
require.Error(t, err)
}
func TestGetNetConfFromPacketv4NullMask(t *testing.T) {
- d := &dhcpv4.DHCPv4{}
- d.YourIPAddr = net.ParseIP("10.0.0.1")
- d = dhcpv4.WithNetmask(net.IPv4Mask(0, 0, 0, 0))(d)
+ d, _ := dhcpv4.New(
+ dhcpv4.WithNetmask(net.IPv4Mask(0, 0, 0, 0)),
+ dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")),
+ )
_, err := GetNetConfFromPacketv4(d)
require.Error(t, err)
}
func TestGetNetConfFromPacketv4NoLeaseTime(t *testing.T) {
- d := &dhcpv4.DHCPv4{}
- d.YourIPAddr = net.ParseIP("10.0.0.1")
- d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
+ d, _ := dhcpv4.New(
+ dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)),
+ dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")),
+ )
_, err := GetNetConfFromPacketv4(d)
require.Error(t, err)
}
func TestGetNetConfFromPacketv4NoDNS(t *testing.T) {
- d := &dhcpv4.DHCPv4{}
- d.YourIPAddr = net.ParseIP("10.0.0.1")
- d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
- d = dhcpv4.WithLeaseTime(uint32(0))(d)
+ d, _ := dhcpv4.New(
+ dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)),
+ dhcpv4.WithLeaseTime(uint32(0)),
+ dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")),
+ )
_, err := GetNetConfFromPacketv4(d)
require.Error(t, err)
}
func TestGetNetConfFromPacketv4EmptyDNSList(t *testing.T) {
- d := &dhcpv4.DHCPv4{}
- d.YourIPAddr = net.ParseIP("10.0.0.1")
- d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
- d = dhcpv4.WithLeaseTime(uint32(0))(d)
- d = dhcpv4.WithDNS()(d)
+ d, _ := dhcpv4.New(
+ dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)),
+ dhcpv4.WithLeaseTime(uint32(0)),
+ dhcpv4.WithDNS(),
+ dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")),
+ )
_, err := GetNetConfFromPacketv4(d)
require.Error(t, err)
}
func TestGetNetConfFromPacketv4NoSearchList(t *testing.T) {
- d := &dhcpv4.DHCPv4{}
- d.YourIPAddr = net.ParseIP("10.0.0.1")
- d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
- d = dhcpv4.WithLeaseTime(uint32(0))(d)
- d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d)
+ d, _ := dhcpv4.New(
+ dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)),
+ dhcpv4.WithLeaseTime(uint32(0)),
+ dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")),
+ dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")),
+ )
_, err := GetNetConfFromPacketv4(d)
require.Error(t, err)
}
func TestGetNetConfFromPacketv4EmptySearchList(t *testing.T) {
- d := &dhcpv4.DHCPv4{}
- d.YourIPAddr = net.ParseIP("10.0.0.1")
- d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
- d = dhcpv4.WithLeaseTime(uint32(0))(d)
- d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d)
- d = dhcpv4.WithDomainSearchList()(d)
+ d, _ := dhcpv4.New(
+ dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)),
+ dhcpv4.WithLeaseTime(uint32(0)),
+ dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")),
+ dhcpv4.WithDomainSearchList(),
+ dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")),
+ )
_, err := GetNetConfFromPacketv4(d)
require.Error(t, err)
}
func TestGetNetConfFromPacketv4NoRouter(t *testing.T) {
- d := &dhcpv4.DHCPv4{}
- d.YourIPAddr = net.ParseIP("10.0.0.1")
- d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
- d = dhcpv4.WithLeaseTime(uint32(0))(d)
- d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d)
- d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d)
+ d, _ := dhcpv4.New(
+ dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)),
+ dhcpv4.WithLeaseTime(uint32(0)),
+ dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")),
+ dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it"),
+ dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")),
+ )
_, err := GetNetConfFromPacketv4(d)
require.Error(t, err)
}
func TestGetNetConfFromPacketv4EmptyRouter(t *testing.T) {
- d := &dhcpv4.DHCPv4{}
- d.YourIPAddr = net.ParseIP("10.0.0.1")
- d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
- d = dhcpv4.WithLeaseTime(uint32(0))(d)
- d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d)
- d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d)
- d = dhcpv4.WithRouter()(d)
+ d, _ := dhcpv4.New(
+ dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)),
+ dhcpv4.WithLeaseTime(uint32(0)),
+ dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")),
+ dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it"),
+ dhcpv4.WithRouter(),
+ dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")),
+ )
_, err := GetNetConfFromPacketv4(d)
require.Error(t, err)
}
func TestGetNetConfFromPacketv4(t *testing.T) {
- d := &dhcpv4.DHCPv4{}
- d.YourIPAddr = net.ParseIP("10.0.0.1")
- d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
- d = dhcpv4.WithLeaseTime(uint32(5200))(d)
- d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d)
- d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d)
- d = dhcpv4.WithRouter(net.ParseIP("10.0.0.254"))(d)
+ d, _ := dhcpv4.New(
+ dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)),
+ dhcpv4.WithLeaseTime(uint32(5200)),
+ dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")),
+ dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it"),
+ dhcpv4.WithRouter(net.ParseIP("10.0.0.254")),
+ dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")),
+ )
+
netconf, err := GetNetConfFromPacketv4(d)
require.NoError(t, err)
// check addresses