summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorChris Koch <chrisko@google.com>2023-02-26 17:37:44 -0800
committerChris K <c@chrisko.ch>2023-02-27 10:35:19 -0800
commitab762aee6bf0c8d70706b3297514858e5efcb103 (patch)
treee85b7df4b27849f3c5c7ea0f4453f4e0ea9f2741
parent7405adac5bdf7b8a148d9ec912ec78dd5ac3d908 (diff)
Vendor class: new Getters & tests for Getters, FromBytes, ToBytes
Signed-off-by: Chris Koch <chrisko@google.com>
-rw-r--r--dhcpv6/dhcpv6message.go27
-rw-r--r--dhcpv6/option_vendorclass.go5
-rw-r--r--dhcpv6/option_vendorclass_test.go133
3 files changed, 103 insertions, 62 deletions
diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go
index 72332c5..9bb83bb 100644
--- a/dhcpv6/dhcpv6message.go
+++ b/dhcpv6/dhcpv6message.go
@@ -211,6 +211,33 @@ func (mo MessageOptions) UserClasses() [][]byte {
return nil
}
+// VendorClasses returns the all vendor class options.
+func (mo MessageOptions) VendorClasses() []*OptVendorClass {
+ opt := mo.Options.Get(OptionVendorClass)
+ if opt == nil {
+ return nil
+ }
+ var vo []*OptVendorClass
+ for _, o := range opt {
+ if t, ok := o.(*OptVendorClass); ok {
+ vo = append(vo, t)
+ }
+ }
+ return vo
+}
+
+// VendorClass returns the vendor class options matching the given enterprise
+// number.
+func (mo MessageOptions) VendorClass(enterpriseNumber uint32) [][]byte {
+ vo := mo.VendorClasses()
+ for _, v := range vo {
+ if v.EnterpriseNumber == enterpriseNumber {
+ return v.Data
+ }
+ }
+ return nil
+}
+
// VendorOpts returns the all vendor-specific options.
//
// RFC 8415 Section 21.17:
diff --git a/dhcpv6/option_vendorclass.go b/dhcpv6/option_vendorclass.go
index 954dbd0..f85795e 100644
--- a/dhcpv6/option_vendorclass.go
+++ b/dhcpv6/option_vendorclass.go
@@ -1,7 +1,6 @@
package dhcpv6
import (
- "errors"
"fmt"
"strings"
@@ -49,8 +48,8 @@ func (op *OptVendorClass) FromBytes(data []byte) error {
len := buf.Read16()
op.Data = append(op.Data, buf.CopyN(int(len)))
}
- if len(op.Data) < 1 {
- return errors.New("ParseOptVendorClass: at least one vendor class data is required")
+ if len(op.Data) == 0 {
+ return fmt.Errorf("%w: vendor class data should not be empty", uio.ErrBufferTooShort)
}
return buf.FinError()
}
diff --git a/dhcpv6/option_vendorclass_test.go b/dhcpv6/option_vendorclass_test.go
index c691176..5fbc0e6 100644
--- a/dhcpv6/option_vendorclass_test.go
+++ b/dhcpv6/option_vendorclass_test.go
@@ -1,73 +1,88 @@
package dhcpv6
import (
+ "errors"
+ "fmt"
+ "reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
+ "github.com/u-root/uio/uio"
)
-func TestParseOptVendorClass(t *testing.T) {
- data := []byte{
- 0xaa, 0xbb, 0xcc, 0xdd, // EnterpriseNumber
- 0, 10, 'H', 'T', 'T', 'P', 'C', 'l', 'i', 'e', 'n', 't',
- 0, 4, 't', 'e', 's', 't',
- }
- var opt OptVendorClass
- err := opt.FromBytes(data)
- require.NoError(t, err)
- require.Equal(t, OptionVendorClass, opt.Code())
- require.Equal(t, 2, len(opt.Data))
- require.Equal(t, uint32(0xaabbccdd), opt.EnterpriseNumber)
- require.Equal(t, []byte("HTTPClient"), opt.Data[0])
- require.Equal(t, []byte("test"), opt.Data[1])
-}
-
-func TestOptVendorClassToBytes(t *testing.T) {
- opt := OptVendorClass{
- EnterpriseNumber: uint32(0xaabbccdd),
- Data: [][]byte{
- []byte("HTTPClient"),
- []byte("test"),
+func TestVendorClassParseAndGetter(t *testing.T) {
+ for i, tt := range []struct {
+ buf []byte
+ err error
+ want []*OptVendorClass
+ }{
+ {
+ buf: []byte{
+ 0, 16, // Vendor Class
+ 0, 14, // length
+ 0, 0, 0, 16,
+ 0, 4,
+ 'S', 'L', 'A', 'M',
+ 0, 2,
+ 'h', 'h',
+ },
+ want: []*OptVendorClass{
+ &OptVendorClass{
+ EnterpriseNumber: 16,
+ Data: [][]byte{[]byte("SLAM"), []byte("hh")},
+ },
+ },
},
- }
- data := opt.ToBytes()
- expected := []byte{
- 0xaa, 0xbb, 0xcc, 0xdd, // EnterpriseNumber
- 0, 10, 'H', 'T', 'T', 'P', 'C', 'l', 'i', 'e', 'n', 't',
- 0, 4, 't', 'e', 's', 't',
- }
- require.Equal(t, expected, data)
-}
-
-func TestOptVendorClassParseOptVendorClassMalformed(t *testing.T) {
- buf := []byte{
- 0xaa, 0xbb, // truncated EnterpriseNumber
- }
- var opt OptVendorClass
- err := opt.FromBytes(buf)
- require.Error(t, err, "ParseOptVendorClass() should error if given truncated EnterpriseNumber")
-
- buf = []byte{
- 0xaa, 0xbb, 0xcc, 0xdd, // EnterpriseNumber
- }
- err = opt.FromBytes(buf)
- require.Error(t, err, "ParseOptVendorClass() should error if given no vendor classes")
-
- buf = []byte{
- 0xaa, 0xbb, 0xcc, 0xdd, // EnterpriseNumber
- 0, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't',
- 0, 4, 't', 'e',
- }
- err = opt.FromBytes(buf)
- require.Error(t, err, "ParseOptVendorClass() should error if given truncated vendor classes")
+ {
+ buf: []byte{
+ 0, 16,
+ 0, 0,
+ },
+ err: uio.ErrBufferTooShort,
+ },
+ {
+ buf: []byte{
+ 0, 16,
+ 0, 4,
+ 0, 0, 0, 6,
+ },
+ err: uio.ErrBufferTooShort,
+ },
+ {
+ buf: []byte{0, 16, 0},
+ err: uio.ErrUnreadBytes,
+ },
+ } {
+ t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
+ var mo MessageOptions
+ if err := mo.FromBytes(tt.buf); !errors.Is(err, tt.err) {
+ t.Errorf("FromBytes = %v, want %v", err, tt.err)
+ }
+ if got := mo.VendorClasses(); !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("VendorClass = %v, want %v", got, tt.want)
+ }
+ for _, v := range tt.want {
+ if got := mo.VendorClass(v.EnterpriseNumber); !reflect.DeepEqual(got, v.Data) {
+ t.Errorf("VendorClass(%d) = %v, want %v", v.EnterpriseNumber, got, v.Data)
+ }
+ }
+ if got := mo.VendorClass(100); got != nil {
+ t.Errorf("VendorClass(100) = %v, want nil", got)
+ }
- buf = []byte{
- 0xaa, 0xbb, 0xcc, 0xdd, // EnterpriseNumber
- 0, 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't',
- 0,
+ if tt.want != nil {
+ var m MessageOptions
+ for _, o := range tt.want {
+ m.Add(o)
+ }
+ got := m.ToBytes()
+ if diff := cmp.Diff(tt.buf, got); diff != "" {
+ t.Errorf("ToBytes mismatch (-want, +got): %s", diff)
+ }
+ }
+ })
}
- err = opt.FromBytes(buf)
- require.Error(t, err, "ParseOptVendorClass() should error if given a truncated length")
}
func TestOptVendorClassString(t *testing.T) {