diff options
-rw-r--r-- | dhcpv6/dhcpv6message.go | 13 | ||||
-rw-r--r-- | dhcpv6/modifiers.go | 7 | ||||
-rw-r--r-- | dhcpv6/modifiers_test.go | 12 | ||||
-rw-r--r-- | dhcpv6/option_domainsearchlist.go | 23 | ||||
-rw-r--r-- | dhcpv6/option_domainsearchlist_test.go | 12 | ||||
-rw-r--r-- | dhcpv6/options.go | 2 | ||||
-rw-r--r-- | netboot/netconf.go | 8 |
7 files changed, 42 insertions, 35 deletions
diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go index 8ba7943..56f8627 100644 --- a/dhcpv6/dhcpv6message.go +++ b/dhcpv6/dhcpv6message.go @@ -7,6 +7,7 @@ import ( "time" "github.com/insomniacslk/dhcp/iana" + "github.com/insomniacslk/dhcp/rfc1035label" "github.com/u-root/u-root/pkg/rand" "github.com/u-root/u-root/pkg/uio" ) @@ -130,6 +131,18 @@ func (mo MessageOptions) DNS() []net.IP { return nil } +// DomainSearchList returns the Domain List option as defined by RFC 3646. +func (mo MessageOptions) DomainSearchList() *rfc1035label.Labels { + opt := mo.Options.GetOne(OptionDomainSearchList) + if opt == nil { + return nil + } + if dsl, ok := opt.(*optDomainSearchList); ok { + return dsl.DomainSearchList + } + return nil +} + // BootFileURL returns the Boot File URL option as defined by RFC 5970. func (mo MessageOptions) BootFileURL() string { opt := mo.Options.GetOne(OptionBootfileURL) diff --git a/dhcpv6/modifiers.go b/dhcpv6/modifiers.go index ee2056f..14bfe51 100644 --- a/dhcpv6/modifiers.go +++ b/dhcpv6/modifiers.go @@ -94,12 +94,11 @@ func WithDNS(dnses ...net.IP) Modifier { // WithDomainSearchList adds or updates an OptDomainSearchList func WithDomainSearchList(searchlist ...string) Modifier { return func(d DHCPv6) { - osl := OptDomainSearchList{ - DomainSearchList: &rfc1035label.Labels{ + d.UpdateOption(OptDomainSearchList( + &rfc1035label.Labels{ Labels: searchlist, }, - } - d.UpdateOption(&osl) + )) } } diff --git a/dhcpv6/modifiers_test.go b/dhcpv6/modifiers_test.go index e31cd29..b99d4a2 100644 --- a/dhcpv6/modifiers_test.go +++ b/dhcpv6/modifiers_test.go @@ -73,15 +73,11 @@ func TestWithDNS(t *testing.T) { func TestWithDomainSearchList(t *testing.T) { var d Message - WithDomainSearchList([]string{ - "slackware.it", - "dhcp.slackware.it", - }...)(&d) + WithDomainSearchList("slackware.it", "dhcp.slackware.it")(&d) require.Equal(t, 1, len(d.Options.Options)) - osl := d.Options.Options[0].(*OptDomainSearchList) - require.Equal(t, OptionDomainSearchList, osl.Code()) - require.NotNil(t, osl.DomainSearchList) - labels := osl.DomainSearchList.Labels + osl := d.Options.DomainSearchList() + require.NotNil(t, osl) + labels := osl.Labels require.Equal(t, 2, len(labels)) require.Equal(t, "slackware.it", labels[0]) require.Equal(t, "dhcp.slackware.it", labels[1]) diff --git a/dhcpv6/option_domainsearchlist.go b/dhcpv6/option_domainsearchlist.go index 6bbe4e3..e71a8d5 100644 --- a/dhcpv6/option_domainsearchlist.go +++ b/dhcpv6/option_domainsearchlist.go @@ -6,31 +6,32 @@ import ( "github.com/insomniacslk/dhcp/rfc1035label" ) -// OptDomainSearchList list implements a OptionDomainSearchList option -// -// This module defines the OptDomainSearchList structure. -// https://www.ietf.org/rfc/rfc3646.txt -type OptDomainSearchList struct { +// OptDomainSearchList returns a DomainSearchList option as defined by RFC 3646. +func OptDomainSearchList(labels *rfc1035label.Labels) Option { + return &optDomainSearchList{DomainSearchList: labels} +} + +type optDomainSearchList struct { DomainSearchList *rfc1035label.Labels } -func (op *OptDomainSearchList) Code() OptionCode { +func (op *optDomainSearchList) Code() OptionCode { return OptionDomainSearchList } // ToBytes marshals this option to bytes. -func (op *OptDomainSearchList) ToBytes() []byte { +func (op *optDomainSearchList) ToBytes() []byte { return op.DomainSearchList.ToBytes() } -func (op *OptDomainSearchList) String() string { - return fmt.Sprintf("OptDomainSearchList{searchlist=%v}", op.DomainSearchList.Labels) +func (op *optDomainSearchList) String() string { + return fmt.Sprintf("DomainSearchList: %s", op.DomainSearchList) } // ParseOptDomainSearchList builds an OptDomainSearchList structure from a sequence // of bytes. The input data does not include option code and length bytes. -func ParseOptDomainSearchList(data []byte) (*OptDomainSearchList, error) { - var opt OptDomainSearchList +func parseOptDomainSearchList(data []byte) (*optDomainSearchList, error) { + var opt optDomainSearchList var err error opt.DomainSearchList, err = rfc1035label.FromBytes(data) if err != nil { diff --git a/dhcpv6/option_domainsearchlist_test.go b/dhcpv6/option_domainsearchlist_test.go index b4d0195..433f710 100644 --- a/dhcpv6/option_domainsearchlist_test.go +++ b/dhcpv6/option_domainsearchlist_test.go @@ -12,13 +12,13 @@ func TestParseOptDomainSearchList(t *testing.T) { 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0, } - opt, err := ParseOptDomainSearchList(data) + opt, err := parseOptDomainSearchList(data) require.NoError(t, err) require.Equal(t, OptionDomainSearchList, opt.Code()) require.Equal(t, 2, len(opt.DomainSearchList.Labels)) require.Equal(t, "example.com", opt.DomainSearchList.Labels[0]) require.Equal(t, "subnet.example.org", opt.DomainSearchList.Labels[1]) - require.Contains(t, opt.String(), "searchlist=[example.com subnet.example.org]", "String() should contain the correct domain search output") + require.Contains(t, opt.String(), "example.com subnet.example.org", "String() should contain the correct domain search output") } func TestOptDomainSearchListToBytes(t *testing.T) { @@ -26,14 +26,14 @@ func TestOptDomainSearchListToBytes(t *testing.T) { 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0, } - opt := OptDomainSearchList{ - DomainSearchList: &rfc1035label.Labels{ + opt := OptDomainSearchList( + &rfc1035label.Labels{ Labels: []string{ "example.com", "subnet.example.org", }, }, - } + ) require.Equal(t, expected, opt.ToBytes()) } @@ -42,6 +42,6 @@ func TestParseOptDomainSearchListInvalidLength(t *testing.T) { 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', // truncated } - _, err := ParseOptDomainSearchList(data) + _, err := parseOptDomainSearchList(data) require.Error(t, err, "A truncated OptDomainSearchList should return an error") } diff --git a/dhcpv6/options.go b/dhcpv6/options.go index 23a0ba9..9b76f6d 100644 --- a/dhcpv6/options.go +++ b/dhcpv6/options.go @@ -68,7 +68,7 @@ func ParseOption(code OptionCode, optData []byte) (Option, error) { case OptionDNSRecursiveNameServer: opt, err = parseOptDNS(optData) case OptionDomainSearchList: - opt, err = ParseOptDomainSearchList(optData) + opt, err = parseOptDomainSearchList(optData) case OptionIAPD: opt, err = ParseOptIAForPrefixDelegation(optData) case OptionIAPrefix: diff --git a/netboot/netconf.go b/netboot/netconf.go index 6f05e93..78ddff8 100644 --- a/netboot/netconf.go +++ b/netboot/netconf.go @@ -66,11 +66,9 @@ func GetNetConfFromPacketv6(d *dhcpv6.Message) (*NetConf, error) { } netconf.DNSServers = dns - opt := d.GetOneOption(dhcpv6.OptionDomainSearchList) - if opt != nil { - odomains := opt.(*dhcpv6.OptDomainSearchList) - // TODO should this be copied? - netconf.DNSSearchList = odomains.DomainSearchList.Labels + domains := d.Options.DomainSearchList() + if domains != nil { + netconf.DNSSearchList = domains.Labels } return &netconf, nil |