From ab762aee6bf0c8d70706b3297514858e5efcb103 Mon Sep 17 00:00:00 2001 From: Chris Koch Date: Sun, 26 Feb 2023 17:37:44 -0800 Subject: Vendor class: new Getters & tests for Getters, FromBytes, ToBytes Signed-off-by: Chris Koch --- dhcpv6/dhcpv6message.go | 27 ++++++++ dhcpv6/option_vendorclass.go | 5 +- dhcpv6/option_vendorclass_test.go | 133 +++++++++++++++++++++----------------- 3 files changed, 103 insertions(+), 62 deletions(-) diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go index 72332c5..9bb83bb 100644 --- a/dhcpv6/dhcpv6message.go +++ b/dhcpv6/dhcpv6message.go @@ -211,6 +211,33 @@ func (mo MessageOptions) UserClasses() [][]byte { return nil } +// VendorClasses returns the all vendor class options. +func (mo MessageOptions) VendorClasses() []*OptVendorClass { + opt := mo.Options.Get(OptionVendorClass) + if opt == nil { + return nil + } + var vo []*OptVendorClass + for _, o := range opt { + if t, ok := o.(*OptVendorClass); ok { + vo = append(vo, t) + } + } + return vo +} + +// VendorClass returns the vendor class options matching the given enterprise +// number. +func (mo MessageOptions) VendorClass(enterpriseNumber uint32) [][]byte { + vo := mo.VendorClasses() + for _, v := range vo { + if v.EnterpriseNumber == enterpriseNumber { + return v.Data + } + } + return nil +} + // VendorOpts returns the all vendor-specific options. // // RFC 8415 Section 21.17: diff --git a/dhcpv6/option_vendorclass.go b/dhcpv6/option_vendorclass.go index 954dbd0..f85795e 100644 --- a/dhcpv6/option_vendorclass.go +++ b/dhcpv6/option_vendorclass.go @@ -1,7 +1,6 @@ package dhcpv6 import ( - "errors" "fmt" "strings" @@ -49,8 +48,8 @@ func (op *OptVendorClass) FromBytes(data []byte) error { len := buf.Read16() op.Data = append(op.Data, buf.CopyN(int(len))) } - if len(op.Data) < 1 { - return errors.New("ParseOptVendorClass: at least one vendor class data is required") + if len(op.Data) == 0 { + return fmt.Errorf("%w: vendor class data should not be empty", uio.ErrBufferTooShort) } return buf.FinError() } diff --git a/dhcpv6/option_vendorclass_test.go b/dhcpv6/option_vendorclass_test.go index c691176..5fbc0e6 100644 --- a/dhcpv6/option_vendorclass_test.go +++ b/dhcpv6/option_vendorclass_test.go @@ -1,73 +1,88 @@ package dhcpv6 import ( + "errors" + "fmt" + "reflect" "testing" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" + "github.com/u-root/uio/uio" ) -func TestParseOptVendorClass(t *testing.T) { - data := []byte{ - 0xaa, 0xbb, 0xcc, 0xdd, // EnterpriseNumber - 0, 10, 'H', 'T', 'T', 'P', 'C', 'l', 'i', 'e', 'n', 't', - 0, 4, 't', 'e', 's', 't', - } - var opt OptVendorClass - err := opt.FromBytes(data) - require.NoError(t, err) - require.Equal(t, OptionVendorClass, opt.Code()) - require.Equal(t, 2, len(opt.Data)) - require.Equal(t, uint32(0xaabbccdd), opt.EnterpriseNumber) - require.Equal(t, []byte("HTTPClient"), opt.Data[0]) - require.Equal(t, []byte("test"), opt.Data[1]) -} - -func TestOptVendorClassToBytes(t *testing.T) { - opt := OptVendorClass{ - EnterpriseNumber: uint32(0xaabbccdd), - Data: [][]byte{ - []byte("HTTPClient"), - []byte("test"), +func TestVendorClassParseAndGetter(t *testing.T) { + for i, tt := range []struct { + buf []byte + err error + want []*OptVendorClass + }{ + { + buf: []byte{ + 0, 16, // Vendor Class + 0, 14, // length + 0, 0, 0, 16, + 0, 4, + 'S', 'L', 'A', 'M', + 0, 2, + 'h', 'h', + }, + want: []*OptVendorClass{ + &OptVendorClass{ + EnterpriseNumber: 16, + Data: [][]byte{[]byte("SLAM"), []byte("hh")}, + }, + }, }, - } - data := opt.ToBytes() - expected := []byte{ - 0xaa, 0xbb, 0xcc, 0xdd, // EnterpriseNumber - 0, 10, 'H', 'T', 'T', 'P', 'C', 'l', 'i', 'e', 'n', 't', - 0, 4, 't', 'e', 's', 't', - } - require.Equal(t, expected, data) -} - -func TestOptVendorClassParseOptVendorClassMalformed(t *testing.T) { - buf := []byte{ - 0xaa, 0xbb, // truncated EnterpriseNumber - } - var opt OptVendorClass - err := opt.FromBytes(buf) - require.Error(t, err, "ParseOptVendorClass() should error if given truncated EnterpriseNumber") - - buf = []byte{ - 0xaa, 0xbb, 0xcc, 0xdd, // EnterpriseNumber - } - err = opt.FromBytes(buf) - require.Error(t, err, "ParseOptVendorClass() should error if given no vendor classes") - - buf = []byte{ - 0xaa, 0xbb, 0xcc, 0xdd, // EnterpriseNumber - 0, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', - 0, 4, 't', 'e', - } - err = opt.FromBytes(buf) - require.Error(t, err, "ParseOptVendorClass() should error if given truncated vendor classes") + { + buf: []byte{ + 0, 16, + 0, 0, + }, + err: uio.ErrBufferTooShort, + }, + { + buf: []byte{ + 0, 16, + 0, 4, + 0, 0, 0, 6, + }, + err: uio.ErrBufferTooShort, + }, + { + buf: []byte{0, 16, 0}, + err: uio.ErrUnreadBytes, + }, + } { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + var mo MessageOptions + if err := mo.FromBytes(tt.buf); !errors.Is(err, tt.err) { + t.Errorf("FromBytes = %v, want %v", err, tt.err) + } + if got := mo.VendorClasses(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("VendorClass = %v, want %v", got, tt.want) + } + for _, v := range tt.want { + if got := mo.VendorClass(v.EnterpriseNumber); !reflect.DeepEqual(got, v.Data) { + t.Errorf("VendorClass(%d) = %v, want %v", v.EnterpriseNumber, got, v.Data) + } + } + if got := mo.VendorClass(100); got != nil { + t.Errorf("VendorClass(100) = %v, want nil", got) + } - buf = []byte{ - 0xaa, 0xbb, 0xcc, 0xdd, // EnterpriseNumber - 0, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', - 0, + if tt.want != nil { + var m MessageOptions + for _, o := range tt.want { + m.Add(o) + } + got := m.ToBytes() + if diff := cmp.Diff(tt.buf, got); diff != "" { + t.Errorf("ToBytes mismatch (-want, +got): %s", diff) + } + } + }) } - err = opt.FromBytes(buf) - require.Error(t, err, "ParseOptVendorClass() should error if given a truncated length") } func TestOptVendorClassString(t *testing.T) { -- cgit v1.2.3