diff options
-rw-r--r-- | dhcpv4/dhcpv4.go | 14 | ||||
-rw-r--r-- | dhcpv4/dhcpv4_test.go | 21 | ||||
-rw-r--r-- | dhcpv4/modifiers.go | 54 | ||||
-rw-r--r-- | dhcpv4/modifiers_test.go | 53 | ||||
-rw-r--r-- | dhcpv4/option_domain_search.go | 3 | ||||
-rw-r--r-- | dhcpv6/dhcpv6message.go | 2 | ||||
-rw-r--r-- | netboot/netconf.go | 4 | ||||
-rw-r--r-- | netboot/netconf_test.go | 137 |
8 files changed, 280 insertions, 8 deletions
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index 94b8351..2931565 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -612,6 +612,20 @@ func (d *DHCPv4) AddOption(option Option) { } } +// UpdateOption updates the existing options with the passed option, adding it +// at the end if not present already +func (d *DHCPv4) UpdateOption(option Option) { + for idx, opt := range d.options { + if opt.Code() == option.Code() { + d.options[idx] = option + // don't look further + return + } + } + // if not found, add it + d.AddOption(option) +} + // MessageType returns the message type, trying to extract it from the // OptMessageType option. It returns nil if the message type cannot be extracted func (d *DHCPv4) MessageType() *MessageType { diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go index 059ae0c..283e728 100644 --- a/dhcpv4/dhcpv4_test.go +++ b/dhcpv4/dhcpv4_test.go @@ -347,9 +347,7 @@ func TestGetOption(t *testing.T) { func TestAddOption(t *testing.T) { d, err := New() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) hostnameOpt := &OptionGeneric{OptionCode: OptionHostName, Data: []byte("darkstar")} bootFileOpt1 := &OptionGeneric{OptionCode: OptionBootfileName, Data: []byte("boot.img")} @@ -363,6 +361,23 @@ func TestAddOption(t *testing.T) { require.Equal(t, options[3].Code(), OptionEnd) } +func TestUpdateOption(t *testing.T) { + d, err := New() + require.NoError(t, err) + require.Equal(t, 1, len(d.options)) + require.Equal(t, OptionEnd, d.options[0].Code()) + // test that it will add the option since it's missing + d.UpdateOption(&OptDomainName{DomainName: "slackware.it"}) + require.Equal(t, 2, len(d.options)) + require.Equal(t, OptionDomainName, d.options[0].Code()) + require.Equal(t, OptionEnd, d.options[1].Code()) + // test that it won't add another option of the same type + d.UpdateOption(&OptDomainName{DomainName: "slackware.it"}) + require.Equal(t, 2, len(d.options)) + require.Equal(t, OptionDomainName, d.options[0].Code()) + require.Equal(t, OptionEnd, d.options[1].Code()) +} + func TestStrippedOptions(t *testing.T) { // Normal set of options that terminate with OptionEnd. d, err := New() diff --git a/dhcpv4/modifiers.go b/dhcpv4/modifiers.go index 188d91f..033718e 100644 --- a/dhcpv4/modifiers.go +++ b/dhcpv4/modifiers.go @@ -2,6 +2,8 @@ package dhcpv4 import ( "net" + + "github.com/insomniacslk/dhcp/rfc1035label" ) // WithTransactionID sets the Transaction ID for the DHCPv4 packet @@ -114,3 +116,55 @@ func WithRelay(ip net.IP) Modifier { return d } } + +// WithNetmask adds or updates an OptSubnetMask +func WithNetmask(mask net.IPMask) Modifier { + return func(d *DHCPv4) *DHCPv4 { + osm := OptSubnetMask{ + SubnetMask: mask, + } + d.UpdateOption(&osm) + return d + } +} + +// WithLeaseTime adds or updates an OptIPAddressLeaseTime +func WithLeaseTime(leaseTime uint32) Modifier { + return func(d *DHCPv4) *DHCPv4 { + olt := OptIPAddressLeaseTime{ + LeaseTime: leaseTime, + } + d.UpdateOption(&olt) + return d + } +} + +// WithDNS adds or updates an OptionDomainNameServer +func WithDNS(dnses ...net.IP) Modifier { + return func(d *DHCPv4) *DHCPv4 { + odns := OptDomainNameServer{NameServers: dnses} + d.UpdateOption(&odns) + return d + } +} + +// WithDomainSearchList adds or updates an OptionDomainSearch +func WithDomainSearchList(searchList ...string) Modifier { + return func(d *DHCPv4) *DHCPv4 { + labels := rfc1035label.Labels{ + Labels: searchList, + } + odsl := OptDomainSearch{DomainSearch: &labels} + d.UpdateOption(&odsl) + return d + } +} + +// WithRouter adds or updates an OptionRouter +func WithRouter(routers ...net.IP) Modifier { + return func(d *DHCPv4) *DHCPv4 { + ortr := OptRouter{Routers: routers} + d.UpdateOption(&ortr) + return d + } +} diff --git a/dhcpv4/modifiers_test.go b/dhcpv4/modifiers_test.go index ce4ff38..f50e40b 100644 --- a/dhcpv4/modifiers_test.go +++ b/dhcpv4/modifiers_test.go @@ -136,3 +136,56 @@ func TestWithRelay(t *testing.T) { require.Equal(t, ip, d.GatewayIPAddr()) require.Equal(t, uint8(1), d.HopCount()) } + +func TestWithNetmask(t *testing.T) { + d := &DHCPv4{} + d = WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) + require.Equal(t, 1, len(d.options)) + require.Equal(t, OptionSubnetMask, d.options[0].Code()) + osm := d.options[0].(*OptSubnetMask) + require.Equal(t, net.IPv4Mask(255, 255, 255, 0), osm.SubnetMask) +} + +func TestWithLeaseTime(t *testing.T) { + d := &DHCPv4{} + d = WithLeaseTime(uint32(3600))(d) + require.Equal(t, 1, len(d.options)) + require.Equal(t, OptionIPAddressLeaseTime, d.options[0].Code()) + olt := d.options[0].(*OptIPAddressLeaseTime) + require.Equal(t, uint32(3600), olt.LeaseTime) +} + +func TestWithDNS(t *testing.T) { + d := &DHCPv4{} + d = WithDNS(net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2"))(d) + require.Equal(t, 1, len(d.options)) + require.Equal(t, OptionDomainNameServer, d.options[0].Code()) + olt := d.options[0].(*OptDomainNameServer) + require.Equal(t, 2, len(olt.NameServers)) + require.Equal(t, net.ParseIP("10.0.0.1"), olt.NameServers[0]) + require.Equal(t, net.ParseIP("10.0.0.2"), olt.NameServers[1]) + require.NotEqual(t, net.ParseIP("10.0.0.1"), olt.NameServers[1]) +} + +func TestWithDomainSearchList(t *testing.T) { + d := &DHCPv4{} + d = WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d) + require.Equal(t, 1, len(d.options)) + osl := d.options[0].(*OptDomainSearch) + require.Equal(t, OptionDNSDomainSearchList, osl.Code()) + require.NotNil(t, osl.DomainSearch) + require.Equal(t, 2, len(osl.DomainSearch.Labels)) + require.Equal(t, "slackware.it", osl.DomainSearch.Labels[0]) + require.Equal(t, "dhcp.slackware.it", osl.DomainSearch.Labels[1]) +} + +func TestWithRouter(t *testing.T) { + d := &DHCPv4{} + rtr := net.ParseIP("10.0.0.254") + d = WithRouter(rtr)(d) + require.Equal(t, 1, len(d.options)) + ortr := d.options[0].(*OptRouter) + require.Equal(t, OptionRouter, ortr.Code()) + require.Equal(t, 1, len(ortr.Routers)) + require.Equal(t, rtr, ortr.Routers[0]) +} diff --git a/dhcpv4/option_domain_search.go b/dhcpv4/option_domain_search.go index c640c0f..9c24eea 100644 --- a/dhcpv4/option_domain_search.go +++ b/dhcpv4/option_domain_search.go @@ -9,6 +9,9 @@ import ( "github.com/insomniacslk/dhcp/rfc1035label" ) +// FIXME rename OptDomainSearch to OptDomainSearchList, and DomainSearch to +// SearchList, for consistency with the equivalent v6 option + // OptDomainSearch represents an option encapsulating a domain search list. type OptDomainSearch struct { DomainSearch *rfc1035label.Labels diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go index 82d44b4..545146c 100644 --- a/dhcpv6/dhcpv6message.go +++ b/dhcpv6/dhcpv6message.go @@ -280,6 +280,8 @@ func (d *DHCPv6Message) AddOption(option Option) { d.options = append(d.options, option) } +// UpdateOption updates the existing options with the passed option, adding it +// at the end if not present already func (d *DHCPv6Message) UpdateOption(option Option) { for idx, opt := range d.options { if opt.Code() == option.Code() { diff --git a/netboot/netconf.go b/netboot/netconf.go index aef4efa..833bf52 100644 --- a/netboot/netconf.go +++ b/netboot/netconf.go @@ -106,10 +106,6 @@ func GetNetConfFromPacketv4(d *dhcpv4.DHCPv4) (*NetConf, error) { leaseTime = leaseTimeOption.(*dhcpv4.OptIPAddressLeaseTime).LeaseTime } - if int(leaseTime) < 0 { - return nil, fmt.Errorf("lease time overflow, Original lease time: %d", leaseTime) - } - netconf.Addresses = append(netconf.Addresses, AddrConf{ IPNet: net.IPNet{ IP: ipAddr, diff --git a/netboot/netconf_test.go b/netboot/netconf_test.go index 066753b..ab40648 100644 --- a/netboot/netconf_test.go +++ b/netboot/netconf_test.go @@ -5,6 +5,7 @@ import ( "net" "testing" + "github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/iana" "github.com/stretchr/testify/require" @@ -88,6 +89,140 @@ func TestGetNetConfFromPacketv6(t *testing.T) { dhcpv6.WithDNS(net.ParseIP("fe80::1")), dhcpv6.WithDomainSearchList("slackware.it"), ) - _, err := GetNetConfFromPacketv6(adv) + netconf, err := GetNetConfFromPacketv6(adv) + require.NoError(t, err) + // check addresses + require.Equal(t, 1, len(netconf.Addresses)) + require.Equal(t, net.ParseIP("::1"), netconf.Addresses[0].IPNet.IP) + require.Equal(t, 3600, netconf.Addresses[0].PreferredLifetime) + require.Equal(t, 5200, netconf.Addresses[0].ValidLifetime) + // check DNSes + require.Equal(t, 1, len(netconf.DNSServers)) + require.Equal(t, net.ParseIP("fe80::1"), netconf.DNSServers[0]) + // check DNS search list + require.Equal(t, 1, len(netconf.DNSSearchList)) + require.Equal(t, "slackware.it", netconf.DNSSearchList[0]) + // check routers + require.Equal(t, 0, len(netconf.Routers)) +} + +func TestGetNetConfFromPacketv4AddrZero(t *testing.T) { + d := dhcpv4.DHCPv4{} + d.SetYourIPAddr(net.IPv4zero) + _, err := GetNetConfFromPacketv4(&d) + require.Error(t, err) +} + +func TestGetNetConfFromPacketv4NoMask(t *testing.T) { + d := dhcpv4.DHCPv4{} + d.SetYourIPAddr(net.ParseIP("10.0.0.1")) + _, err := GetNetConfFromPacketv4(&d) + require.Error(t, err) +} + +func TestGetNetConfFromPacketv4NullMask(t *testing.T) { + d := &dhcpv4.DHCPv4{} + d.SetYourIPAddr(net.ParseIP("10.0.0.1")) + d = dhcpv4.WithNetmask(net.IPv4Mask(0, 0, 0, 0))(d) + _, err := GetNetConfFromPacketv4(d) + require.Error(t, err) +} + +func TestGetNetConfFromPacketv4NoLeaseTime(t *testing.T) { + d := &dhcpv4.DHCPv4{} + d.SetYourIPAddr(net.ParseIP("10.0.0.1")) + d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) + _, err := GetNetConfFromPacketv4(d) + require.Error(t, err) +} + +func TestGetNetConfFromPacketv4NoDNS(t *testing.T) { + d := &dhcpv4.DHCPv4{} + d.SetYourIPAddr(net.ParseIP("10.0.0.1")) + d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) + d = dhcpv4.WithLeaseTime(uint32(0))(d) + _, err := GetNetConfFromPacketv4(d) + require.Error(t, err) +} + +func TestGetNetConfFromPacketv4EmptyDNSList(t *testing.T) { + d := &dhcpv4.DHCPv4{} + d.SetYourIPAddr(net.ParseIP("10.0.0.1")) + d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) + d = dhcpv4.WithLeaseTime(uint32(0))(d) + d = dhcpv4.WithDNS()(d) + _, err := GetNetConfFromPacketv4(d) + require.Error(t, err) +} + +func TestGetNetConfFromPacketv4NoSearchList(t *testing.T) { + d := &dhcpv4.DHCPv4{} + d.SetYourIPAddr(net.ParseIP("10.0.0.1")) + d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) + d = dhcpv4.WithLeaseTime(uint32(0))(d) + d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d) + _, err := GetNetConfFromPacketv4(d) + require.Error(t, err) +} + +func TestGetNetConfFromPacketv4EmptySearchList(t *testing.T) { + d := &dhcpv4.DHCPv4{} + d.SetYourIPAddr(net.ParseIP("10.0.0.1")) + d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) + d = dhcpv4.WithLeaseTime(uint32(0))(d) + d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d) + d = dhcpv4.WithDomainSearchList()(d) + _, err := GetNetConfFromPacketv4(d) + require.Error(t, err) +} + +func TestGetNetConfFromPacketv4NoRouter(t *testing.T) { + d := &dhcpv4.DHCPv4{} + d.SetYourIPAddr(net.ParseIP("10.0.0.1")) + d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) + d = dhcpv4.WithLeaseTime(uint32(0))(d) + d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d) + d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d) + _, err := GetNetConfFromPacketv4(d) + require.Error(t, err) +} + +func TestGetNetConfFromPacketv4EmptyRouter(t *testing.T) { + d := &dhcpv4.DHCPv4{} + d.SetYourIPAddr(net.ParseIP("10.0.0.1")) + d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) + d = dhcpv4.WithLeaseTime(uint32(0))(d) + d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d) + d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d) + d = dhcpv4.WithRouter()(d) + _, err := GetNetConfFromPacketv4(d) + require.Error(t, err) +} + +func TestGetNetConfFromPacketv4(t *testing.T) { + d := &dhcpv4.DHCPv4{} + d.SetYourIPAddr(net.ParseIP("10.0.0.1")) + d = dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) + d = dhcpv4.WithLeaseTime(uint32(5200))(d) + d = dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2"))(d) + d = dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d) + d = dhcpv4.WithRouter(net.ParseIP("10.0.0.254"))(d) + netconf, err := GetNetConfFromPacketv4(d) require.NoError(t, err) + // check addresses + require.Equal(t, 1, len(netconf.Addresses)) + require.Equal(t, net.ParseIP("10.0.0.1"), netconf.Addresses[0].IPNet.IP) + require.Equal(t, 0, netconf.Addresses[0].PreferredLifetime) + require.Equal(t, 5200, netconf.Addresses[0].ValidLifetime) + // check DNSes + require.Equal(t, 2, len(netconf.DNSServers)) + require.Equal(t, net.ParseIP("10.10.0.1"), netconf.DNSServers[0]) + require.Equal(t, net.ParseIP("10.10.0.2"), netconf.DNSServers[1]) + // check DNS search list + require.Equal(t, 2, len(netconf.DNSSearchList)) + require.Equal(t, "slackware.it", netconf.DNSSearchList[0]) + require.Equal(t, "dhcp.slackware.it", netconf.DNSSearchList[1]) + // check routers + require.Equal(t, 1, len(netconf.Routers)) + require.Equal(t, net.ParseIP("10.0.0.254"), netconf.Routers[0]) } |