summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv4/client.go9
-rw-r--r--dhcpv4/dhcpv4.go10
-rw-r--r--dhcpv4/dhcpv4_test.go25
3 files changed, 34 insertions, 10 deletions
diff --git a/dhcpv4/client.go b/dhcpv4/client.go
index ce6f242..db7e71a 100644
--- a/dhcpv4/client.go
+++ b/dhcpv4/client.go
@@ -140,14 +140,13 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier
// Discover
if discover == nil {
- discover, err = NewDiscoveryForInterface(ifname, modifiers...)
+ discover, err = NewDiscoveryForInterface(ifname)
if err != nil {
return conversation, err
}
- } else {
- for _, mod := range modifiers {
- discover = mod(discover)
- }
+ }
+ for _, mod := range modifiers {
+ discover = mod(discover)
}
conversation[0] = *discover
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go
index a74f675..e027f28 100644
--- a/dhcpv4/dhcpv4.go
+++ b/dhcpv4/dhcpv4.go
@@ -124,7 +124,7 @@ func New() (*DHCPv4, error) {
// NewDiscoveryForInterface builds a new DHCPv4 Discovery message, with a default
// Ethernet HW type and the hardware address obtained from the specified
// interface.
-func NewDiscoveryForInterface(ifname string, modifiers ...Modifier) (*DHCPv4, error) {
+func NewDiscoveryForInterface(ifname string) (*DHCPv4, error) {
d, err := New()
if err != nil {
return nil, err
@@ -148,9 +148,6 @@ func NewDiscoveryForInterface(ifname string, modifiers ...Modifier) (*DHCPv4, er
OptionDomainNameServer,
},
})
- for _, mod := range modifiers {
- d = mod(d)
- }
return d, nil
}
@@ -228,7 +225,7 @@ func RequestFromOffer(offer DHCPv4, modifiers ...Modifier) (*DHCPv4, error) {
}
// NewReplyFromRequest builds a DHCPv4 reply from a request.
-func NewReplyFromRequest(request *DHCPv4) (*DHCPv4, error) {
+func NewReplyFromRequest(request *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) {
reply, err := New()
if err != nil {
return nil, err
@@ -241,6 +238,9 @@ func NewReplyFromRequest(request *DHCPv4) (*DHCPv4, error) {
reply.SetTransactionID(request.TransactionID())
reply.SetFlags(request.Flags())
reply.SetGatewayIPAddr(request.GatewayIPAddr())
+ for _, mod := range modifiers {
+ reply = mod(reply)
+ }
return reply, nil
}
diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go
index 55e082d..7e5f083 100644
--- a/dhcpv4/dhcpv4_test.go
+++ b/dhcpv4/dhcpv4_test.go
@@ -353,6 +353,19 @@ func TestDHCPv4RequestFromOffer(t *testing.T) {
require.Equal(t, MessageTypeRequest, *req.MessageType())
}
+func TestDHCPv4RequestFromOfferWithModifier(t *testing.T) {
+ offer, err := New()
+ require.NoError(t, err)
+ offer.AddOption(&OptMessageType{MessageType: MessageTypeOffer})
+ offer.AddOption(&OptServerIdentifier{ServerID: net.IPv4(192, 168, 0, 1)})
+ userClass := WithUserClass([]byte("linuxboot"))
+ req, err := RequestFromOffer(*offer, userClass)
+ require.NoError(t, err)
+ require.NotEqual(t, (*MessageType)(nil), *req.MessageType())
+ require.Equal(t, MessageTypeRequest, *req.MessageType())
+ require.Equal(t, "OptUserClass{userclass=[linuxboot]}", req.options[3].String())
+}
+
func TestNewReplyFromRequest(t *testing.T) {
discover, err := New()
require.NoError(t, err)
@@ -363,6 +376,18 @@ func TestNewReplyFromRequest(t *testing.T) {
require.Equal(t, discover.GatewayIPAddr(), reply.GatewayIPAddr())
}
+func TestNewReplyFromRequestWithModifier(t *testing.T) {
+ discover, err := New()
+ require.NoError(t, err)
+ discover.SetGatewayIPAddr(net.IPv4(192, 168, 0, 1))
+ userClass := WithUserClass([]byte("linuxboot"))
+ reply, err := NewReplyFromRequest(discover, userClass)
+ require.NoError(t, err)
+ require.Equal(t, discover.TransactionID(), reply.TransactionID())
+ require.Equal(t, discover.GatewayIPAddr(), reply.GatewayIPAddr())
+ require.Equal(t, "OptUserClass{userclass=[linuxboot]}", reply.options[0].String())
+}
+
func TestDHCPv4MessageTypeNil(t *testing.T) {
m, err := New()
require.NoError(t, err)