package dhcpv4 import ( "net" "testing" "github.com/insomniacslk/dhcp/iana" "github.com/stretchr/testify/require" ) func RequireEqualIPAddr(t *testing.T, a, b net.IP, msg ...interface{}) { if !net.IP.Equal(a, b) { t.Fatalf("Invalid %s. %v != %v", msg, a, b) } } func TestFromBytes(t *testing.T) { data := []byte{ 1, // dhcp request 1, // ethernet hw type 6, // hw addr length 3, // hop count 0xaa, 0xbb, 0xcc, 0xdd, // transaction ID, big endian (network) 0, 3, // number of seconds 0, 1, // broadcast 0, 0, 0, 0, // client IP address 0, 0, 0, 0, // your IP address 0, 0, 0, 0, // server IP address 0, 0, 0, 0, // gateway IP address 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // client MAC address + padding } // server host name expectedHostname := []byte{} for i := 0; i < 64; i++ { expectedHostname = append(expectedHostname, 0) } data = append(data, expectedHostname...) // boot file name expectedBootfilename := []byte{} for i := 0; i < 128; i++ { expectedBootfilename = append(expectedBootfilename, 0) } data = append(data, expectedBootfilename...) // magic cookie, then no options data = append(data, []byte{99, 130, 83, 99}...) d, err := FromBytes(data) require.NoError(t, err) require.Equal(t, d.Opcode(), OpcodeBootRequest) require.Equal(t, d.HwType(), iana.HwTypeEthernet) require.Equal(t, d.HwAddrLen(), byte(6)) require.Equal(t, d.HopCount(), byte(3)) require.Equal(t, d.TransactionID(), uint32(0xaabbccdd)) require.Equal(t, d.NumSeconds(), uint16(3)) require.Equal(t, d.Flags(), uint16(1)) RequireEqualIPAddr(t, d.ClientIPAddr(), net.IPv4zero) RequireEqualIPAddr(t, d.YourIPAddr(), net.IPv4zero) RequireEqualIPAddr(t, d.GatewayIPAddr(), net.IPv4zero) clientHwAddr := d.ClientHwAddr() require.Equal(t, clientHwAddr[:], []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) hostname := d.ServerHostName() require.Equal(t, hostname[:], expectedHostname) bootfileName := d.BootFileName() require.Equal(t, bootfileName[:], expectedBootfilename) // no need to check Magic Cookie as it is already validated in FromBytes // above } func TestFromBytesZeroLength(t *testing.T) { data := []byte{} _, err := FromBytes(data) require.Error(t, err) } func TestFromBytesShortLength(t *testing.T) { data := []byte{1, 1, 6, 0} _, err := FromBytes(data) require.Error(t, err) } func TestFromBytesInvalidOptions(t *testing.T) { data := []byte{ 1, // dhcp request 1, // ethernet hw type 6, // hw addr length 0, // hop count 0xaa, 0xbb, 0xcc, 0xdd, // transaction ID 3, 0, // number of seconds 1, 0, // broadcast 0, 0, 0, 0, // client IP address 0, 0, 0, 0, // your IP address 0, 0, 0, 0, // server IP address 0, 0, 0, 0, // gateway IP address 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // client MAC address + padding } // server host name for i := 0; i < 64; i++ { data = append(data, 0) } // boot file name for i := 0; i < 128; i++ { data = append(data, 0) } // invalid magic cookie, forcing option parsing to fail data = append(data, []byte{99, 130, 83, 98}...) _, err := FromBytes(data) require.Error(t, err) } func TestSettersAndGetters(t *testing.T) { data := []byte{ 1, // dhcp request 1, // ethernet hw type 6, // hw addr length 3, // hop count 0xaa, 0xbb, 0xcc, 0xdd, // transaction ID, big endian (network) 0, 3, // number of seconds 0, 1, // broadcast 1, 2, 3, 4, // client IP address 5, 6, 7, 8, // your IP address 9, 10, 11, 12, // server IP address 13, 14, 15, 16, // gateway IP address 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // client MAC address + padding } // server host name expectedHostname := []byte{} for i := 0; i < 64; i++ { expectedHostname = append(expectedHostname, 0) } data = append(data, expectedHostname...) // boot file name expectedBootfilename := []byte{} for i := 0; i < 128; i++ { expectedBootfilename = append(expectedBootfilename, 0) } data = append(data, expectedBootfilename...) // magic cookie, then no options data = append(data, []byte{99, 130, 83, 99}...) d, err := FromBytes(data) require.NoError(t, err) // getter/setter for Opcode require.Equal(t, OpcodeBootRequest, d.Opcode()) d.SetOpcode(OpcodeBootReply) require.Equal(t, OpcodeBootReply, d.Opcode()) // getter/setter for HwType require.Equal(t, iana.HwTypeEthernet, d.HwType()) d.SetHwType(iana.HwTypeARCNET) require.Equal(t, iana.HwTypeARCNET, d.HwType()) // getter/setter for HwAddrLen require.Equal(t, uint8(6), d.HwAddrLen()) d.SetHwAddrLen(12) require.Equal(t, uint8(12), d.HwAddrLen()) // getter/setter for HopCount require.Equal(t, uint8(3), d.HopCount()) d.SetHopCount(1) require.Equal(t, uint8(1), d.HopCount()) // getter/setter for TransactionID require.Equal(t, uint32(0xaabbccdd), d.TransactionID()) d.SetTransactionID(0xeeff0011) require.Equal(t, uint32(0xeeff0011), d.TransactionID()) // getter/setter for TransactionID require.Equal(t, uint16(3), d.NumSeconds()) d.SetNumSeconds(15) require.Equal(t, uint16(15), d.NumSeconds()) // getter/setter for Flags require.Equal(t, uint16(1), d.Flags()) d.SetFlags(0) require.Equal(t, uint16(0), d.Flags()) // getter/setter for ClientIPAddr RequireEqualIPAddr(t, net.IPv4(1, 2, 3, 4), d.ClientIPAddr()) d.SetClientIPAddr(net.IPv4(4, 3, 2, 1)) RequireEqualIPAddr(t, net.IPv4(4, 3, 2, 1), d.ClientIPAddr()) // getter/setter for YourIPAddr RequireEqualIPAddr(t, net.IPv4(5, 6, 7, 8), d.YourIPAddr()) d.SetYourIPAddr(net.IPv4(8, 7, 6, 5)) RequireEqualIPAddr(t, net.IPv4(8, 7, 6, 5), d.YourIPAddr()) // getter/setter for ServerIPAddr RequireEqualIPAddr(t, net.IPv4(9, 10, 11, 12), d.ServerIPAddr()) d.SetServerIPAddr(net.IPv4(12, 11, 10, 9)) RequireEqualIPAddr(t, net.IPv4(12, 11, 10, 9), d.ServerIPAddr()) // getter/setter for GatewayIPAddr RequireEqualIPAddr(t, net.IPv4(13, 14, 15, 16), d.GatewayIPAddr()) d.SetGatewayIPAddr(net.IPv4(16, 15, 14, 13)) RequireEqualIPAddr(t, net.IPv4(16, 15, 14, 13), d.GatewayIPAddr()) // getter/setter for ClientHwAddr hwaddr := d.ClientHwAddr() require.Equal(t, []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, hwaddr[:]) d.SetFlags(0) // getter/setter for ServerHostName serverhostname := d.ServerHostName() require.Equal(t, expectedHostname, serverhostname[:]) newHostname := []byte{'t', 'e', 's', 't'} for i := 0; i < 60; i++ { newHostname = append(newHostname, 0) } d.SetServerHostName(newHostname) serverhostname = d.ServerHostName() require.Equal(t, newHostname, serverhostname[:]) // getter/setter for BootFileName bootfilename := d.BootFileName() require.Equal(t, expectedBootfilename, bootfilename[:]) newBootfilename := []byte{'t', 'e', 's', 't'} for i := 0; i < 124; i++ { newBootfilename = append(newBootfilename, 0) } d.SetBootFileName(newBootfilename) bootfilename = d.BootFileName() require.Equal(t, newBootfilename, bootfilename[:]) } func TestToStringMethods(t *testing.T) { d, err := New() if err != nil { t.Fatal(err) } // OpcodeToString d.SetOpcode(OpcodeBootRequest) require.Equal(t, "BootRequest", d.OpcodeToString()) d.SetOpcode(OpcodeBootReply) require.Equal(t, "BootReply", d.OpcodeToString()) d.SetOpcode(OpcodeType(0)) require.Equal(t, "Invalid", d.OpcodeToString()) // HwTypeToString d.SetHwType(iana.HwTypeEthernet) require.Equal(t, "Ethernet", d.HwTypeToString()) d.SetHwType(iana.HwTypeARCNET) require.Equal(t, "ARCNET", d.HwTypeToString()) // FlagsToString d.SetUnicast() require.Equal(t, "Unicast", d.FlagsToString()) d.SetBroadcast() require.Equal(t, "Broadcast", d.FlagsToString()) d.SetFlags(0xffff) require.Equal(t, "Broadcast (reserved bits not zeroed)", d.FlagsToString()) // ClientHwAddrToString d.SetHwAddrLen(6) d.SetClientHwAddr([]byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) require.Equal(t, "aa:bb:cc:dd:ee:ff", d.ClientHwAddrToString()) // ServerHostNameToString d.SetServerHostName([]byte("my.host.local")) require.Equal(t, "my.host.local", d.ServerHostNameToString()) // BootFileNameToString d.SetBootFileName([]byte("/my/boot/file")) require.Equal(t, "/my/boot/file", d.BootFileNameToString()) } func TestNewToBytes(t *testing.T) { // the following bytes match what dhcpv4.New would create. Keep them in // sync! expected := []byte{ 1, // Opcode BootRequest 1, // HwType Ethernet 6, // HwAddrLen 0, // HopCount 0x11, 0x22, 0x33, 0x44, // TransactionID 0, 0, // NumSeconds 0, 0, // Flags 0, 0, 0, 0, // ClientIPAddr 0, 0, 0, 0, // YourIPAddr 0, 0, 0, 0, // ServerIPAddr 0, 0, 0, 0, // GatewayIPAddr 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // ClientHwAddr } // ServerHostName for i := 0; i < 64; i++ { expected = append(expected, 0) } // BootFileName for i := 0; i < 128; i++ { expected = append(expected, 0) } // Magic Cookie expected = append(expected, MagicCookie...) // End expected = append(expected, 0xff) d, err := New() require.NoError(t, err) // fix TransactionID to match the expected one, since it's randomly // generated in New() d.SetTransactionID(0x11223344) got := d.ToBytes() require.Equal(t, expected, got) } func TestGetOption(t *testing.T) { d, err := New() if err != nil { t.Fatal(err) } hostnameOpt := &OptionGeneric{OptionCode: OptionHostName, Data: []byte("darkstar")} bootFileOpt1 := &OptionGeneric{OptionCode: OptionBootfileName, Data: []byte("boot.img")} bootFileOpt2 := &OptionGeneric{OptionCode: OptionBootfileName, Data: []byte("boot2.img")} d.AddOption(hostnameOpt) d.AddOption(bootFileOpt1) d.AddOption(bootFileOpt2) require.Equal(t, d.GetOption(OptionHostName), []Option{hostnameOpt}) require.Equal(t, d.GetOption(OptionBootfileName), []Option{bootFileOpt1, bootFileOpt2}) require.Equal(t, d.GetOption(OptionRouter), []Option{}) require.Equal(t, d.GetOneOption(OptionHostName), hostnameOpt) require.Equal(t, d.GetOneOption(OptionBootfileName), bootFileOpt1) require.Equal(t, d.GetOneOption(OptionRouter), nil) } func TestAddOption(t *testing.T) { d, err := New() if err != nil { t.Fatal(err) } hostnameOpt := &OptionGeneric{OptionCode: OptionHostName, Data: []byte("darkstar")} bootFileOpt1 := &OptionGeneric{OptionCode: OptionBootfileName, Data: []byte("boot.img")} bootFileOpt2 := &OptionGeneric{OptionCode: OptionBootfileName, Data: []byte("boot2.img")} d.AddOption(hostnameOpt) d.AddOption(bootFileOpt1) d.AddOption(bootFileOpt2) options := d.Options() require.Equal(t, len(options), 4) require.Equal(t, options[3].Code(), OptionEnd) } func TestDHCPv4RequestFromOffer(t *testing.T) { offer, err := New() require.NoError(t, err) offer.AddOption(&OptMessageType{MessageType: MessageTypeOffer}) offer.AddOption(&OptServerIdentifier{ServerID: net.IPv4(192, 168, 0, 1)}) req, err := RequestFromOffer(*offer) require.NoError(t, err) require.NotEqual(t, (*MessageType)(nil), *req.MessageType()) require.Equal(t, MessageTypeRequest, *req.MessageType()) } func TestDHCPv4RequestFromOfferWithModifier(t *testing.T) { offer, err := New() require.NoError(t, err) offer.AddOption(&OptMessageType{MessageType: MessageTypeOffer}) offer.AddOption(&OptServerIdentifier{ServerID: net.IPv4(192, 168, 0, 1)}) userClass := WithUserClass([]byte("linuxboot"), false) req, err := RequestFromOffer(*offer, userClass) require.NoError(t, err) require.NotEqual(t, (*MessageType)(nil), *req.MessageType()) require.Equal(t, MessageTypeRequest, *req.MessageType()) require.Equal(t, "User Class Information -> linuxboot", req.options[3].String()) } func TestNewReplyFromRequest(t *testing.T) { discover, err := New() require.NoError(t, err) discover.SetGatewayIPAddr(net.IPv4(192, 168, 0, 1)) reply, err := NewReplyFromRequest(discover) require.NoError(t, err) require.Equal(t, discover.TransactionID(), reply.TransactionID()) require.Equal(t, discover.GatewayIPAddr(), reply.GatewayIPAddr()) } func TestNewReplyFromRequestWithModifier(t *testing.T) { discover, err := New() require.NoError(t, err) discover.SetGatewayIPAddr(net.IPv4(192, 168, 0, 1)) userClass := WithUserClass([]byte("linuxboot"), false) reply, err := NewReplyFromRequest(discover, userClass) require.NoError(t, err) require.Equal(t, discover.TransactionID(), reply.TransactionID()) require.Equal(t, discover.GatewayIPAddr(), reply.GatewayIPAddr()) require.Equal(t, "User Class Information -> linuxboot", reply.options[0].String()) } func TestDHCPv4MessageTypeNil(t *testing.T) { m, err := New() require.NoError(t, err) require.Equal(t, (*MessageType)(nil), m.MessageType()) } func TestDHCPv4MessageTypeDiscovery(t *testing.T) { m, err := NewDiscoveryForInterface("lo") require.NoError(t, err) require.NotEqual(t, (*MessageType)(nil), m.MessageType()) require.Equal(t, MessageTypeDiscover, *m.MessageType()) } // TODO // test broadcast/unicast flags // test Options setter/getter // test Summary() and String()