summaryrefslogtreecommitdiffhomepage
path: root/dhcpv6/option_statuscode_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'dhcpv6/option_statuscode_test.go')
-rw-r--r--dhcpv6/option_statuscode_test.go68
1 files changed, 45 insertions, 23 deletions
diff --git a/dhcpv6/option_statuscode_test.go b/dhcpv6/option_statuscode_test.go
index 3c99e4f..a2655b2 100644
--- a/dhcpv6/option_statuscode_test.go
+++ b/dhcpv6/option_statuscode_test.go
@@ -12,12 +12,47 @@ import (
"github.com/u-root/uio/uio"
)
+type optionsWithStatusCode interface {
+ Status() *OptStatusCode
+ ToBytes() []byte
+}
+
+type optionsPtr[O any] interface {
+ *O
+ FromBytes([]byte) error
+ Add(o Option)
+}
+
+type testCase struct {
+ buf []byte
+ err error
+ want *OptStatusCode
+}
+
+func testParseStatus[MO optionsWithStatusCode, OA optionsPtr[MO]](t *testing.T, tt testCase) func(t *testing.T) {
+ return func(t *testing.T) {
+ t.Helper()
+ var mo MO
+ if err := OA(&mo).FromBytes(tt.buf); !errors.Is(err, tt.err) {
+ t.Errorf("FromBytes = %v, want %v", err, tt.err)
+ }
+ if got := mo.Status(); !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("Status = %v, want %v", got, tt.want)
+ }
+
+ if tt.want != nil {
+ var m MO
+ OA(&m).Add(tt.want)
+ got := m.ToBytes()
+ if diff := cmp.Diff(tt.buf, got); diff != "" {
+ t.Errorf("ToBytes mismatch (-want, +got): %s", diff)
+ }
+ }
+ }
+}
+
func TestStatusCodeParseAndGetter(t *testing.T) {
- for i, tt := range []struct {
- buf []byte
- err error
- want *OptStatusCode
- }{
+ for i, tt := range []testCase{
{
buf: []byte{
0, 13, // StatusCode option
@@ -45,24 +80,11 @@ func TestStatusCodeParseAndGetter(t *testing.T) {
err: uio.ErrUnreadBytes,
},
} {
- t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
- var mo MessageOptions
- if err := mo.FromBytes(tt.buf); !errors.Is(err, tt.err) {
- t.Errorf("FromBytes = %v, want %v", err, tt.err)
- }
- if got := mo.Status(); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("Status = %v, want %v", got, tt.want)
- }
-
- if tt.want != nil {
- var m MessageOptions
- m.Add(tt.want)
- got := m.ToBytes()
- if diff := cmp.Diff(tt.buf, got); diff != "" {
- t.Errorf("ToBytes mismatch (-want, +got): %s", diff)
- }
- }
- })
+ t.Run(fmt.Sprintf("MO-%d", i), testParseStatus[MessageOptions, *MessageOptions](t, tt))
+ t.Run(fmt.Sprintf("IO-%d", i), testParseStatus[IdentityOptions, *IdentityOptions](t, tt))
+ t.Run(fmt.Sprintf("AO-%d", i), testParseStatus[AddressOptions, *AddressOptions](t, tt))
+ t.Run(fmt.Sprintf("PDO-%d", i), testParseStatus[PDOptions, *PDOptions](t, tt))
+ t.Run(fmt.Sprintf("PO-%d", i), testParseStatus[PrefixOptions, *PrefixOptions](t, tt))
}
}