summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv4/dhcpv4.go14
-rw-r--r--dhcpv4/dhcpv4_test.go21
-rw-r--r--dhcpv4/modifiers.go54
-rw-r--r--dhcpv4/modifiers_test.go53
-rw-r--r--dhcpv4/option_domain_search.go3
-rw-r--r--dhcpv6/dhcpv6message.go2
-rw-r--r--netboot/netconf.go4
-rw-r--r--netboot/netconf_test.go137
8 files changed, 280 insertions, 8 deletions
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go
index 94b8351..2931565 100644
--- a/dhcpv4/dhcpv4.go
+++ b/dhcpv4/dhcpv4.go
@@ -612,6 +612,20 @@ func (d *DHCPv4) AddOption(option Option) {
}
}
+// UpdateOption updates the existing options with the passed option, adding it
+// at the end if not present already
+func (d *DHCPv4) UpdateOption(option Option) {
+ for idx, opt := range d.options {
+ if opt.Code() == option.Code() {
+ d.options[idx] = option
+ // don't look further
+ return
+ }
+ }
+ // if not found, add it
+ d.AddOption(option)
+}
+
// MessageType returns the message type, trying to extract it from the
// OptMessageType option. It returns nil if the message type cannot be extracted
func (d *DHCPv4) MessageType() *MessageType {
diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go
index 059ae0c..283e728 100644
--- a/dhcpv4/dhcpv4_test.go
+++ b/dhcpv4/dhcpv4_test.go
@@ -347,9 +347,7 @@ func TestGetOption(t *testing.T) {
func TestAddOption(t *testing.T) {
d, err := New()
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
hostnameOpt := &OptionGeneric{OptionCode: OptionHostName, Data: []byte("darkstar")}
bootFileOpt1 := &OptionGeneric{OptionCode: OptionBootfileName, Data: []byte("boot.img")}
@@ -363,6 +361,23 @@ func TestAddOption(t *testing.T) {
require.Equal(t, options[3].Code(), OptionEnd)
}
+func TestUpdateOption(t *testing.T) {
+ d, err := New()
+ require.NoError(t, err)
+ require.Equal(t, 1, len(d.options))
+ require.Equal(t, OptionEnd, d.options[0].Code())
+ // test that it will add the option since it's missing
+ d.UpdateOption(&OptDomainName{DomainName: "slackware.it"})
+ require.Equal(t, 2, len(d.options))
+ require.Equal(t, OptionDomainName, d.options[0].Code())
+ require.Equal(t, OptionEnd, d.options[1].Code())
+ // test that it won't add another option of the same type
+ d.UpdateOption(&OptDomainName{DomainName: "slackware.it"})
+ require.Equal(t, 2, len(d.options))
+ require.Equal(t, OptionDomainName, d.options[0].Code())
+ require.Equal(t, OptionEnd, d.options[1].Code())
+}
+
func TestStrippedOptions(t *testing.T) {
// Normal set of options that terminate with OptionEnd.
d, err := New()
diff --git a/dhcpv4/modifiers.go b/dhcpv4/modifiers.go
index 188d91f..033718e 100644
--- a/dhcpv4/modifiers.go
+++ b/dhcpv4/modifiers.go
@@ -2,6 +2,8 @@ package dhcpv4
import (
"net"
+
+ "github.com/insomniacslk/dhcp/rfc1035label"
)
// WithTransactionID sets the Transaction ID for the DHCPv4 packet
@@ -114,3 +116,55 @@ func WithRelay(ip net.IP) Modifier {
return d
}
}
+
+// WithNetmask adds or updates an OptSubnetMask
+func WithNetmask(mask net.IPMask) Modifier {
+ return func(d *DHCPv4) *DHCPv4 {
+ osm := OptSubnetMask{
+ SubnetMask: mask,
+ }
+ d.UpdateOption(&osm)
+ return d
+ }
+}
+
+// WithLeaseTime adds or updates an OptIPAddressLeaseTime
+func WithLeaseTime(leaseTime uint32) Modifier {
+ return func(d *DHCPv4) *DHCPv4 {
+ olt := OptIPAddressLeaseTime{
+ LeaseTime: leaseTime,
+ }
+ d.UpdateOption(&olt)
+ return d
+ }
+}
+
+// WithDNS adds or updates an OptionDomainNameServer
+func WithDNS(dnses ...net.IP) Modifier {
+ return func(d *DHCPv4) *DHCPv4 {
+ odns := OptDomainNameServer{NameServers: dnses}
+ d.UpdateOption(&odns)
+ return d
+ }
+}
+
+// WithDomainSearchList adds or updates an OptionDomainSearch
+func WithDomainSearchList(searchList ...string) Modifier {
+ return func(d *DHCPv4) *DHCPv4 {
+ labels := rfc1035label.Labels{
+ Labels: searchList,
+ }
+ odsl := OptDomainSearch{DomainSearch: &labels}
+ d.UpdateOption(&odsl)
+ return d
+ }
+}
+
+// WithRouter adds or updates an OptionRouter
+func WithRouter(routers ...net.IP) Modifier {
+ return func(d *DHCPv4) *DHCPv4 {
+ ortr := OptRouter{Routers: routers}
+ d.UpdateOption(&ortr)
+ return d
+ }
+}
diff --git a/dhcpv4/modifiers_test.go b/dhcpv4/modifiers_test.go
index ce4ff38..f50e40b 100644
--- a/dhcpv4/modifiers_test.go
+++ b/dhcpv4/modifiers_test.go
@@ -136,3 +136,56 @@ func TestWithRelay(t *testing.T) {
require.Equal(t, ip, d.GatewayIPAddr())
require.Equal(t, uint8(1), d.HopCount())
}
+
+func TestWithNetmask(t *testing.T) {
+ d := &DHCPv4{}
+ d = WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
+ require.Equal(t, 1, len(d.options))
+ require.Equal(t, OptionSubnetMask, d.options[0].Code())
+ osm := d.options[0].(*OptSubnetMask)
+ require.Equal(t, net.IPv4Mask(255, 255, 255, 0), osm.SubnetMask)
+}
+
+func TestWithLeaseTime(t *testing.T) {
+ d := &DHCPv4{}
+ d = WithLeaseTime(uint32(3600))(d)
+ require.Equal(t, 1, len(d.options))
+ require.Equal(t, OptionIPAddressLeaseTime, d.options[0].Code())
+ olt := d.options[0].(*OptIPAddressLeaseTime)
+ require.Equal(t, uint32(3600), olt.LeaseTime)
+}
+
+func TestWithDNS(t *testing.T) {
+ d := &DHCPv4{}
+ d = WithDNS(net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2"))(d)
+ require.Equal(t, 1, len(d.options))
+ require.Equal(t, OptionDomainNameServer, d.options[0].Code())
+ olt := d.options[0].(*OptDomainNameServer)
+ require.Equal(t, 2, len(olt.NameServers))
+ require.Equal(t, net.ParseIP("10.0.0.1"), olt.NameServers[0])
+ require.Equal(t, net.ParseIP("10.0.0.2"), olt.NameServers[1])
+ require.NotEqual(t, net.ParseIP("10.0.0.1"), olt.NameServers[1])
+}
+
+func TestWithDomainSearchList(t *testing.T) {
+ d := &DHCPv4{}
+ d = WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d)
+ require.Equal(t, 1, len(d.options))
+ osl := d.options[0].(*OptDomainSearch)
+ require.Equal(t, OptionDNSDomainSearchList, osl.Code())
+ require.NotNil(t, osl.DomainSearch)
+ require.Equal(t, 2, len(osl.DomainSearch.Labels))
+ require.Equal(t, "slackware.it", osl.DomainSearch.Labels[0])
+ require.Equal(t, "dhcp.slackware.it", osl.DomainSearch.Labels[1])
+}
+
+func TestWithRouter(t *testing.T) {
+ d := &DHCPv4{}
+ rtr := net.ParseIP("10.0.0.254")
+ d = WithRouter(rtr)(d)
+ require.Equal(t, 1, len(d.options))
+ ortr := d.options[0].(*OptRouter)
+ require.Equal(t, OptionRouter, ortr.Code())
+ require.Equal(t, 1, len(ortr.Routers))
+ require.Equal(t, rtr, ortr.Routers[0])
+}
diff --git a/dhcpv4/option_domain_search.go b/dhcpv4/option_domain_search.go
index c640c0f..9c24eea 100644
--- a/dhcpv4/option_domain_search.go
+++ b/dhcpv4/option_domain_search.go
@@ -9,6 +9,9 @@ import (
"github.com/insomniacslk/dhcp/rfc1035label"
)
+// FIXME rename OptDomainSearch to OptDomainSearchList, and DomainSearch to
+// SearchList, for consistency with the equivalent v6 option
+
// OptDomainSearch represents an option encapsulating a domain search list.
type OptDomainSearch struct {
DomainSearch *rfc1035label.Labels
diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go
index 82d44b4..545146c 100644
--- a/dhcpv6/dhcpv6message.go
+++ b/dhcpv6/dhcpv6message.go
@@ -280,6 +280,8 @@ func (d *DHCPv6Message) AddOption(option Option) {
d.options = append(d.options, option)
}
+// UpdateOption updates the existing options with the passed option, adding it
+// at the end if not present already
func (d *DHCPv6Message) UpdateOption(option Option) {
for idx, opt := range d.options {
if opt.Code() == option.Code() {
diff --git a/netboot/netconf.go b/netboot/netconf.go
index aef4efa..833bf52 100644
--- a/netboot/netconf.go
+++ b/netboot/netconf.go
@@ -106,10 +106,6 @@ func GetNetConfFromPacketv4(d *dhcpv4.DHCPv4) (*NetConf, error) {
leaseTime = leaseTimeOption.(*dhcpv4.OptIPAddressLeaseTime).LeaseTime
}
- if int(leaseTime) < 0 {
- return nil, fmt.Errorf("lease time overflow, Original lease time: %d", leaseTime)
- }
-
netconf.Addresses = append(netconf.Addresses, AddrConf{
IPNet: net.IPNet{
IP: ipAddr,
diff --git a/netboot/netconf_test.go b/netboot/netconf_test.go
index 066753b..ab40648 100644
--- a/netboot/netconf_test.go
+++ b/netboot/netconf_test.go
@@ -5,6 +5,7 @@ import (
"net"
"testing"
+ "github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv6"
"github.com/insomniacslk/dhcp/iana"
"github.com/stretchr/testify/require"
@@ -88,6 +89,140 @@ func TestGetNetConfFromPacketv6(t *testing.T) {
dhcpv6.WithDNS(net.ParseIP("fe80::1")),
dhcpv6.WithDomainSearchList("slackware.it"),
)
- _, err := GetNetConfFromPacketv6(adv)
+ netconf, err := GetNetConfFromPacketv6(adv)
+ require.NoError(t, err)
+ // check addresses
+ require.Equal(t, 1, len(netconf.Addresses))
+ require.Equal(t, net.ParseIP("::1"), netconf.Addresses[0].IPNet.IP)
+ require.Equal(t, 3600, netconf.Addresses[0].PreferredLifetime)
+ require.Equal(t, 5200, netconf.Addresses[0].ValidLifetime)
+ // check DNSes
+ require.Equal(t, 1, len(netconf.DNSServers))
+ require.Equal(t, net.ParseIP("fe80::1"), netconf.DNSServers[0])
+ // check DNS search list
+ require.Equal(t, 1, len(netconf.DNSSearchList))
+ require.Equal(t, "slackware.it", netconf.DNSSearchList[0])
+ // check routers
+ require.Equal(t, 0, len(netconf.Routers))
+}
+
+func TestGetNetConfFromPacketv4AddrZero(t *testing.T) {
+ d := dhcpv4.DHCPv4{}
+ d.SetYourIPAddr(net.IPv4zero)
+ _, err := GetNetConfFromPacketv4(&d)
+ require.Error(t, err)
+}
+
+func TestGetNetConfFromPacketv4NoMask(t *testing.T) {
+ d := dhcpv4.DHCPv4{}
+ d.SetYourIPAddr(net.ParseIP("10.0.0.1"))
+ _, err := GetNetConfFromPacketv4(&d)
+ require.Error(t, err)
+}
+
+func TestGetNetConfFromPacketv4NullMask(t *testing.T) {
+ d := &dhcpv4.DHCPv4{}
+ d.SetYourIPAddr(net.ParseIP("10.0.0.1"))
+ d = dhcpv4.WithNetmask(net.IPv4Mask(0, 0, 0, 0))(d)
+ _, err := GetNetConfFromPacketv4(d)
+ require.Error(t, err)
+}
+
+func TestGetNetConfFromPacketv4NoLeaseTime(t *testing.T) {
+ d := &dhcpv4.DHCPv4{}
+ d.SetYourIPAddr(net.ParseIP("10.0.0.1"))
+ d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
+ _, err := GetNetConfFromPacketv4(d)
+ require.Error(t, err)
+}
+
+func TestGetNetConfFromPacketv4NoDNS(t *testing.T) {
+ d := &dhcpv4.DHCPv4{}
+ d.SetYourIPAddr(net.ParseIP("10.0.0.1"))
+ d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
+ d = dhcpv4.WithLeaseTime(uint32(0))(d)
+ _, err := GetNetConfFromPacketv4(d)
+ require.Error(t, err)
+}
+
+func TestGetNetConfFromPacketv4EmptyDNSList(t *testing.T) {
+ d := &dhcpv4.DHCPv4{}
+ d.SetYourIPAddr(net.ParseIP("10.0.0.1"))
+ d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
+ d = dhcpv4.WithLeaseTime(uint32(0))(d)
+ d = dhcpv4.WithDNS()(d)
+ _, err := GetNetConfFromPacketv4(d)
+ require.Error(t, err)
+}
+
+func TestGetNetConfFromPacketv4NoSearchList(t *testing.T) {
+ d := &dhcpv4.DHCPv4{}
+ d.SetYourIPAddr(net.ParseIP("10.0.0.1"))
+ d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
+ d = dhcpv4.WithLeaseTime(uint32(0))(d)
+ d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d)
+ _, err := GetNetConfFromPacketv4(d)
+ require.Error(t, err)
+}
+
+func TestGetNetConfFromPacketv4EmptySearchList(t *testing.T) {
+ d := &dhcpv4.DHCPv4{}
+ d.SetYourIPAddr(net.ParseIP("10.0.0.1"))
+ d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
+ d = dhcpv4.WithLeaseTime(uint32(0))(d)
+ d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d)
+ d = dhcpv4.WithDomainSearchList()(d)
+ _, err := GetNetConfFromPacketv4(d)
+ require.Error(t, err)
+}
+
+func TestGetNetConfFromPacketv4NoRouter(t *testing.T) {
+ d := &dhcpv4.DHCPv4{}
+ d.SetYourIPAddr(net.ParseIP("10.0.0.1"))
+ d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
+ d = dhcpv4.WithLeaseTime(uint32(0))(d)
+ d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d)
+ d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d)
+ _, err := GetNetConfFromPacketv4(d)
+ require.Error(t, err)
+}
+
+func TestGetNetConfFromPacketv4EmptyRouter(t *testing.T) {
+ d := &dhcpv4.DHCPv4{}
+ d.SetYourIPAddr(net.ParseIP("10.0.0.1"))
+ d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
+ d = dhcpv4.WithLeaseTime(uint32(0))(d)
+ d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d)
+ d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d)
+ d = dhcpv4.WithRouter()(d)
+ _, err := GetNetConfFromPacketv4(d)
+ require.Error(t, err)
+}
+
+func TestGetNetConfFromPacketv4(t *testing.T) {
+ d := &dhcpv4.DHCPv4{}
+ d.SetYourIPAddr(net.ParseIP("10.0.0.1"))
+ d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d)
+ d = dhcpv4.WithLeaseTime(uint32(5200))(d)
+ d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d)
+ d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d)
+ d = dhcpv4.WithRouter(net.ParseIP("10.0.0.254"))(d)
+ netconf, err := GetNetConfFromPacketv4(d)
require.NoError(t, err)
+ // check addresses
+ require.Equal(t, 1, len(netconf.Addresses))
+ require.Equal(t, net.ParseIP("10.0.0.1"), netconf.Addresses[0].IPNet.IP)
+ require.Equal(t, 0, netconf.Addresses[0].PreferredLifetime)
+ require.Equal(t, 5200, netconf.Addresses[0].ValidLifetime)
+ // check DNSes
+ require.Equal(t, 2, len(netconf.DNSServers))
+ require.Equal(t, net.ParseIP("10.10.0.1"), netconf.DNSServers[0])
+ require.Equal(t, net.ParseIP("10.10.0.2"), netconf.DNSServers[1])
+ // check DNS search list
+ require.Equal(t, 2, len(netconf.DNSSearchList))
+ require.Equal(t, "slackware.it", netconf.DNSSearchList[0])
+ require.Equal(t, "dhcp.slackware.it", netconf.DNSSearchList[1])
+ // check routers
+ require.Equal(t, 1, len(netconf.Routers))
+ require.Equal(t, net.ParseIP("10.0.0.254"), netconf.Routers[0])
}