summaryrefslogtreecommitdiffhomepage
path: root/dhcpv4
diff options
context:
space:
mode:
Diffstat (limited to 'dhcpv4')
-rw-r--r--dhcpv4/bsdp/bsdp.go14
-rw-r--r--dhcpv4/bsdp/bsdp_test.go103
-rw-r--r--dhcpv4/dhcpv4_test.go154
-rw-r--r--dhcpv4/options_test.go105
4 files changed, 152 insertions, 224 deletions
diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go
index 6c8d00f..d4d4e31 100644
--- a/dhcpv4/bsdp/bsdp.go
+++ b/dhcpv4/bsdp/bsdp.go
@@ -217,8 +217,8 @@ func NewInformListForInterface(iface string, replyPort uint16) (*dhcpv4.DHCPv4,
d.AddOption(dhcpv4.Option{
Code: dhcpv4.OptionParameterRequestList,
Data: []byte{
- dhcpv4.OptionVendorSpecificInformation,
- dhcpv4.OptionClassIdentifier,
+ byte(dhcpv4.OptionVendorSpecificInformation),
+ byte(dhcpv4.OptionClassIdentifier),
},
})
@@ -316,11 +316,11 @@ func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootI
d.AddOption(dhcpv4.Option{
Code: dhcpv4.OptionParameterRequestList,
Data: []byte{
- dhcpv4.OptionSubnetMask,
- dhcpv4.OptionRouter,
- dhcpv4.OptionBootfileName,
- dhcpv4.OptionVendorSpecificInformation,
- dhcpv4.OptionClassIdentifier,
+ byte(dhcpv4.OptionSubnetMask),
+ byte(dhcpv4.OptionRouter),
+ byte(dhcpv4.OptionBootfileName),
+ byte(dhcpv4.OptionVendorSpecificInformation),
+ byte(dhcpv4.OptionClassIdentifier),
},
})
d.AddOption(dhcpv4.Option{
diff --git a/dhcpv4/bsdp/bsdp_test.go b/dhcpv4/bsdp/bsdp_test.go
index b66efbc..9d5ca1f 100644
--- a/dhcpv4/bsdp/bsdp_test.go
+++ b/dhcpv4/bsdp/bsdp_test.go
@@ -4,7 +4,7 @@ import (
"testing"
"github.com/insomniacslk/dhcp/dhcpv4"
- "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
/*
@@ -18,12 +18,12 @@ func TestBootImageIDToBytes(t *testing.T) {
}
actual := b.ToBytes()
expected := []byte{0x81, 0, 0x10, 0}
- assert.Equal(t, actual, expected, "serialized BootImageID should be equal")
+ require.Equal(t, expected, actual)
b.IsInstall = false
actual = b.ToBytes()
expected = []byte{0x01, 0, 0x10, 0}
- assert.Equal(t, actual, expected, "serialized BootImageID should be equal")
+ require.Equal(t, expected, actual)
}
func TestBootImageIDFromBytes(t *testing.T) {
@@ -33,8 +33,8 @@ func TestBootImageIDFromBytes(t *testing.T) {
Index: 0x1000,
}
newBootImage, err := BootImageIDFromBytes(b.ToBytes())
- assert.Nil(t, err, "error from BootImageIDFromBytes")
- assert.Equal(t, b, *newBootImage, "deserialized BootImage should be equal")
+ require.NoError(t, err)
+ require.Equal(t, b, *newBootImage)
b = BootImageID{
IsInstall: true,
@@ -42,15 +42,15 @@ func TestBootImageIDFromBytes(t *testing.T) {
Index: 0x1011,
}
newBootImage, err = BootImageIDFromBytes(b.ToBytes())
- assert.Nil(t, err, "error from BootImageIDFromBytes")
- assert.Equal(t, b, *newBootImage, "deserialized BootImage should be equal")
+ require.NoError(t, err)
+ require.Equal(t, b, *newBootImage)
}
func TestBootImageIDFromBytesFail(t *testing.T) {
serialized := []byte{0x81, 0, 0x10} // intentionally left short
deserialized, err := BootImageIDFromBytes(serialized)
- assert.Nil(t, deserialized, "BootImageIDFromBytes should return nil on failed deserialization")
- assert.NotNil(t, err, "BootImageIDFromBytes should return err on failed deserialization")
+ require.Nil(t, deserialized)
+ require.Error(t, err)
}
/*
@@ -71,7 +71,7 @@ func TestBootImageToBytes(t *testing.T) {
98, 115, 100, 112, 45, 49, // byte-encoding of Name
}
actual := b.ToBytes()
- assert.Equal(t, actual, expected, "serialized BootImage should be equal")
+ require.Equal(t, expected, actual)
b = BootImage{
ID: BootImageID{
@@ -87,7 +87,7 @@ func TestBootImageToBytes(t *testing.T) {
98, 115, 100, 112, 45, 50, 49, // byte-encoding of Name
}
actual = b.ToBytes()
- assert.Equal(t, actual, expected, "serialized BootImage should be equal")
+ require.Equal(t, expected, actual)
}
func TestBootImageFromBytes(t *testing.T) {
@@ -97,7 +97,7 @@ func TestBootImageFromBytes(t *testing.T) {
98, 115, 100, 112, 45, 50, 49, // byte-encoding of Name
}
b, err := BootImageFromBytes(input)
- assert.Nil(t, err, "error while marshalling BootImage")
+ require.NoError(t, err)
expectedBootImage := BootImage{
ID: BootImageID{
IsInstall: false,
@@ -106,15 +106,15 @@ func TestBootImageFromBytes(t *testing.T) {
},
Name: "bsdp-21",
}
- assert.Equal(t, *b, expectedBootImage, "invalid marshalling of BootImage")
+ require.Equal(t, expectedBootImage, *b)
}
func TestBootImageFromBytesOnlyBootImageID(t *testing.T) {
// Only a BootImageID, nothing else.
input := []byte{0x1, 0, 0x10, 0x10}
b, err := BootImageFromBytes(input)
- assert.Nil(t, b, "short bytestream should return nil BootImageID")
- assert.NotNil(t, err, "short bytestream should return error")
+ require.Nil(t, b)
+ require.Error(t, err)
}
func TestBootImageFromBytesShortBootImage(t *testing.T) {
@@ -124,8 +124,8 @@ func TestBootImageFromBytesShortBootImage(t *testing.T) {
98, 115, 100, 112, 45, 50, // Name bytes (intentionally off-by-one)
}
b, err := BootImageFromBytes(input)
- assert.Nil(t, b, "short bytestream should return nil BootImageID")
- assert.NotNil(t, err, "short bytestream should return error")
+ require.Nil(t, b)
+ require.Error(t, err)
}
func TestParseBootImageSingleBootImage(t *testing.T) {
@@ -135,16 +135,16 @@ func TestParseBootImageSingleBootImage(t *testing.T) {
98, 115, 100, 112, 45, 50, 49, // byte-encoding of Name
}
bs, err := ParseBootImagesFromOption(input)
- assert.Nil(t, err, "parsing single boot image should not return error")
- assert.Equal(t, len(bs), 1, "parsing single boot image should return 1")
+ require.NoError(t, err)
+ require.Equal(t, len(bs), 1, "parsing single boot image should return 1")
b := bs[0]
- expectedBootImage := BootImageID{
+ expectedBootImageID := BootImageID{
IsInstall: false,
ImageType: BootImageTypeMacOSX,
Index: 0x1010,
}
- assert.Equal(t, b.ID, expectedBootImage, "parsed BootImageIDs should be equal")
- assert.Equal(t, b.Name, "bsdp-21", "BootImage name should be equal")
+ require.Equal(t, expectedBootImageID, b.ID)
+ require.Equal(t, b.Name, "bsdp-21")
}
func TestParseBootImageMultipleBootImage(t *testing.T) {
@@ -160,8 +160,8 @@ func TestParseBootImageMultipleBootImage(t *testing.T) {
98, 115, 100, 112, 45, 50, 50, 50, // byte-encoding of Name
}
bs, err := ParseBootImagesFromOption(input)
- assert.Nil(t, err, "parsing multiple BootImages should not return error")
- assert.Equal(t, len(bs), 2, "parsing 2 BootImages should return 2")
+ require.NoError(t, err)
+ require.Equal(t, len(bs), 2, "parsing 2 BootImages should return 2")
b1 := bs[0]
b2 := bs[1]
expectedID1 := BootImageID{
@@ -174,18 +174,18 @@ func TestParseBootImageMultipleBootImage(t *testing.T) {
ImageType: BootImageTypeMacOSXServer,
Index: 0x1122,
}
- assert.Equal(t, b1.ID, expectedID1, "first BootImageID should be equal")
- assert.Equal(t, b2.ID, expectedID2, "second BootImageID should be equal")
- assert.Equal(t, b1.Name, "bsdp-21", "first BootImage name should be equal")
- assert.Equal(t, b2.Name, "bsdp-222", "second BootImage name should be equal")
+ require.Equal(t, expectedID1, b1.ID, "first BootImageID should be equal")
+ require.Equal(t, expectedID2, b2.ID, "second BootImageID should be equal")
+ require.Equal(t, "bsdp-21", b1.Name, "first BootImage name should be equal")
+ require.Equal(t, "bsdp-222", b2.Name, "second BootImage name should be equal")
}
func TestParseBootImageFail(t *testing.T) {
_, err := ParseBootImagesFromOption([]byte{})
- assert.NotNil(t, err, "parseBootImages with empty arg")
+ require.Error(t, err, "parseBootImages with empty arg")
_, err = ParseBootImagesFromOption([]byte{1, 2, 3})
- assert.NotNil(t, err, "parseBootImages with short arg")
+ require.Error(t, err, "parseBootImages with short arg")
_, err = ParseBootImagesFromOption([]byte{
// boot image 1
@@ -198,14 +198,14 @@ func TestParseBootImageFail(t *testing.T) {
8, // len(Name)
98, 115, 100, 112, 45, 50, 50, 50, // byte-encoding of Name
})
- assert.NotNil(t, err, "parseBootImages with short arg")
+ require.Error(t, err, "parseBootImages with short arg")
}
/*
* ParseVendorOptionsFromOptions
*/
func TestParseVendorOptions(t *testing.T) {
- vendorOpts := []dhcpv4.Option{
+ expectedOpts := []dhcpv4.Option{
dhcpv4.Option{
Code: OptionMessageType,
Data: []byte{byte(MessageTypeList)},
@@ -226,15 +226,15 @@ func TestParseVendorOptions(t *testing.T) {
},
dhcpv4.Option{
Code: dhcpv4.OptionVendorSpecificInformation,
- Data: dhcpv4.OptionsToBytesWithoutMagicCookie(vendorOpts),
+ Data: dhcpv4.OptionsToBytesWithoutMagicCookie(expectedOpts),
},
}
opts := ParseVendorOptionsFromOptions(recvOpts)
- assert.Equal(t, opts, vendorOpts, "Parsed vendorOpts should be the same")
+ require.Equal(t, expectedOpts, opts, "Parsed vendorOpts should be the same")
}
func TestParseVendorOptionsFromOptionsNotPresent(t *testing.T) {
- recvOpts := []dhcpv4.Option{
+ expectedOpts := []dhcpv4.Option{
dhcpv4.Option{
Code: dhcpv4.OptionDHCPMessageType,
Data: []byte{byte(dhcpv4.MessageTypeAck)},
@@ -244,13 +244,13 @@ func TestParseVendorOptionsFromOptionsNotPresent(t *testing.T) {
Data: []byte{0xff, 0xff, 0xff, 0xff},
},
}
- opts := ParseVendorOptionsFromOptions(recvOpts)
- assert.Empty(t, opts, "vendor opts should be empty if not present in input")
+ opts := ParseVendorOptionsFromOptions(expectedOpts)
+ require.Empty(t, opts, "empty vendor opts if not present in DHCP opts")
}
func TestParseVendorOptionsFromOptionsEmpty(t *testing.T) {
- options := ParseVendorOptionsFromOptions([]dhcpv4.Option{})
- assert.Empty(t, options, "vendor opts should be empty if given an empty input")
+ opts := ParseVendorOptionsFromOptions([]dhcpv4.Option{})
+ require.Empty(t, opts, "vendor opts should be empty if given an empty input")
}
func TestParseVendorOptionsFromOptionsFail(t *testing.T) {
@@ -264,14 +264,14 @@ func TestParseVendorOptionsFromOptionsFail(t *testing.T) {
},
}
vendorOpts := ParseVendorOptionsFromOptions(opts)
- assert.Empty(t, vendorOpts, "vendor opts should be empty on parse error")
+ require.Empty(t, vendorOpts, "vendor opts should be empty on parse error")
}
/*
* ParseBootImageListFromAck
*/
func TestParseBootImageListFromAck(t *testing.T) {
- bootImages := []BootImage{
+ expectedBootImages := []BootImage{
BootImage{
ID: BootImageID{
IsInstall: true,
@@ -290,7 +290,7 @@ func TestParseBootImageListFromAck(t *testing.T) {
},
}
var bootImageBytes []byte
- for _, image := range bootImages {
+ for _, image := range expectedBootImages {
bootImageBytes = append(bootImageBytes, image.ToBytes()...)
}
ack, _ := dhcpv4.New()
@@ -305,9 +305,8 @@ func TestParseBootImageListFromAck(t *testing.T) {
})
images, err := ParseBootImageListFromAck(*ack)
- assert.Nil(t, err, "error from ParseBootImageListFromAck")
- assert.NotNil(t, images, "parsed boot images from ack")
- assert.Equal(t, images, bootImages, "should get same BootImages")
+ require.NoError(t, err)
+ require.Equal(t, expectedBootImages, images, "should get same BootImages")
}
func TestParseBootImageListFromAckNoVendorOption(t *testing.T) {
@@ -317,8 +316,8 @@ func TestParseBootImageListFromAckNoVendorOption(t *testing.T) {
Data: []byte{byte(dhcpv4.MessageTypeAck)},
})
images, err := ParseBootImageListFromAck(*ack)
- assert.Nil(t, err, "no vendor extensions should not return error")
- assert.Empty(t, images, "should not get images from ACK without Vendor extensions")
+ require.NoError(t, err, "no vendor extensions should not return error")
+ require.Empty(t, images, "should not get images from ACK without Vendor extensions")
}
func TestParseBootImageListFromAckFail(t *testing.T) {
@@ -348,15 +347,15 @@ func TestParseBootImageListFromAckFail(t *testing.T) {
})
images, err := ParseBootImageListFromAck(*ack)
- assert.Nil(t, images, "should get nil on parse error")
- assert.NotNil(t, err, "should get error on parse error")
+ require.Nil(t, images, "should get nil on parse error")
+ require.Error(t, err, "should get error on parse error")
}
/*
* Private funcs
*/
func TestNeedsReplyPort(t *testing.T) {
- assert.True(t, needsReplyPort(123), "")
- assert.False(t, needsReplyPort(0), "")
- assert.False(t, needsReplyPort(dhcpv4.ClientPort), "")
+ require.True(t, needsReplyPort(123))
+ require.False(t, needsReplyPort(0))
+ require.False(t, needsReplyPort(dhcpv4.ClientPort))
}
diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go
index 0d569c8..e9824ef 100644
--- a/dhcpv4/dhcpv4_test.go
+++ b/dhcpv4/dhcpv4_test.go
@@ -1,31 +1,16 @@
package dhcpv4
import (
- "bytes"
"net"
"testing"
"github.com/insomniacslk/dhcp/iana"
+ "github.com/stretchr/testify/require"
)
-// NOTE: if one of the following Assert* fails where expected and got values are
-// the same, you probably have to cast one of them to match the other one's
-// type, e.g. comparing int and byte, even the same value, will fail.
-func AssertEqual(t *testing.T, a, b interface{}, what string) {
- if a != b {
- t.Fatalf("Invalid %s. %v != %v", what, a, b)
- }
-}
-
-func AssertEqualBytes(t *testing.T, a, b []byte, what string) {
- if !bytes.Equal(a, b) {
- t.Fatalf("Invalid %s. %v != %v", what, a, b)
- }
-}
-
-func AssertEqualIPAddr(t *testing.T, a, b net.IP, what string) {
+func RequireEqualIPAddr(t *testing.T, a, b net.IP, msg ...interface{}) {
if !net.IP.Equal(a, b) {
- t.Fatalf("Invalid %s. %v != %v", what, a, b)
+ t.Fatalf("Invalid %s. %v != %v", msg, a, b)
}
}
@@ -60,26 +45,23 @@ func TestFromBytes(t *testing.T) {
data = append(data, []byte{99, 130, 83, 99}...)
d, err := FromBytes(data)
- if err != nil {
- t.Fatal(err)
- }
- AssertEqual(t, d.Opcode(), OpcodeBootRequest, "opcode")
- AssertEqual(t, d.HwType(), iana.HwTypeEthernet, "hardware type")
- AssertEqual(t, d.HwAddrLen(), byte(6), "hardware address length")
- AssertEqual(t, d.HopCount(), byte(3), "hop count")
- AssertEqual(t, d.TransactionID(), uint32(0xaabbccdd), "transaction ID")
- AssertEqual(t, d.NumSeconds(), uint16(3), "number of seconds")
- AssertEqual(t, d.Flags(), uint16(1), "flags")
- AssertEqualIPAddr(t, d.ClientIPAddr(), net.IPv4zero, "client IP address")
- AssertEqualIPAddr(t, d.YourIPAddr(), net.IPv4zero, "your IP address")
- AssertEqualIPAddr(t, d.ServerIPAddr(), net.IPv4zero, "server IP address")
- AssertEqualIPAddr(t, d.GatewayIPAddr(), net.IPv4zero, "gateway IP address")
- hwaddr := d.ClientHwAddr()
- AssertEqualBytes(t, hwaddr[:], []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, "flags")
+ require.NoError(t, err)
+ require.Equal(t, d.Opcode(), OpcodeBootRequest)
+ require.Equal(t, d.HwType(), iana.HwTypeEthernet)
+ require.Equal(t, d.HwAddrLen(), byte(6))
+ require.Equal(t, d.HopCount(), byte(3))
+ require.Equal(t, d.TransactionID(), uint32(0xaabbccdd))
+ require.Equal(t, d.NumSeconds(), uint16(3))
+ require.Equal(t, d.Flags(), uint16(1))
+ RequireEqualIPAddr(t, d.ClientIPAddr(), net.IPv4zero)
+ RequireEqualIPAddr(t, d.YourIPAddr(), net.IPv4zero)
+ RequireEqualIPAddr(t, d.GatewayIPAddr(), net.IPv4zero)
+ clientHwAddr := d.ClientHwAddr()
+ require.Equal(t, clientHwAddr[:], []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
hostname := d.ServerHostName()
- AssertEqualBytes(t, hostname[:], expectedHostname, "server host name")
- bootfilename := d.BootFileName()
- AssertEqualBytes(t, bootfilename[:], expectedBootfilename, "boot file name")
+ require.Equal(t, hostname[:], expectedHostname)
+ bootfileName := d.BootFileName()
+ require.Equal(t, bootfileName[:], expectedBootfilename)
// no need to check Magic Cookie as it is already validated in FromBytes
// above
}
@@ -87,17 +69,13 @@ func TestFromBytes(t *testing.T) {
func TestFromBytesZeroLength(t *testing.T) {
data := []byte{}
_, err := FromBytes(data)
- if err == nil {
- t.Fatal("Expected error, got nil")
- }
+ require.Error(t, err)
}
func TestFromBytesShortLength(t *testing.T) {
data := []byte{1, 1, 6, 0}
_, err := FromBytes(data)
- if err == nil {
- t.Fatal("Expected error, got nil")
- }
+ require.Error(t, err)
}
func TestFromBytesInvalidOptions(t *testing.T) {
@@ -126,9 +104,7 @@ func TestFromBytesInvalidOptions(t *testing.T) {
// invalid magic cookie, forcing option parsing to fail
data = append(data, []byte{99, 130, 83, 98}...)
_, err := FromBytes(data)
- if err == nil {
- t.Fatal("Expected error, got nil")
- }
+ require.Error(t, err)
}
func TestSettersAndGetters(t *testing.T) {
@@ -161,91 +137,89 @@ func TestSettersAndGetters(t *testing.T) {
// magic cookie, then no options
data = append(data, []byte{99, 130, 83, 99}...)
d, err := FromBytes(data)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
// getter/setter for Opcode
- AssertEqual(t, d.Opcode(), OpcodeBootRequest, "opcode")
+ require.Equal(t, OpcodeBootRequest, d.Opcode())
d.SetOpcode(OpcodeBootReply)
- AssertEqual(t, d.Opcode(), OpcodeBootReply, "opcode")
+ require.Equal(t, OpcodeBootReply, d.Opcode())
// getter/setter for HwType
- AssertEqual(t, d.HwType(), iana.HwTypeEthernet, "hardware type")
+ require.Equal(t, iana.HwTypeEthernet, d.HwType())
d.SetHwType(iana.HwTypeARCNET)
- AssertEqual(t, d.HwType(), iana.HwTypeARCNET, "hardware type")
+ require.Equal(t, iana.HwTypeARCNET, d.HwType())
// getter/setter for HwAddrLen
- AssertEqual(t, d.HwAddrLen(), uint8(6), "hardware address length")
+ require.Equal(t, uint8(6), d.HwAddrLen())
d.SetHwAddrLen(12)
- AssertEqual(t, d.HwAddrLen(), uint8(12), "hardware address length")
+ require.Equal(t, uint8(12), d.HwAddrLen())
// getter/setter for HopCount
- AssertEqual(t, d.HopCount(), uint8(3), "hop count")
+ require.Equal(t, uint8(3), d.HopCount())
d.SetHopCount(1)
- AssertEqual(t, d.HopCount(), uint8(1), "hop count")
+ require.Equal(t, uint8(1), d.HopCount())
// getter/setter for TransactionID
- AssertEqual(t, d.TransactionID(), uint32(0xaabbccdd), "transaction ID")
+ require.Equal(t, uint32(0xaabbccdd), d.TransactionID())
d.SetTransactionID(0xeeff0011)
- AssertEqual(t, d.TransactionID(), uint32(0xeeff0011), "transaction ID")
+ require.Equal(t, uint32(0xeeff0011), d.TransactionID())
// getter/setter for TransactionID
- AssertEqual(t, d.NumSeconds(), uint16(3), "number of seconds")
+ require.Equal(t, uint16(3), d.NumSeconds())
d.SetNumSeconds(15)
- AssertEqual(t, d.NumSeconds(), uint16(15), "number of seconds")
+ require.Equal(t, uint16(15), d.NumSeconds())
// getter/setter for Flags
- AssertEqual(t, d.Flags(), uint16(1), "flags")
+ require.Equal(t, uint16(1), d.Flags())
d.SetFlags(0)
- AssertEqual(t, d.Flags(), uint16(0), "flags")
+ require.Equal(t, uint16(0), d.Flags())
// getter/setter for ClientIPAddr
- AssertEqualIPAddr(t, d.ClientIPAddr(), net.IPv4(1, 2, 3, 4), "client IP address")
+ RequireEqualIPAddr(t, net.IPv4(1, 2, 3, 4), d.ClientIPAddr())
d.SetClientIPAddr(net.IPv4(4, 3, 2, 1))
- AssertEqualIPAddr(t, d.ClientIPAddr(), net.IPv4(4, 3, 2, 1), "client IP address")
+ RequireEqualIPAddr(t, net.IPv4(4, 3, 2, 1), d.ClientIPAddr())
// getter/setter for YourIPAddr
- AssertEqualIPAddr(t, d.YourIPAddr(), net.IPv4(5, 6, 7, 8), "your IP address")
+ RequireEqualIPAddr(t, net.IPv4(5, 6, 7, 8), d.YourIPAddr())
d.SetYourIPAddr(net.IPv4(8, 7, 6, 5))
- AssertEqualIPAddr(t, d.YourIPAddr(), net.IPv4(8, 7, 6, 5), "your IP address")
+ RequireEqualIPAddr(t, net.IPv4(8, 7, 6, 5), d.YourIPAddr())
// getter/setter for ServerIPAddr
- AssertEqualIPAddr(t, d.ServerIPAddr(), net.IPv4(9, 10, 11, 12), "server IP address")
+ RequireEqualIPAddr(t, net.IPv4(9, 10, 11, 12), d.ServerIPAddr())
d.SetServerIPAddr(net.IPv4(12, 11, 10, 9))
- AssertEqualIPAddr(t, d.ServerIPAddr(), net.IPv4(12, 11, 10, 9), "server IP address")
+ RequireEqualIPAddr(t, net.IPv4(12, 11, 10, 9), d.ServerIPAddr())
// getter/setter for GatewayIPAddr
- AssertEqualIPAddr(t, d.GatewayIPAddr(), net.IPv4(13, 14, 15, 16), "gateway IP address")
+ RequireEqualIPAddr(t, net.IPv4(13, 14, 15, 16), d.GatewayIPAddr())
d.SetGatewayIPAddr(net.IPv4(16, 15, 14, 13))
- AssertEqualIPAddr(t, d.GatewayIPAddr(), net.IPv4(16, 15, 14, 13), "gateway IP address")
+ RequireEqualIPAddr(t, net.IPv4(16, 15, 14, 13), d.GatewayIPAddr())
// getter/setter for ClientHwAddr
hwaddr := d.ClientHwAddr()
- AssertEqualBytes(t, hwaddr[:], []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, "client hardware address")
+ require.Equal(t, []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, hwaddr[:])
d.SetFlags(0)
// getter/setter for ServerHostName
serverhostname := d.ServerHostName()
- AssertEqualBytes(t, serverhostname[:], expectedHostname, "server host name")
+ require.Equal(t, expectedHostname, serverhostname[:])
newHostname := []byte{'t', 'e', 's', 't'}
for i := 0; i < 60; i++ {
newHostname = append(newHostname, 0)
}
d.SetServerHostName(newHostname)
serverhostname = d.ServerHostName()
- AssertEqualBytes(t, serverhostname[:], newHostname, "server host name")
+ require.Equal(t, newHostname, serverhostname[:])
// getter/setter for BootFileName
bootfilename := d.BootFileName()
- AssertEqualBytes(t, bootfilename[:], expectedBootfilename, "boot file name")
+ require.Equal(t, expectedBootfilename, bootfilename[:])
newBootfilename := []byte{'t', 'e', 's', 't'}
for i := 0; i < 124; i++ {
newBootfilename = append(newBootfilename, 0)
}
d.SetBootFileName(newBootfilename)
bootfilename = d.BootFileName()
- AssertEqualBytes(t, bootfilename[:], newBootfilename, "boot file name")
+ require.Equal(t, newBootfilename, bootfilename[:])
}
func TestToStringMethods(t *testing.T) {
@@ -255,38 +229,38 @@ func TestToStringMethods(t *testing.T) {
}
// OpcodeToString
d.SetOpcode(OpcodeBootRequest)
- AssertEqual(t, d.OpcodeToString(), "BootRequest", "OpcodeToString")
+ require.Equal(t, "BootRequest", d.OpcodeToString())
d.SetOpcode(OpcodeBootReply)
- AssertEqual(t, d.OpcodeToString(), "BootReply", "OpcodeToString")
+ require.Equal(t, "BootReply", d.OpcodeToString())
d.SetOpcode(OpcodeType(0))
- AssertEqual(t, d.OpcodeToString(), "Invalid", "OpcodeToString")
+ require.Equal(t, "Invalid", d.OpcodeToString())
// HwTypeToString
d.SetHwType(iana.HwTypeEthernet)
- AssertEqual(t, d.HwTypeToString(), "Ethernet", "HwTypeToString")
+ require.Equal(t, "Ethernet", d.HwTypeToString())
d.SetHwType(iana.HwTypeARCNET)
- AssertEqual(t, d.HwTypeToString(), "ARCNET", "HwTypeToString")
+ require.Equal(t, "ARCNET", d.HwTypeToString())
// FlagsToString
d.SetUnicast()
- AssertEqual(t, d.FlagsToString(), "Unicast", "FlagsToString")
+ require.Equal(t, "Unicast", d.FlagsToString())
d.SetBroadcast()
- AssertEqual(t, d.FlagsToString(), "Broadcast", "FlagsToString")
+ require.Equal(t, "Broadcast", d.FlagsToString())
d.SetFlags(0xffff)
- AssertEqual(t, d.FlagsToString(), "Broadcast (reserved bits not zeroed)", "FlagsToString")
+ require.Equal(t, "Broadcast (reserved bits not zeroed)", d.FlagsToString())
// ClientHwAddrToString
d.SetHwAddrLen(6)
d.SetClientHwAddr([]byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
- AssertEqual(t, d.ClientHwAddrToString(), "aa:bb:cc:dd:ee:ff", "ClientHwAddrToString")
+ require.Equal(t, "aa:bb:cc:dd:ee:ff", d.ClientHwAddrToString())
// ServerHostNameToString
d.SetServerHostName([]byte("my.host.local"))
- AssertEqual(t, d.ServerHostNameToString(), "my.host.local", "ServerHostNameToString")
+ require.Equal(t, "my.host.local", d.ServerHostNameToString())
// BootFileNameToString
d.SetBootFileName([]byte("/my/boot/file"))
- AssertEqual(t, d.BootFileNameToString(), "/my/boot/file", "BootFileNameToString")
+ require.Equal(t, "/my/boot/file", d.BootFileNameToString())
}
func TestToBytes(t *testing.T) {
@@ -318,14 +292,12 @@ func TestToBytes(t *testing.T) {
expected = append(expected, MagicCookie...)
d, err := New()
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
// fix TransactionID to match the expected one, since it's randomly
// generated in New()
d.SetTransactionID(0x11223344)
got := d.ToBytes()
- AssertEqualBytes(t, expected, got, "ToBytes")
+ require.Equal(t, expected, got)
}
// TODO
diff --git a/dhcpv4/options_test.go b/dhcpv4/options_test.go
index 07be9bf..ca860cf 100644
--- a/dhcpv4/options_test.go
+++ b/dhcpv4/options_test.go
@@ -1,52 +1,37 @@
package dhcpv4
import (
- "bytes"
"testing"
+
+ "github.com/stretchr/testify/require"
)
func TestParseOption(t *testing.T) {
option := []byte{5, 4, 192, 168, 1, 254} // DNS option
opt, err := ParseOption(option)
- if err != nil {
- t.Fatal(err)
- }
- if opt.Code != OptionNameServer {
- t.Fatalf("Invalid option code. Expected 5, got %v", opt.Code)
- }
- if !bytes.Equal(opt.Data, option[2:]) {
- t.Fatalf("Invalid option data. Expected %v, got %v", option[2:], opt.Data)
- }
+ require.NoError(t, err, "should not get error from parsing option")
+ require.Equal(t, OptionNameServer, opt.Code, "opt should have the same opcode")
+ require.Equal(t, option[2:], opt.Data, "opt should have the same data")
}
func TestParseOptionPad(t *testing.T) {
option := []byte{0}
opt, err := ParseOption(option)
- if err != nil {
- t.Fatal(err)
- }
- if opt.Code != OptionPad {
- t.Fatalf("Invalid option code. Expected %v, got %v", OptionPad, opt.Code)
- }
- if len(opt.Data) != 0 {
- t.Fatalf("Invalid option data. Expected empty slice, got %v", opt.Data)
- }
+ require.NoError(t, err, "should not get error from parsing option")
+ require.Equal(t, OptionPad, opt.Code, "should get pad option code")
+ require.Empty(t, opt.Data, "should get empty data with pad option")
}
func TestParseOptionZeroLength(t *testing.T) {
option := []byte{}
_, err := ParseOption(option)
- if err == nil {
- t.Fatal("Expected an error, got none")
- }
+ require.Error(t, err, "should get error from zero-length options")
}
func TestParseOptionShortOption(t *testing.T) {
option := []byte{53, 1}
_, err := ParseOption(option)
- if err == nil {
- t.Fatal(err)
- }
+ require.Error(t, err, "should get error from short options")
}
func TestOptionsFromBytes(t *testing.T) {
@@ -57,41 +42,29 @@ func TestOptionsFromBytes(t *testing.T) {
0, 0, 0, //padding
}
opts, err := OptionsFromBytes(options)
- if err != nil {
- t.Fatal(err)
- }
- // each padding byte counts as an option. Magic Cookie doesn't add up
- if len(opts) != 5 {
- t.Fatal("Invalid options length. Expected 5, got %v", len(opts))
- }
- if opts[0].Code != OptionNameServer {
- t.Fatal("Invalid option code. Expected %v, got %v", OptionNameServer, opts[0].Code)
- }
- if !bytes.Equal(opts[0].Data, options[6:10]) {
- t.Fatal("Invalid option data. Expected %v, got %v", options[6:10], opts[0].Data)
- }
- if opts[1].Code != OptionEnd {
- t.Fatalf("Invalid option code. Expected %v, got %v", OptionEnd, opts[1].Code)
- }
- if opts[2].Code != OptionPad {
- t.Fatalf("Invalid option code. Expected %v, got %v", OptionPad, opts[2].Code)
- }
+ require.NoError(t, err)
+ require.Equal(t, []Option{
+ Option{
+ Code: OptionNameServer,
+ Data: []byte{192, 168, 1, 1},
+ },
+ Option{Code: OptionEnd, Data: []byte{}},
+ Option{Code: OptionPad, Data: []byte{}},
+ Option{Code: OptionPad, Data: []byte{}},
+ Option{Code: OptionPad, Data: []byte{}},
+ }, opts)
}
func TestOptionsFromBytesZeroLength(t *testing.T) {
options := []byte{}
_, err := OptionsFromBytes(options)
- if err == nil {
- t.Fatal("Expected an error, got none")
- }
+ require.Error(t, err)
}
func TestOptionsFromBytesBadMagicCookie(t *testing.T) {
options := []byte{1, 2, 3, 4}
_, err := OptionsFromBytes(options)
- if err == nil {
- t.Fatal("Expected an error, got none")
- }
+ require.Error(t, err)
}
func TestOptionsToBytes(t *testing.T) {
@@ -102,47 +75,31 @@ func TestOptionsToBytes(t *testing.T) {
0, 0, 0, //padding
}
options, err := OptionsFromBytes(originalOptions)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
finalOptions := OptionsToBytes(options)
- if !bytes.Equal(originalOptions, finalOptions) {
- t.Fatalf("Invalid options. Expected %v, got %v", originalOptions, finalOptions)
- }
+ require.Equal(t, originalOptions, finalOptions)
}
func TestOptionsToBytesEmpty(t *testing.T) {
originalOptions := []byte{99, 130, 83, 99}
options, err := OptionsFromBytes(originalOptions)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
finalOptions := OptionsToBytes(options)
- if !bytes.Equal(originalOptions, finalOptions) {
- t.Fatalf("Invalid options. Expected %v, got %v", originalOptions, finalOptions)
- }
+ require.Equal(t, originalOptions, finalOptions)
}
func TestOptionsToStringPad(t *testing.T) {
option := []byte{0}
opt, err := ParseOption(option)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
stropt := opt.String()
- if stropt != "Pad -> []" {
- t.Fatalf("Invalid string representation: %v", stropt)
- }
+ require.Equal(t, "Pad -> []", stropt)
}
func TestOptionsToStringDHCPMessageType(t *testing.T) {
option := []byte{53, 1, 5}
opt, err := ParseOption(option)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
stropt := opt.String()
- if stropt != "DHCP Message Type -> [5]" {
- t.Fatalf("Invalid string representation: %v", stropt)
- }
+ require.Equal(t, "DHCP Message Type -> [5]", stropt)
}