diff options
-rw-r--r-- | dhcpv6/dhcpv6_test.go | 12 | ||||
-rw-r--r-- | dhcpv6/dhcpv6message.go | 35 | ||||
-rw-r--r-- | dhcpv6/dhcpv6message_test.go | 8 | ||||
-rw-r--r-- | dhcpv6/modifiers.go | 33 | ||||
-rw-r--r-- | dhcpv6/modifiers_test.go | 12 | ||||
-rw-r--r-- | dhcpv6/option_requestedoption.go | 81 | ||||
-rw-r--r-- | dhcpv6/option_requestedoption_test.go | 15 | ||||
-rw-r--r-- | dhcpv6/options.go | 4 |
8 files changed, 94 insertions, 106 deletions
diff --git a/dhcpv6/dhcpv6_test.go b/dhcpv6/dhcpv6_test.go index d975308..1ecaa3c 100644 --- a/dhcpv6/dhcpv6_test.go +++ b/dhcpv6/dhcpv6_test.go @@ -242,14 +242,10 @@ func TestNewMessageTypeSolicit(t *testing.T) { require.Equal(t, cduid, &duid) // Check ORO - oroOption := s.GetOneOption(OptionORO) - require.NotNil(t, oroOption) - oro, ok := oroOption.(*OptRequestedOption) - require.True(t, ok) - opts := oro.RequestedOptions() - require.Contains(t, opts, OptionDNSRecursiveNameServer) - require.Contains(t, opts, OptionDomainSearchList) - require.Equal(t, len(opts), 2) + oro := s.Options.RequestedOptions() + require.Contains(t, oro, OptionDNSRecursiveNameServer) + require.Contains(t, oro, OptionDomainSearchList) + require.Equal(t, len(oro), 2) // Check IA_NA iaid := [4]byte{hwAddr[2], hwAddr[3], hwAddr[4], hwAddr[5]} diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go index 1bc8670..28853f1 100644 --- a/dhcpv6/dhcpv6message.go +++ b/dhcpv6/dhcpv6message.go @@ -79,6 +79,19 @@ func (mo MessageOptions) Status() *OptStatusCode { return sc } +// RequestedOptions returns the Options Requested Option. +func (mo MessageOptions) RequestedOptions() OptionCodes { + opt := mo.Options.GetOne(OptionORO) + if opt == nil { + return nil + } + oro, ok := opt.(*optRequestedOption) + if !ok { + return nil + } + return oro.OptionCodes +} + // Message represents a DHCPv6 Message as defined by RFC 3315 Section 6. type Message struct { MessageType MessageType @@ -123,12 +136,10 @@ func NewSolicit(hwaddr net.HardwareAddr, modifiers ...Modifier) (*Message, error } m.MessageType = MessageTypeSolicit m.AddOption(OptClientID(duid)) - oro := new(OptRequestedOption) - oro.SetRequestedOptions([]OptionCode{ + m.AddOption(OptRequestedOption( OptionDNSRecursiveNameServer, OptionDomainSearchList, - }) - m.AddOption(oro) + )) m.AddOption(&OptElapsedTime{}) if len(hwaddr) < 4 { return nil, errors.New("short hardware addrss: less than 4 bytes") @@ -205,13 +216,10 @@ func NewRequestFromAdvertise(adv *Message, modifiers ...Modifier) (*Message, err return nil, fmt.Errorf("IA_NA cannot be nil in ADVERTISE when building REQUEST") } req.AddOption(iana) - // add OptRequestedOption - oro := OptRequestedOption{} - oro.SetRequestedOptions([]OptionCode{ + req.AddOption(OptRequestedOption( OptionDNSRecursiveNameServer, OptionDomainSearchList, - }) - req.AddOption(&oro) + )) // add OPTION_VENDOR_CLASS, only if present in the original request // TODO implement OptionVendorClass vClass := adv.GetOneOption(OptionVendorClass) @@ -303,14 +311,7 @@ func (m *Message) IsNetboot() bool { // IsOptionRequested takes an OptionCode and returns true if that option is // within the requested options of the DHCPv6 message. func (m *Message) IsOptionRequested(requested OptionCode) bool { - for _, optoro := range m.GetOption(OptionORO) { - for _, o := range optoro.(*OptRequestedOption).RequestedOptions() { - if o == requested { - return true - } - } - } - return false + return m.Options.RequestedOptions().Contains(requested) } // String returns a short human-readable string for this message. diff --git a/dhcpv6/dhcpv6message_test.go b/dhcpv6/dhcpv6message_test.go index d914e74..d5ae227 100644 --- a/dhcpv6/dhcpv6message_test.go +++ b/dhcpv6/dhcpv6message_test.go @@ -11,9 +11,7 @@ func TestIsNetboot(t *testing.T) { require.False(t, msg1.IsNetboot()) msg2 := Message{} - optro := OptRequestedOption{} - optro.AddRequestedOption(OptionBootfileURL) - msg2.AddOption(&optro) + msg2.AddOption(OptRequestedOption(OptionBootfileURL)) require.True(t, msg2.IsNetboot()) msg3 := Message{} @@ -27,8 +25,6 @@ func TestIsOptionRequested(t *testing.T) { require.False(t, msg1.IsOptionRequested(OptionDNSRecursiveNameServer)) msg2 := Message{} - optro := OptRequestedOption{} - optro.AddRequestedOption(OptionDNSRecursiveNameServer) - msg2.AddOption(&optro) + msg2.AddOption(OptRequestedOption(OptionDNSRecursiveNameServer)) require.True(t, msg2.IsOptionRequested(OptionDNSRecursiveNameServer)) } diff --git a/dhcpv6/modifiers.go b/dhcpv6/modifiers.go index af5e6a8..860071c 100644 --- a/dhcpv6/modifiers.go +++ b/dhcpv6/modifiers.go @@ -1,7 +1,6 @@ package dhcpv6 import ( - "log" "net" "github.com/insomniacslk/dhcp/iana" @@ -27,21 +26,7 @@ func WithServerID(duid Duid) Modifier { // WithNetboot adds bootfile URL and bootfile param options to a DHCPv6 packet. func WithNetboot(d DHCPv6) { - msg, ok := d.(*Message) - if !ok { - log.Printf("WithNetboot: not a Message") - return - } - // add OptionBootfileURL and OptionBootfileParam - opt := msg.GetOneOption(OptionORO) - if opt == nil { - opt = &OptRequestedOption{} - } - // TODO only add options if they are not there already - oro := opt.(*OptRequestedOption) - oro.AddRequestedOption(OptionBootfileURL) - oro.AddRequestedOption(OptionBootfileParam) - msg.UpdateOption(oro) + WithRequestedOptions(OptionBootfileURL, OptionBootfileParam)(d) } // WithFQDN adds a fully qualified domain name option to the packet @@ -129,16 +114,14 @@ func WithRapidCommit(d DHCPv6) { } // WithRequestedOptions adds requested options to the packet -func WithRequestedOptions(optionCodes ...OptionCode) Modifier { +func WithRequestedOptions(codes ...OptionCode) Modifier { return func(d DHCPv6) { - opt := d.GetOneOption(OptionORO) - if opt == nil { - opt = &OptRequestedOption{} - } - oro := opt.(*OptRequestedOption) - for _, optionCode := range optionCodes { - oro.AddRequestedOption(optionCode) + if msg, ok := d.(*Message); ok { + oro := msg.Options.RequestedOptions() + for _, c := range codes { + oro.Add(c) + } + d.UpdateOption(OptRequestedOption(oro...)) } - d.UpdateOption(oro) } } diff --git a/dhcpv6/modifiers_test.go b/dhcpv6/modifiers_test.go index d904ca1..2179aaa 100644 --- a/dhcpv6/modifiers_test.go +++ b/dhcpv6/modifiers_test.go @@ -37,16 +37,12 @@ func TestWithRequestedOptions(t *testing.T) { // Check if ORO is created when no ORO present m, err := NewMessage(WithRequestedOptions(OptionClientID)) require.NoError(t, err) - opt := m.GetOneOption(OptionORO) - require.NotNil(t, opt) - oro := opt.(*OptRequestedOption) - require.ElementsMatch(t, oro.RequestedOptions(), []OptionCode{OptionClientID}) + oro := m.Options.RequestedOptions() + require.ElementsMatch(t, oro, OptionCodes{OptionClientID}) // Check if already set options are preserved WithRequestedOptions(OptionServerID)(m) - opt = m.GetOneOption(OptionORO) - require.NotNil(t, opt) - oro = opt.(*OptRequestedOption) - require.ElementsMatch(t, oro.RequestedOptions(), []OptionCode{OptionClientID, OptionServerID}) + oro = m.Options.RequestedOptions() + require.ElementsMatch(t, oro, OptionCodes{OptionClientID, OptionServerID}) } func TestWithIANA(t *testing.T) { diff --git a/dhcpv6/option_requestedoption.go b/dhcpv6/option_requestedoption.go index 54ff5bf..0d16c74 100644 --- a/dhcpv6/option_requestedoption.go +++ b/dhcpv6/option_requestedoption.go @@ -7,58 +7,69 @@ import ( "github.com/u-root/u-root/pkg/uio" ) -// OptRequestedOption implements the requested options option. -// -// This module defines the OptRequestedOption structure. -// https://www.ietf.org/rfc/rfc3315.txt -type OptRequestedOption struct { - requestedOptions []OptionCode +// OptionCodes are a collection of option codes. +type OptionCodes []OptionCode + +// Add adds an option to the list, ignoring duplicates. +func (o *OptionCodes) Add(c OptionCode) { + if !o.Contains(c) { + *o = append(*o, c) + } } -func (op *OptRequestedOption) Code() OptionCode { - return OptionORO +// Contains returns whether the option codes contain c. +func (o OptionCodes) Contains(c OptionCode) bool { + for _, oo := range o { + if oo == c { + return true + } + } + return false } -func (op *OptRequestedOption) ToBytes() []byte { +// ToBytes implements Option.ToBytes. +func (o OptionCodes) ToBytes() []byte { buf := uio.NewBigEndianBuffer(nil) - for _, ro := range op.requestedOptions { + for _, ro := range o { buf.Write16(uint16(ro)) } return buf.Data() } -func (op *OptRequestedOption) RequestedOptions() []OptionCode { - return op.requestedOptions +func (o OptionCodes) String() string { + names := make([]string, 0, len(o)) + for _, code := range o { + names = append(names, code.String()) + } + return strings.Join(names, ", ") } -func (op *OptRequestedOption) SetRequestedOptions(opts []OptionCode) { - op.requestedOptions = opts +// FromBytes populates o from binary-encoded data. +func (o *OptionCodes) FromBytes(data []byte) error { + buf := uio.NewBigEndianBuffer(data) + for buf.Has(2) { + o.Add(OptionCode(buf.Read16())) + } + return buf.FinError() } -func (op *OptRequestedOption) AddRequestedOption(opt OptionCode) { - for _, requestedOption := range op.requestedOptions { - if opt == requestedOption { - fmt.Printf("Warning: option %s is already set, appending duplicate", opt) - } +// OptRequestedOption implements the requested options option as defined by RFC +// 3315 Section 22.7. +func OptRequestedOption(o ...OptionCode) Option { + return &optRequestedOption{ + OptionCodes: o, } - op.requestedOptions = append(op.requestedOptions, opt) } -func (op *OptRequestedOption) String() string { - names := make([]string, 0, len(op.requestedOptions)) - for _, code := range op.requestedOptions { - names = append(names, code.String()) - } - return fmt.Sprintf("OptRequestedOption{options=[%v]}", strings.Join(names, ", ")) +type optRequestedOption struct { + OptionCodes } -// build an OptRequestedOption structure from a sequence of bytes. -// The input data does not include option code and length bytes. -func ParseOptRequestedOption(data []byte) (*OptRequestedOption, error) { - var opt OptRequestedOption - buf := uio.NewBigEndianBuffer(data) - for buf.Has(2) { - opt.requestedOptions = append(opt.requestedOptions, OptionCode(buf.Read16())) - } - return &opt, buf.FinError() +// Code implements Option.Code. +func (*optRequestedOption) Code() OptionCode { + return OptionORO +} + +func (op *optRequestedOption) String() string { + return fmt.Sprintf("RequestedOptions: %s", op.OptionCodes) } diff --git a/dhcpv6/option_requestedoption_test.go b/dhcpv6/option_requestedoption_test.go index 3e79480..9941a89 100644 --- a/dhcpv6/option_requestedoption_test.go +++ b/dhcpv6/option_requestedoption_test.go @@ -8,30 +8,33 @@ import ( func TestOptRequestedOption(t *testing.T) { expected := []byte{0, 1, 0, 2} - _, err := ParseOptRequestedOption(expected) + var o optRequestedOption + err := o.FromBytes(expected) require.NoError(t, err, "ParseOptRequestedOption() correct options should not error") } func TestOptRequestedOptionParseOptRequestedOptionTooShort(t *testing.T) { buf := []byte{0, 1, 0} - _, err := ParseOptRequestedOption(buf) + var o optRequestedOption + err := o.FromBytes(buf) require.Error(t, err, "A short option should return an error (must be divisible by 2)") } func TestOptRequestedOptionString(t *testing.T) { buf := []byte{0, 1, 0, 2} - opt, err := ParseOptRequestedOption(buf) + var o optRequestedOption + err := o.FromBytes(buf) require.NoError(t, err) require.Contains( t, - opt.String(), + o.String(), "Client Identifier, Server Identifier", "String() should contain the options specified", ) - opt.AddRequestedOption(12345) + o.OptionCodes = append(o.OptionCodes, 12345) require.Contains( t, - opt.String(), + o.String(), "unknown", "String() should contain 'Unknown' for an illegal option", ) diff --git a/dhcpv6/options.go b/dhcpv6/options.go index 617761f..0a9ab0e 100644 --- a/dhcpv6/options.go +++ b/dhcpv6/options.go @@ -48,7 +48,9 @@ func ParseOption(code OptionCode, optData []byte) (Option, error) { case OptionIAAddr: opt, err = ParseOptIAAddress(optData) case OptionORO: - opt, err = ParseOptRequestedOption(optData) + var o optRequestedOption + err = o.FromBytes(optData) + opt = &o case OptionElapsedTime: opt, err = ParseOptElapsedTime(optData) case OptionRelayMsg: |