From c85893e1cfbf875afdc99643c0a55f18187343ce Mon Sep 17 00:00:00 2001 From: Chris Koch Date: Sat, 25 Feb 2023 23:16:41 -0800 Subject: StatusCode: test for all Options types Signed-off-by: Chris Koch --- dhcpv6/option_statuscode_test.go | 68 ++++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 23 deletions(-) (limited to 'dhcpv6') diff --git a/dhcpv6/option_statuscode_test.go b/dhcpv6/option_statuscode_test.go index 3c99e4f..a2655b2 100644 --- a/dhcpv6/option_statuscode_test.go +++ b/dhcpv6/option_statuscode_test.go @@ -12,12 +12,47 @@ import ( "github.com/u-root/uio/uio" ) +type optionsWithStatusCode interface { + Status() *OptStatusCode + ToBytes() []byte +} + +type optionsPtr[O any] interface { + *O + FromBytes([]byte) error + Add(o Option) +} + +type testCase struct { + buf []byte + err error + want *OptStatusCode +} + +func testParseStatus[MO optionsWithStatusCode, OA optionsPtr[MO]](t *testing.T, tt testCase) func(t *testing.T) { + return func(t *testing.T) { + t.Helper() + var mo MO + if err := OA(&mo).FromBytes(tt.buf); !errors.Is(err, tt.err) { + t.Errorf("FromBytes = %v, want %v", err, tt.err) + } + if got := mo.Status(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Status = %v, want %v", got, tt.want) + } + + if tt.want != nil { + var m MO + OA(&m).Add(tt.want) + got := m.ToBytes() + if diff := cmp.Diff(tt.buf, got); diff != "" { + t.Errorf("ToBytes mismatch (-want, +got): %s", diff) + } + } + } +} + func TestStatusCodeParseAndGetter(t *testing.T) { - for i, tt := range []struct { - buf []byte - err error - want *OptStatusCode - }{ + for i, tt := range []testCase{ { buf: []byte{ 0, 13, // StatusCode option @@ -45,24 +80,11 @@ func TestStatusCodeParseAndGetter(t *testing.T) { err: uio.ErrUnreadBytes, }, } { - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - var mo MessageOptions - if err := mo.FromBytes(tt.buf); !errors.Is(err, tt.err) { - t.Errorf("FromBytes = %v, want %v", err, tt.err) - } - if got := mo.Status(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("Status = %v, want %v", got, tt.want) - } - - if tt.want != nil { - var m MessageOptions - m.Add(tt.want) - got := m.ToBytes() - if diff := cmp.Diff(tt.buf, got); diff != "" { - t.Errorf("ToBytes mismatch (-want, +got): %s", diff) - } - } - }) + t.Run(fmt.Sprintf("MO-%d", i), testParseStatus[MessageOptions, *MessageOptions](t, tt)) + t.Run(fmt.Sprintf("IO-%d", i), testParseStatus[IdentityOptions, *IdentityOptions](t, tt)) + t.Run(fmt.Sprintf("AO-%d", i), testParseStatus[AddressOptions, *AddressOptions](t, tt)) + t.Run(fmt.Sprintf("PDO-%d", i), testParseStatus[PDOptions, *PDOptions](t, tt)) + t.Run(fmt.Sprintf("PO-%d", i), testParseStatus[PrefixOptions, *PrefixOptions](t, tt)) } } -- cgit v1.2.3