diff options
Diffstat (limited to 'dhcpv4')
-rw-r--r-- | dhcpv4/bsdp/bsdp.go | 14 | ||||
-rw-r--r-- | dhcpv4/bsdp/bsdp_test.go | 103 | ||||
-rw-r--r-- | dhcpv4/dhcpv4_test.go | 154 | ||||
-rw-r--r-- | dhcpv4/options_test.go | 105 |
4 files changed, 152 insertions, 224 deletions
diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go index 6c8d00f..d4d4e31 100644 --- a/dhcpv4/bsdp/bsdp.go +++ b/dhcpv4/bsdp/bsdp.go @@ -217,8 +217,8 @@ func NewInformListForInterface(iface string, replyPort uint16) (*dhcpv4.DHCPv4, d.AddOption(dhcpv4.Option{ Code: dhcpv4.OptionParameterRequestList, Data: []byte{ - dhcpv4.OptionVendorSpecificInformation, - dhcpv4.OptionClassIdentifier, + byte(dhcpv4.OptionVendorSpecificInformation), + byte(dhcpv4.OptionClassIdentifier), }, }) @@ -316,11 +316,11 @@ func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootI d.AddOption(dhcpv4.Option{ Code: dhcpv4.OptionParameterRequestList, Data: []byte{ - dhcpv4.OptionSubnetMask, - dhcpv4.OptionRouter, - dhcpv4.OptionBootfileName, - dhcpv4.OptionVendorSpecificInformation, - dhcpv4.OptionClassIdentifier, + byte(dhcpv4.OptionSubnetMask), + byte(dhcpv4.OptionRouter), + byte(dhcpv4.OptionBootfileName), + byte(dhcpv4.OptionVendorSpecificInformation), + byte(dhcpv4.OptionClassIdentifier), }, }) d.AddOption(dhcpv4.Option{ diff --git a/dhcpv4/bsdp/bsdp_test.go b/dhcpv4/bsdp/bsdp_test.go index b66efbc..9d5ca1f 100644 --- a/dhcpv4/bsdp/bsdp_test.go +++ b/dhcpv4/bsdp/bsdp_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) /* @@ -18,12 +18,12 @@ func TestBootImageIDToBytes(t *testing.T) { } actual := b.ToBytes() expected := []byte{0x81, 0, 0x10, 0} - assert.Equal(t, actual, expected, "serialized BootImageID should be equal") + require.Equal(t, expected, actual) b.IsInstall = false actual = b.ToBytes() expected = []byte{0x01, 0, 0x10, 0} - assert.Equal(t, actual, expected, "serialized BootImageID should be equal") + require.Equal(t, expected, actual) } func TestBootImageIDFromBytes(t *testing.T) { @@ -33,8 +33,8 @@ func TestBootImageIDFromBytes(t *testing.T) { Index: 0x1000, } newBootImage, err := BootImageIDFromBytes(b.ToBytes()) - assert.Nil(t, err, "error from BootImageIDFromBytes") - assert.Equal(t, b, *newBootImage, "deserialized BootImage should be equal") + require.NoError(t, err) + require.Equal(t, b, *newBootImage) b = BootImageID{ IsInstall: true, @@ -42,15 +42,15 @@ func TestBootImageIDFromBytes(t *testing.T) { Index: 0x1011, } newBootImage, err = BootImageIDFromBytes(b.ToBytes()) - assert.Nil(t, err, "error from BootImageIDFromBytes") - assert.Equal(t, b, *newBootImage, "deserialized BootImage should be equal") + require.NoError(t, err) + require.Equal(t, b, *newBootImage) } func TestBootImageIDFromBytesFail(t *testing.T) { serialized := []byte{0x81, 0, 0x10} // intentionally left short deserialized, err := BootImageIDFromBytes(serialized) - assert.Nil(t, deserialized, "BootImageIDFromBytes should return nil on failed deserialization") - assert.NotNil(t, err, "BootImageIDFromBytes should return err on failed deserialization") + require.Nil(t, deserialized) + require.Error(t, err) } /* @@ -71,7 +71,7 @@ func TestBootImageToBytes(t *testing.T) { 98, 115, 100, 112, 45, 49, // byte-encoding of Name } actual := b.ToBytes() - assert.Equal(t, actual, expected, "serialized BootImage should be equal") + require.Equal(t, expected, actual) b = BootImage{ ID: BootImageID{ @@ -87,7 +87,7 @@ func TestBootImageToBytes(t *testing.T) { 98, 115, 100, 112, 45, 50, 49, // byte-encoding of Name } actual = b.ToBytes() - assert.Equal(t, actual, expected, "serialized BootImage should be equal") + require.Equal(t, expected, actual) } func TestBootImageFromBytes(t *testing.T) { @@ -97,7 +97,7 @@ func TestBootImageFromBytes(t *testing.T) { 98, 115, 100, 112, 45, 50, 49, // byte-encoding of Name } b, err := BootImageFromBytes(input) - assert.Nil(t, err, "error while marshalling BootImage") + require.NoError(t, err) expectedBootImage := BootImage{ ID: BootImageID{ IsInstall: false, @@ -106,15 +106,15 @@ func TestBootImageFromBytes(t *testing.T) { }, Name: "bsdp-21", } - assert.Equal(t, *b, expectedBootImage, "invalid marshalling of BootImage") + require.Equal(t, expectedBootImage, *b) } func TestBootImageFromBytesOnlyBootImageID(t *testing.T) { // Only a BootImageID, nothing else. input := []byte{0x1, 0, 0x10, 0x10} b, err := BootImageFromBytes(input) - assert.Nil(t, b, "short bytestream should return nil BootImageID") - assert.NotNil(t, err, "short bytestream should return error") + require.Nil(t, b) + require.Error(t, err) } func TestBootImageFromBytesShortBootImage(t *testing.T) { @@ -124,8 +124,8 @@ func TestBootImageFromBytesShortBootImage(t *testing.T) { 98, 115, 100, 112, 45, 50, // Name bytes (intentionally off-by-one) } b, err := BootImageFromBytes(input) - assert.Nil(t, b, "short bytestream should return nil BootImageID") - assert.NotNil(t, err, "short bytestream should return error") + require.Nil(t, b) + require.Error(t, err) } func TestParseBootImageSingleBootImage(t *testing.T) { @@ -135,16 +135,16 @@ func TestParseBootImageSingleBootImage(t *testing.T) { 98, 115, 100, 112, 45, 50, 49, // byte-encoding of Name } bs, err := ParseBootImagesFromOption(input) - assert.Nil(t, err, "parsing single boot image should not return error") - assert.Equal(t, len(bs), 1, "parsing single boot image should return 1") + require.NoError(t, err) + require.Equal(t, len(bs), 1, "parsing single boot image should return 1") b := bs[0] - expectedBootImage := BootImageID{ + expectedBootImageID := BootImageID{ IsInstall: false, ImageType: BootImageTypeMacOSX, Index: 0x1010, } - assert.Equal(t, b.ID, expectedBootImage, "parsed BootImageIDs should be equal") - assert.Equal(t, b.Name, "bsdp-21", "BootImage name should be equal") + require.Equal(t, expectedBootImageID, b.ID) + require.Equal(t, b.Name, "bsdp-21") } func TestParseBootImageMultipleBootImage(t *testing.T) { @@ -160,8 +160,8 @@ func TestParseBootImageMultipleBootImage(t *testing.T) { 98, 115, 100, 112, 45, 50, 50, 50, // byte-encoding of Name } bs, err := ParseBootImagesFromOption(input) - assert.Nil(t, err, "parsing multiple BootImages should not return error") - assert.Equal(t, len(bs), 2, "parsing 2 BootImages should return 2") + require.NoError(t, err) + require.Equal(t, len(bs), 2, "parsing 2 BootImages should return 2") b1 := bs[0] b2 := bs[1] expectedID1 := BootImageID{ @@ -174,18 +174,18 @@ func TestParseBootImageMultipleBootImage(t *testing.T) { ImageType: BootImageTypeMacOSXServer, Index: 0x1122, } - assert.Equal(t, b1.ID, expectedID1, "first BootImageID should be equal") - assert.Equal(t, b2.ID, expectedID2, "second BootImageID should be equal") - assert.Equal(t, b1.Name, "bsdp-21", "first BootImage name should be equal") - assert.Equal(t, b2.Name, "bsdp-222", "second BootImage name should be equal") + require.Equal(t, expectedID1, b1.ID, "first BootImageID should be equal") + require.Equal(t, expectedID2, b2.ID, "second BootImageID should be equal") + require.Equal(t, "bsdp-21", b1.Name, "first BootImage name should be equal") + require.Equal(t, "bsdp-222", b2.Name, "second BootImage name should be equal") } func TestParseBootImageFail(t *testing.T) { _, err := ParseBootImagesFromOption([]byte{}) - assert.NotNil(t, err, "parseBootImages with empty arg") + require.Error(t, err, "parseBootImages with empty arg") _, err = ParseBootImagesFromOption([]byte{1, 2, 3}) - assert.NotNil(t, err, "parseBootImages with short arg") + require.Error(t, err, "parseBootImages with short arg") _, err = ParseBootImagesFromOption([]byte{ // boot image 1 @@ -198,14 +198,14 @@ func TestParseBootImageFail(t *testing.T) { 8, // len(Name) 98, 115, 100, 112, 45, 50, 50, 50, // byte-encoding of Name }) - assert.NotNil(t, err, "parseBootImages with short arg") + require.Error(t, err, "parseBootImages with short arg") } /* * ParseVendorOptionsFromOptions */ func TestParseVendorOptions(t *testing.T) { - vendorOpts := []dhcpv4.Option{ + expectedOpts := []dhcpv4.Option{ dhcpv4.Option{ Code: OptionMessageType, Data: []byte{byte(MessageTypeList)}, @@ -226,15 +226,15 @@ func TestParseVendorOptions(t *testing.T) { }, dhcpv4.Option{ Code: dhcpv4.OptionVendorSpecificInformation, - Data: dhcpv4.OptionsToBytesWithoutMagicCookie(vendorOpts), + Data: dhcpv4.OptionsToBytesWithoutMagicCookie(expectedOpts), }, } opts := ParseVendorOptionsFromOptions(recvOpts) - assert.Equal(t, opts, vendorOpts, "Parsed vendorOpts should be the same") + require.Equal(t, expectedOpts, opts, "Parsed vendorOpts should be the same") } func TestParseVendorOptionsFromOptionsNotPresent(t *testing.T) { - recvOpts := []dhcpv4.Option{ + expectedOpts := []dhcpv4.Option{ dhcpv4.Option{ Code: dhcpv4.OptionDHCPMessageType, Data: []byte{byte(dhcpv4.MessageTypeAck)}, @@ -244,13 +244,13 @@ func TestParseVendorOptionsFromOptionsNotPresent(t *testing.T) { Data: []byte{0xff, 0xff, 0xff, 0xff}, }, } - opts := ParseVendorOptionsFromOptions(recvOpts) - assert.Empty(t, opts, "vendor opts should be empty if not present in input") + opts := ParseVendorOptionsFromOptions(expectedOpts) + require.Empty(t, opts, "empty vendor opts if not present in DHCP opts") } func TestParseVendorOptionsFromOptionsEmpty(t *testing.T) { - options := ParseVendorOptionsFromOptions([]dhcpv4.Option{}) - assert.Empty(t, options, "vendor opts should be empty if given an empty input") + opts := ParseVendorOptionsFromOptions([]dhcpv4.Option{}) + require.Empty(t, opts, "vendor opts should be empty if given an empty input") } func TestParseVendorOptionsFromOptionsFail(t *testing.T) { @@ -264,14 +264,14 @@ func TestParseVendorOptionsFromOptionsFail(t *testing.T) { }, } vendorOpts := ParseVendorOptionsFromOptions(opts) - assert.Empty(t, vendorOpts, "vendor opts should be empty on parse error") + require.Empty(t, vendorOpts, "vendor opts should be empty on parse error") } /* * ParseBootImageListFromAck */ func TestParseBootImageListFromAck(t *testing.T) { - bootImages := []BootImage{ + expectedBootImages := []BootImage{ BootImage{ ID: BootImageID{ IsInstall: true, @@ -290,7 +290,7 @@ func TestParseBootImageListFromAck(t *testing.T) { }, } var bootImageBytes []byte - for _, image := range bootImages { + for _, image := range expectedBootImages { bootImageBytes = append(bootImageBytes, image.ToBytes()...) } ack, _ := dhcpv4.New() @@ -305,9 +305,8 @@ func TestParseBootImageListFromAck(t *testing.T) { }) images, err := ParseBootImageListFromAck(*ack) - assert.Nil(t, err, "error from ParseBootImageListFromAck") - assert.NotNil(t, images, "parsed boot images from ack") - assert.Equal(t, images, bootImages, "should get same BootImages") + require.NoError(t, err) + require.Equal(t, expectedBootImages, images, "should get same BootImages") } func TestParseBootImageListFromAckNoVendorOption(t *testing.T) { @@ -317,8 +316,8 @@ func TestParseBootImageListFromAckNoVendorOption(t *testing.T) { Data: []byte{byte(dhcpv4.MessageTypeAck)}, }) images, err := ParseBootImageListFromAck(*ack) - assert.Nil(t, err, "no vendor extensions should not return error") - assert.Empty(t, images, "should not get images from ACK without Vendor extensions") + require.NoError(t, err, "no vendor extensions should not return error") + require.Empty(t, images, "should not get images from ACK without Vendor extensions") } func TestParseBootImageListFromAckFail(t *testing.T) { @@ -348,15 +347,15 @@ func TestParseBootImageListFromAckFail(t *testing.T) { }) images, err := ParseBootImageListFromAck(*ack) - assert.Nil(t, images, "should get nil on parse error") - assert.NotNil(t, err, "should get error on parse error") + require.Nil(t, images, "should get nil on parse error") + require.Error(t, err, "should get error on parse error") } /* * Private funcs */ func TestNeedsReplyPort(t *testing.T) { - assert.True(t, needsReplyPort(123), "") - assert.False(t, needsReplyPort(0), "") - assert.False(t, needsReplyPort(dhcpv4.ClientPort), "") + require.True(t, needsReplyPort(123)) + require.False(t, needsReplyPort(0)) + require.False(t, needsReplyPort(dhcpv4.ClientPort)) } diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go index 0d569c8..e9824ef 100644 --- a/dhcpv4/dhcpv4_test.go +++ b/dhcpv4/dhcpv4_test.go @@ -1,31 +1,16 @@ package dhcpv4 import ( - "bytes" "net" "testing" "github.com/insomniacslk/dhcp/iana" + "github.com/stretchr/testify/require" ) -// NOTE: if one of the following Assert* fails where expected and got values are -// the same, you probably have to cast one of them to match the other one's -// type, e.g. comparing int and byte, even the same value, will fail. -func AssertEqual(t *testing.T, a, b interface{}, what string) { - if a != b { - t.Fatalf("Invalid %s. %v != %v", what, a, b) - } -} - -func AssertEqualBytes(t *testing.T, a, b []byte, what string) { - if !bytes.Equal(a, b) { - t.Fatalf("Invalid %s. %v != %v", what, a, b) - } -} - -func AssertEqualIPAddr(t *testing.T, a, b net.IP, what string) { +func RequireEqualIPAddr(t *testing.T, a, b net.IP, msg ...interface{}) { if !net.IP.Equal(a, b) { - t.Fatalf("Invalid %s. %v != %v", what, a, b) + t.Fatalf("Invalid %s. %v != %v", msg, a, b) } } @@ -60,26 +45,23 @@ func TestFromBytes(t *testing.T) { data = append(data, []byte{99, 130, 83, 99}...) d, err := FromBytes(data) - if err != nil { - t.Fatal(err) - } - AssertEqual(t, d.Opcode(), OpcodeBootRequest, "opcode") - AssertEqual(t, d.HwType(), iana.HwTypeEthernet, "hardware type") - AssertEqual(t, d.HwAddrLen(), byte(6), "hardware address length") - AssertEqual(t, d.HopCount(), byte(3), "hop count") - AssertEqual(t, d.TransactionID(), uint32(0xaabbccdd), "transaction ID") - AssertEqual(t, d.NumSeconds(), uint16(3), "number of seconds") - AssertEqual(t, d.Flags(), uint16(1), "flags") - AssertEqualIPAddr(t, d.ClientIPAddr(), net.IPv4zero, "client IP address") - AssertEqualIPAddr(t, d.YourIPAddr(), net.IPv4zero, "your IP address") - AssertEqualIPAddr(t, d.ServerIPAddr(), net.IPv4zero, "server IP address") - AssertEqualIPAddr(t, d.GatewayIPAddr(), net.IPv4zero, "gateway IP address") - hwaddr := d.ClientHwAddr() - AssertEqualBytes(t, hwaddr[:], []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, "flags") + require.NoError(t, err) + require.Equal(t, d.Opcode(), OpcodeBootRequest) + require.Equal(t, d.HwType(), iana.HwTypeEthernet) + require.Equal(t, d.HwAddrLen(), byte(6)) + require.Equal(t, d.HopCount(), byte(3)) + require.Equal(t, d.TransactionID(), uint32(0xaabbccdd)) + require.Equal(t, d.NumSeconds(), uint16(3)) + require.Equal(t, d.Flags(), uint16(1)) + RequireEqualIPAddr(t, d.ClientIPAddr(), net.IPv4zero) + RequireEqualIPAddr(t, d.YourIPAddr(), net.IPv4zero) + RequireEqualIPAddr(t, d.GatewayIPAddr(), net.IPv4zero) + clientHwAddr := d.ClientHwAddr() + require.Equal(t, clientHwAddr[:], []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) hostname := d.ServerHostName() - AssertEqualBytes(t, hostname[:], expectedHostname, "server host name") - bootfilename := d.BootFileName() - AssertEqualBytes(t, bootfilename[:], expectedBootfilename, "boot file name") + require.Equal(t, hostname[:], expectedHostname) + bootfileName := d.BootFileName() + require.Equal(t, bootfileName[:], expectedBootfilename) // no need to check Magic Cookie as it is already validated in FromBytes // above } @@ -87,17 +69,13 @@ func TestFromBytes(t *testing.T) { func TestFromBytesZeroLength(t *testing.T) { data := []byte{} _, err := FromBytes(data) - if err == nil { - t.Fatal("Expected error, got nil") - } + require.Error(t, err) } func TestFromBytesShortLength(t *testing.T) { data := []byte{1, 1, 6, 0} _, err := FromBytes(data) - if err == nil { - t.Fatal("Expected error, got nil") - } + require.Error(t, err) } func TestFromBytesInvalidOptions(t *testing.T) { @@ -126,9 +104,7 @@ func TestFromBytesInvalidOptions(t *testing.T) { // invalid magic cookie, forcing option parsing to fail data = append(data, []byte{99, 130, 83, 98}...) _, err := FromBytes(data) - if err == nil { - t.Fatal("Expected error, got nil") - } + require.Error(t, err) } func TestSettersAndGetters(t *testing.T) { @@ -161,91 +137,89 @@ func TestSettersAndGetters(t *testing.T) { // magic cookie, then no options data = append(data, []byte{99, 130, 83, 99}...) d, err := FromBytes(data) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // getter/setter for Opcode - AssertEqual(t, d.Opcode(), OpcodeBootRequest, "opcode") + require.Equal(t, OpcodeBootRequest, d.Opcode()) d.SetOpcode(OpcodeBootReply) - AssertEqual(t, d.Opcode(), OpcodeBootReply, "opcode") + require.Equal(t, OpcodeBootReply, d.Opcode()) // getter/setter for HwType - AssertEqual(t, d.HwType(), iana.HwTypeEthernet, "hardware type") + require.Equal(t, iana.HwTypeEthernet, d.HwType()) d.SetHwType(iana.HwTypeARCNET) - AssertEqual(t, d.HwType(), iana.HwTypeARCNET, "hardware type") + require.Equal(t, iana.HwTypeARCNET, d.HwType()) // getter/setter for HwAddrLen - AssertEqual(t, d.HwAddrLen(), uint8(6), "hardware address length") + require.Equal(t, uint8(6), d.HwAddrLen()) d.SetHwAddrLen(12) - AssertEqual(t, d.HwAddrLen(), uint8(12), "hardware address length") + require.Equal(t, uint8(12), d.HwAddrLen()) // getter/setter for HopCount - AssertEqual(t, d.HopCount(), uint8(3), "hop count") + require.Equal(t, uint8(3), d.HopCount()) d.SetHopCount(1) - AssertEqual(t, d.HopCount(), uint8(1), "hop count") + require.Equal(t, uint8(1), d.HopCount()) // getter/setter for TransactionID - AssertEqual(t, d.TransactionID(), uint32(0xaabbccdd), "transaction ID") + require.Equal(t, uint32(0xaabbccdd), d.TransactionID()) d.SetTransactionID(0xeeff0011) - AssertEqual(t, d.TransactionID(), uint32(0xeeff0011), "transaction ID") + require.Equal(t, uint32(0xeeff0011), d.TransactionID()) // getter/setter for TransactionID - AssertEqual(t, d.NumSeconds(), uint16(3), "number of seconds") + require.Equal(t, uint16(3), d.NumSeconds()) d.SetNumSeconds(15) - AssertEqual(t, d.NumSeconds(), uint16(15), "number of seconds") + require.Equal(t, uint16(15), d.NumSeconds()) // getter/setter for Flags - AssertEqual(t, d.Flags(), uint16(1), "flags") + require.Equal(t, uint16(1), d.Flags()) d.SetFlags(0) - AssertEqual(t, d.Flags(), uint16(0), "flags") + require.Equal(t, uint16(0), d.Flags()) // getter/setter for ClientIPAddr - AssertEqualIPAddr(t, d.ClientIPAddr(), net.IPv4(1, 2, 3, 4), "client IP address") + RequireEqualIPAddr(t, net.IPv4(1, 2, 3, 4), d.ClientIPAddr()) d.SetClientIPAddr(net.IPv4(4, 3, 2, 1)) - AssertEqualIPAddr(t, d.ClientIPAddr(), net.IPv4(4, 3, 2, 1), "client IP address") + RequireEqualIPAddr(t, net.IPv4(4, 3, 2, 1), d.ClientIPAddr()) // getter/setter for YourIPAddr - AssertEqualIPAddr(t, d.YourIPAddr(), net.IPv4(5, 6, 7, 8), "your IP address") + RequireEqualIPAddr(t, net.IPv4(5, 6, 7, 8), d.YourIPAddr()) d.SetYourIPAddr(net.IPv4(8, 7, 6, 5)) - AssertEqualIPAddr(t, d.YourIPAddr(), net.IPv4(8, 7, 6, 5), "your IP address") + RequireEqualIPAddr(t, net.IPv4(8, 7, 6, 5), d.YourIPAddr()) // getter/setter for ServerIPAddr - AssertEqualIPAddr(t, d.ServerIPAddr(), net.IPv4(9, 10, 11, 12), "server IP address") + RequireEqualIPAddr(t, net.IPv4(9, 10, 11, 12), d.ServerIPAddr()) d.SetServerIPAddr(net.IPv4(12, 11, 10, 9)) - AssertEqualIPAddr(t, d.ServerIPAddr(), net.IPv4(12, 11, 10, 9), "server IP address") + RequireEqualIPAddr(t, net.IPv4(12, 11, 10, 9), d.ServerIPAddr()) // getter/setter for GatewayIPAddr - AssertEqualIPAddr(t, d.GatewayIPAddr(), net.IPv4(13, 14, 15, 16), "gateway IP address") + RequireEqualIPAddr(t, net.IPv4(13, 14, 15, 16), d.GatewayIPAddr()) d.SetGatewayIPAddr(net.IPv4(16, 15, 14, 13)) - AssertEqualIPAddr(t, d.GatewayIPAddr(), net.IPv4(16, 15, 14, 13), "gateway IP address") + RequireEqualIPAddr(t, net.IPv4(16, 15, 14, 13), d.GatewayIPAddr()) // getter/setter for ClientHwAddr hwaddr := d.ClientHwAddr() - AssertEqualBytes(t, hwaddr[:], []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, "client hardware address") + require.Equal(t, []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, hwaddr[:]) d.SetFlags(0) // getter/setter for ServerHostName serverhostname := d.ServerHostName() - AssertEqualBytes(t, serverhostname[:], expectedHostname, "server host name") + require.Equal(t, expectedHostname, serverhostname[:]) newHostname := []byte{'t', 'e', 's', 't'} for i := 0; i < 60; i++ { newHostname = append(newHostname, 0) } d.SetServerHostName(newHostname) serverhostname = d.ServerHostName() - AssertEqualBytes(t, serverhostname[:], newHostname, "server host name") + require.Equal(t, newHostname, serverhostname[:]) // getter/setter for BootFileName bootfilename := d.BootFileName() - AssertEqualBytes(t, bootfilename[:], expectedBootfilename, "boot file name") + require.Equal(t, expectedBootfilename, bootfilename[:]) newBootfilename := []byte{'t', 'e', 's', 't'} for i := 0; i < 124; i++ { newBootfilename = append(newBootfilename, 0) } d.SetBootFileName(newBootfilename) bootfilename = d.BootFileName() - AssertEqualBytes(t, bootfilename[:], newBootfilename, "boot file name") + require.Equal(t, newBootfilename, bootfilename[:]) } func TestToStringMethods(t *testing.T) { @@ -255,38 +229,38 @@ func TestToStringMethods(t *testing.T) { } // OpcodeToString d.SetOpcode(OpcodeBootRequest) - AssertEqual(t, d.OpcodeToString(), "BootRequest", "OpcodeToString") + require.Equal(t, "BootRequest", d.OpcodeToString()) d.SetOpcode(OpcodeBootReply) - AssertEqual(t, d.OpcodeToString(), "BootReply", "OpcodeToString") + require.Equal(t, "BootReply", d.OpcodeToString()) d.SetOpcode(OpcodeType(0)) - AssertEqual(t, d.OpcodeToString(), "Invalid", "OpcodeToString") + require.Equal(t, "Invalid", d.OpcodeToString()) // HwTypeToString d.SetHwType(iana.HwTypeEthernet) - AssertEqual(t, d.HwTypeToString(), "Ethernet", "HwTypeToString") + require.Equal(t, "Ethernet", d.HwTypeToString()) d.SetHwType(iana.HwTypeARCNET) - AssertEqual(t, d.HwTypeToString(), "ARCNET", "HwTypeToString") + require.Equal(t, "ARCNET", d.HwTypeToString()) // FlagsToString d.SetUnicast() - AssertEqual(t, d.FlagsToString(), "Unicast", "FlagsToString") + require.Equal(t, "Unicast", d.FlagsToString()) d.SetBroadcast() - AssertEqual(t, d.FlagsToString(), "Broadcast", "FlagsToString") + require.Equal(t, "Broadcast", d.FlagsToString()) d.SetFlags(0xffff) - AssertEqual(t, d.FlagsToString(), "Broadcast (reserved bits not zeroed)", "FlagsToString") + require.Equal(t, "Broadcast (reserved bits not zeroed)", d.FlagsToString()) // ClientHwAddrToString d.SetHwAddrLen(6) d.SetClientHwAddr([]byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - AssertEqual(t, d.ClientHwAddrToString(), "aa:bb:cc:dd:ee:ff", "ClientHwAddrToString") + require.Equal(t, "aa:bb:cc:dd:ee:ff", d.ClientHwAddrToString()) // ServerHostNameToString d.SetServerHostName([]byte("my.host.local")) - AssertEqual(t, d.ServerHostNameToString(), "my.host.local", "ServerHostNameToString") + require.Equal(t, "my.host.local", d.ServerHostNameToString()) // BootFileNameToString d.SetBootFileName([]byte("/my/boot/file")) - AssertEqual(t, d.BootFileNameToString(), "/my/boot/file", "BootFileNameToString") + require.Equal(t, "/my/boot/file", d.BootFileNameToString()) } func TestToBytes(t *testing.T) { @@ -318,14 +292,12 @@ func TestToBytes(t *testing.T) { expected = append(expected, MagicCookie...) d, err := New() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // fix TransactionID to match the expected one, since it's randomly // generated in New() d.SetTransactionID(0x11223344) got := d.ToBytes() - AssertEqualBytes(t, expected, got, "ToBytes") + require.Equal(t, expected, got) } // TODO diff --git a/dhcpv4/options_test.go b/dhcpv4/options_test.go index 07be9bf..ca860cf 100644 --- a/dhcpv4/options_test.go +++ b/dhcpv4/options_test.go @@ -1,52 +1,37 @@ package dhcpv4 import ( - "bytes" "testing" + + "github.com/stretchr/testify/require" ) func TestParseOption(t *testing.T) { option := []byte{5, 4, 192, 168, 1, 254} // DNS option opt, err := ParseOption(option) - if err != nil { - t.Fatal(err) - } - if opt.Code != OptionNameServer { - t.Fatalf("Invalid option code. Expected 5, got %v", opt.Code) - } - if !bytes.Equal(opt.Data, option[2:]) { - t.Fatalf("Invalid option data. Expected %v, got %v", option[2:], opt.Data) - } + require.NoError(t, err, "should not get error from parsing option") + require.Equal(t, OptionNameServer, opt.Code, "opt should have the same opcode") + require.Equal(t, option[2:], opt.Data, "opt should have the same data") } func TestParseOptionPad(t *testing.T) { option := []byte{0} opt, err := ParseOption(option) - if err != nil { - t.Fatal(err) - } - if opt.Code != OptionPad { - t.Fatalf("Invalid option code. Expected %v, got %v", OptionPad, opt.Code) - } - if len(opt.Data) != 0 { - t.Fatalf("Invalid option data. Expected empty slice, got %v", opt.Data) - } + require.NoError(t, err, "should not get error from parsing option") + require.Equal(t, OptionPad, opt.Code, "should get pad option code") + require.Empty(t, opt.Data, "should get empty data with pad option") } func TestParseOptionZeroLength(t *testing.T) { option := []byte{} _, err := ParseOption(option) - if err == nil { - t.Fatal("Expected an error, got none") - } + require.Error(t, err, "should get error from zero-length options") } func TestParseOptionShortOption(t *testing.T) { option := []byte{53, 1} _, err := ParseOption(option) - if err == nil { - t.Fatal(err) - } + require.Error(t, err, "should get error from short options") } func TestOptionsFromBytes(t *testing.T) { @@ -57,41 +42,29 @@ func TestOptionsFromBytes(t *testing.T) { 0, 0, 0, //padding } opts, err := OptionsFromBytes(options) - if err != nil { - t.Fatal(err) - } - // each padding byte counts as an option. Magic Cookie doesn't add up - if len(opts) != 5 { - t.Fatal("Invalid options length. Expected 5, got %v", len(opts)) - } - if opts[0].Code != OptionNameServer { - t.Fatal("Invalid option code. Expected %v, got %v", OptionNameServer, opts[0].Code) - } - if !bytes.Equal(opts[0].Data, options[6:10]) { - t.Fatal("Invalid option data. Expected %v, got %v", options[6:10], opts[0].Data) - } - if opts[1].Code != OptionEnd { - t.Fatalf("Invalid option code. Expected %v, got %v", OptionEnd, opts[1].Code) - } - if opts[2].Code != OptionPad { - t.Fatalf("Invalid option code. Expected %v, got %v", OptionPad, opts[2].Code) - } + require.NoError(t, err) + require.Equal(t, []Option{ + Option{ + Code: OptionNameServer, + Data: []byte{192, 168, 1, 1}, + }, + Option{Code: OptionEnd, Data: []byte{}}, + Option{Code: OptionPad, Data: []byte{}}, + Option{Code: OptionPad, Data: []byte{}}, + Option{Code: OptionPad, Data: []byte{}}, + }, opts) } func TestOptionsFromBytesZeroLength(t *testing.T) { options := []byte{} _, err := OptionsFromBytes(options) - if err == nil { - t.Fatal("Expected an error, got none") - } + require.Error(t, err) } func TestOptionsFromBytesBadMagicCookie(t *testing.T) { options := []byte{1, 2, 3, 4} _, err := OptionsFromBytes(options) - if err == nil { - t.Fatal("Expected an error, got none") - } + require.Error(t, err) } func TestOptionsToBytes(t *testing.T) { @@ -102,47 +75,31 @@ func TestOptionsToBytes(t *testing.T) { 0, 0, 0, //padding } options, err := OptionsFromBytes(originalOptions) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) finalOptions := OptionsToBytes(options) - if !bytes.Equal(originalOptions, finalOptions) { - t.Fatalf("Invalid options. Expected %v, got %v", originalOptions, finalOptions) - } + require.Equal(t, originalOptions, finalOptions) } func TestOptionsToBytesEmpty(t *testing.T) { originalOptions := []byte{99, 130, 83, 99} options, err := OptionsFromBytes(originalOptions) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) finalOptions := OptionsToBytes(options) - if !bytes.Equal(originalOptions, finalOptions) { - t.Fatalf("Invalid options. Expected %v, got %v", originalOptions, finalOptions) - } + require.Equal(t, originalOptions, finalOptions) } func TestOptionsToStringPad(t *testing.T) { option := []byte{0} opt, err := ParseOption(option) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) stropt := opt.String() - if stropt != "Pad -> []" { - t.Fatalf("Invalid string representation: %v", stropt) - } + require.Equal(t, "Pad -> []", stropt) } func TestOptionsToStringDHCPMessageType(t *testing.T) { option := []byte{53, 1, 5} opt, err := ParseOption(option) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) stropt := opt.String() - if stropt != "DHCP Message Type -> [5]" { - t.Fatalf("Invalid string representation: %v", stropt) - } + require.Equal(t, "DHCP Message Type -> [5]", stropt) } |