diff options
Diffstat (limited to 'dhcpv4')
-rw-r--r-- | dhcpv4/async/client.go | 8 | ||||
-rw-r--r-- | dhcpv4/async/client_test.go | 2 | ||||
-rw-r--r-- | dhcpv4/bsdp/bsdp.go | 30 | ||||
-rw-r--r-- | dhcpv4/bsdp/bsdp_test.go | 49 | ||||
-rw-r--r-- | dhcpv4/bsdp/client.go | 2 | ||||
-rw-r--r-- | dhcpv4/client.go | 6 | ||||
-rw-r--r-- | dhcpv4/dhcpv4.go | 416 | ||||
-rw-r--r-- | dhcpv4/dhcpv4_test.go | 218 | ||||
-rw-r--r-- | dhcpv4/modifiers.go | 10 | ||||
-rw-r--r-- | dhcpv4/modifiers_test.go | 50 | ||||
-rw-r--r-- | dhcpv4/server_test.go | 10 | ||||
-rw-r--r-- | dhcpv4/types.go | 15 |
12 files changed, 258 insertions, 558 deletions
diff --git a/dhcpv4/async/client.go b/dhcpv4/async/client.go index c37d249..9844180 100644 --- a/dhcpv4/async/client.go +++ b/dhcpv4/async/client.go @@ -132,7 +132,7 @@ func (c *Client) senderLoop(ctx context.Context) { func (c *Client) send(packet *dhcpv4.DHCPv4) { c.packetsLock.Lock() - p := c.packets[packet.TransactionID()] + p := c.packets[packet.TransactionID] c.packetsLock.Unlock() raddr, err := c.remoteAddr() @@ -174,8 +174,8 @@ func (c *Client) receive(_ *dhcpv4.DHCPv4) { } c.packetsLock.Lock() - if p, ok := c.packets[received.TransactionID()]; ok { - delete(c.packets, received.TransactionID()) + if p, ok := c.packets[received.TransactionID]; ok { + delete(c.packets, received.TransactionID) p.Resolve(received) } c.packetsLock.Unlock() @@ -201,7 +201,7 @@ func (c *Client) Send(message *dhcpv4.DHCPv4, modifiers ...dhcpv4.Modifier) *pro p := promise.NewPromise() c.packetsLock.Lock() - c.packets[message.TransactionID()] = p + c.packets[message.TransactionID] = p c.packetsLock.Unlock() c.sendQueue <- message return p.Future diff --git a/dhcpv4/async/client_test.go b/dhcpv4/async/client_test.go index 4be6edd..7fa0e9e 100644 --- a/dhcpv4/async/client_test.go +++ b/dhcpv4/async/client_test.go @@ -121,5 +121,5 @@ func TestSend(t *testing.T) { require.True(t, ok) require.False(t, timeout) require.NoError(t, err) - require.Equal(t, m.TransactionID(), r.TransactionID()) + require.Equal(t, m.TransactionID, r.TransactionID) } diff --git a/dhcpv4/bsdp/bsdp.go b/dhcpv4/bsdp/bsdp.go index 8d4430a..3f97602 100644 --- a/dhcpv4/bsdp/bsdp.go +++ b/dhcpv4/bsdp/bsdp.go @@ -143,12 +143,10 @@ func InformSelectForAck(ack dhcpv4.DHCPv4, replyPort uint16, selectedImage BootI if needsReplyPort(replyPort) && replyPort >= 1024 { return nil, errors.New("replyPort must be a privileged port") } - d.SetOpcode(dhcpv4.OpcodeBootRequest) - d.SetHwType(ack.HwType()) - d.SetHwAddrLen(ack.HwAddrLen()) - clientHwAddr := ack.ClientHwAddr() - d.SetClientHwAddr(clientHwAddr[:]) - d.SetTransactionID(ack.TransactionID()) + d.OpCode = dhcpv4.OpcodeBootRequest + d.HWType = ack.HWType + d.ClientHWAddr = ack.ClientHWAddr + d.TransactionID = ack.TransactionID if ack.IsBroadcast() { d.SetBroadcast() } else { @@ -209,11 +207,10 @@ func NewReplyForInformList(inform *dhcpv4.DHCPv4, config ReplyConfig) (*dhcpv4.D if err != nil { return nil, err } - reply.SetClientIPAddr(inform.ClientIPAddr()) - reply.SetYourIPAddr(net.IPv4zero) - reply.SetGatewayIPAddr(inform.GatewayIPAddr()) - reply.SetServerIPAddr(config.ServerIP) - reply.SetServerHostName([]byte(config.ServerHostname)) + reply.ClientIPAddr = inform.ClientIPAddr + reply.GatewayIPAddr = inform.GatewayIPAddr + reply.ServerIPAddr = config.ServerIP + reply.ServerHostName = config.ServerHostname reply.AddOption(&dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeAck}) reply.AddOption(&dhcpv4.OptServerIdentifier{ServerID: config.ServerIP}) @@ -249,12 +246,11 @@ func NewReplyForInformSelect(inform *dhcpv4.DHCPv4, config ReplyConfig) (*dhcpv4 return nil, err } - reply.SetClientIPAddr(inform.ClientIPAddr()) - reply.SetYourIPAddr(net.IPv4zero) - reply.SetGatewayIPAddr(inform.GatewayIPAddr()) - reply.SetServerIPAddr(config.ServerIP) - reply.SetServerHostName([]byte(config.ServerHostname)) - reply.SetBootFileName([]byte(config.BootFileName)) + reply.ClientIPAddr = inform.ClientIPAddr + reply.GatewayIPAddr = inform.GatewayIPAddr + reply.ServerIPAddr = config.ServerIP + reply.ServerHostName = config.ServerHostname + reply.BootFileName = config.BootFileName reply.AddOption(&dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeAck}) reply.AddOption(&dhcpv4.OptServerIdentifier{ServerID: config.ServerIP}) diff --git a/dhcpv4/bsdp/bsdp_test.go b/dhcpv4/bsdp/bsdp_test.go index 610b8d6..cb703da 100644 --- a/dhcpv4/bsdp/bsdp_test.go +++ b/dhcpv4/bsdp/bsdp_test.go @@ -103,20 +103,19 @@ func TestNewInformList_ReplyPort(t *testing.T) { require.Equal(t, replyPort, opt.(*OptReplyPort).Port) } -func newAck(hwAddr []byte, transactionID uint32) *dhcpv4.DHCPv4 { +func newAck(hwAddr net.HardwareAddr, transactionID [4]byte) *dhcpv4.DHCPv4 { ack, _ := dhcpv4.New() - ack.SetTransactionID(transactionID) - ack.SetHwType(iana.HwTypeEthernet) - ack.SetClientHwAddr(hwAddr) - ack.SetHwAddrLen(uint8(len(hwAddr))) + ack.TransactionID = transactionID + ack.HWType = iana.HwTypeEthernet + ack.ClientHWAddr = hwAddr ack.AddOption(&dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeAck}) ack.AddOption(&dhcpv4.OptionGeneric{OptionCode: dhcpv4.OptionEnd}) return ack } func TestInformSelectForAck_Broadcast(t *testing.T) { - hwAddr := []byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66} - tid := uint32(22) + hwAddr := net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66} + tid := [4]byte{0x22, 0, 0, 0} serverID := net.IPv4(1, 2, 3, 4) bootImage := BootImage{ ID: BootImageID{ @@ -132,10 +131,10 @@ func TestInformSelectForAck_Broadcast(t *testing.T) { m, err := InformSelectForAck(*ack, 0, bootImage) require.NoError(t, err) - require.Equal(t, dhcpv4.OpcodeBootRequest, m.Opcode()) - require.Equal(t, ack.HwType(), m.HwType()) - require.Equal(t, ack.ClientHwAddr(), m.ClientHwAddr()) - require.Equal(t, ack.TransactionID(), m.TransactionID()) + require.Equal(t, dhcpv4.OpcodeBootRequest, m.OpCode) + require.Equal(t, ack.HWType, m.HWType) + require.Equal(t, ack.ClientHWAddr, m.ClientHWAddr) + require.Equal(t, ack.TransactionID, m.TransactionID) require.True(t, m.IsBroadcast()) // Validate options. @@ -157,8 +156,8 @@ func TestInformSelectForAck_Broadcast(t *testing.T) { } func TestInformSelectForAck_NoServerID(t *testing.T) { - hwAddr := []byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66} - tid := uint32(22) + hwAddr := net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66} + tid := [4]byte{0x22, 0, 0, 0} bootImage := BootImage{ ID: BootImageID{ IsInstall: true, @@ -174,8 +173,8 @@ func TestInformSelectForAck_NoServerID(t *testing.T) { } func TestInformSelectForAck_BadReplyPort(t *testing.T) { - hwAddr := []byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66} - tid := uint32(22) + hwAddr := net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66} + tid := [4]byte{0x22, 0, 0, 0} serverID := net.IPv4(1, 2, 3, 4) bootImage := BootImage{ ID: BootImageID{ @@ -194,8 +193,8 @@ func TestInformSelectForAck_BadReplyPort(t *testing.T) { } func TestInformSelectForAck_ReplyPort(t *testing.T) { - hwAddr := []byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66} - tid := uint32(22) + hwAddr := net.HardwareAddr{0x11, 0x22, 0x33, 0x44, 0x55, 0x66} + tid := [4]byte{0x22, 0, 0, 0} serverID := net.IPv4(1, 2, 3, 4) bootImage := BootImage{ ID: BootImageID{ @@ -272,9 +271,9 @@ func TestNewReplyForInformList(t *testing.T) { } ack, err := NewReplyForInformList(inform, config) require.NoError(t, err) - require.Equal(t, net.IP{1, 2, 3, 4}, ack.ClientIPAddr()) - require.Equal(t, net.IPv4zero, ack.YourIPAddr()) - require.Equal(t, "bsdp.foo.com", ack.ServerHostName()) + require.Equal(t, net.IP{1, 2, 3, 4}, ack.ClientIPAddr) + require.Equal(t, net.IPv4zero, ack.YourIPAddr) + require.Equal(t, "bsdp.foo.com", ack.ServerHostName) // Validate options. RequireHasOption(t, ack, &dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeAck}) @@ -283,7 +282,7 @@ func TestNewReplyForInformList(t *testing.T) { require.NotNil(t, ack.GetOneOption(dhcpv4.OptionVendorSpecificInformation)) // Ensure options terminated with End option. - require.Equal(t, &dhcpv4.OptionGeneric{OptionCode: dhcpv4.OptionEnd}, ack.Options()[len(ack.Options())-1]) + require.Equal(t, &dhcpv4.OptionGeneric{OptionCode: dhcpv4.OptionEnd}, ack.Options[len(ack.Options)-1]) // Vendor-specific options. vendorOpts := ack.GetOneOption(dhcpv4.OptionVendorSpecificInformation).(*OptVendorSpecificInformation) @@ -353,9 +352,9 @@ func TestNewReplyForInformSelect(t *testing.T) { } ack, err := NewReplyForInformSelect(inform, config) require.NoError(t, err) - require.Equal(t, net.IP{1, 2, 3, 4}, ack.ClientIPAddr()) - require.Equal(t, net.IPv4zero, ack.YourIPAddr()) - require.Equal(t, "bsdp.foo.com", ack.ServerHostName()) + require.Equal(t, net.IP{1, 2, 3, 4}, ack.ClientIPAddr) + require.Equal(t, net.IPv4zero, ack.YourIPAddr) + require.Equal(t, "bsdp.foo.com", ack.ServerHostName) // Validate options. RequireHasOption(t, ack, &dhcpv4.OptMessageType{MessageType: dhcpv4.MessageTypeAck}) @@ -365,7 +364,7 @@ func TestNewReplyForInformSelect(t *testing.T) { require.NotNil(t, ack.GetOneOption(dhcpv4.OptionVendorSpecificInformation)) // Ensure options are terminated with End option. - require.Equal(t, &dhcpv4.OptionGeneric{OptionCode: dhcpv4.OptionEnd}, ack.Options()[len(ack.Options())-1]) + require.Equal(t, &dhcpv4.OptionGeneric{OptionCode: dhcpv4.OptionEnd}, ack.Options[len(ack.Options)-1]) vendorOpts := ack.GetOneOption(dhcpv4.OptionVendorSpecificInformation).(*OptVendorSpecificInformation) RequireHasOption(t, vendorOpts, &OptMessageType{Type: MessageTypeSelect}) diff --git a/dhcpv4/bsdp/client.go b/dhcpv4/bsdp/client.go index dd9cbcd..dd4a0a0 100644 --- a/dhcpv4/bsdp/client.go +++ b/dhcpv4/bsdp/client.go @@ -19,7 +19,7 @@ func NewClient() *Client { } func castVendorOpt(ack *dhcpv4.DHCPv4) { - opts := ack.Options() + opts := ack.Options for i := 0; i < len(opts); i++ { if opts[i].Code() == dhcpv4.OptionVendorSpecificInformation { vendorOpt, err := ParseOptVendorSpecificInformation(opts[i].ToBytes()) diff --git a/dhcpv4/client.go b/dhcpv4/client.go index 8ec1490..d2d18db 100644 --- a/dhcpv4/client.go +++ b/dhcpv4/client.go @@ -331,11 +331,11 @@ func (c *Client) SendReceive(sendFd, recvFd int, packet *DHCPv4, messageType Mes return } // check that this is a response to our message - if response.TransactionID() != packet.TransactionID() { + if response.TransactionID != packet.TransactionID { continue } // wait for a response message - if response.Opcode() != OpcodeBootReply { + if response.OpCode != OpcodeBootReply { continue } // if we are not requested to wait for a specific message type, @@ -344,7 +344,7 @@ func (c *Client) SendReceive(sendFd, recvFd int, packet *DHCPv4, messageType Mes break } // break if it's a reply of the desired type, continue otherwise - if response.MessageType() != nil && *response.MessageType() == messageType { + if response.MessageType() == messageType { break } } diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index 35e1917..b24502f 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -2,6 +2,7 @@ package dhcpv4 import ( "crypto/rand" + "encoding/binary" "errors" "fmt" "log" @@ -32,20 +33,20 @@ var magicCookie = [4]byte{99, 130, 83, 99} // DHCPv4 represents a DHCPv4 packet header and options. See the New* functions // to build DHCPv4 packets. type DHCPv4 struct { - opcode OpcodeType - hwType iana.HwTypeType - hopCount uint8 - transactionID TransactionID - numSeconds uint16 - flags uint16 - clientIPAddr net.IP - yourIPAddr net.IP - serverIPAddr net.IP - gatewayIPAddr net.IP - clientHwAddr net.HardwareAddr - serverHostName string - bootFileName string - options []Option + OpCode OpcodeType + HWType iana.HwTypeType + HopCount uint8 + TransactionID TransactionID + NumSeconds uint16 + Flags uint16 + ClientIPAddr net.IP + YourIPAddr net.IP + ServerIPAddr net.IP + GatewayIPAddr net.IP + ClientHWAddr net.HardwareAddr + ServerHostName string + BootFileName string + Options []Option } // Modifier defines the signature for functions that can modify DHCPv4 @@ -112,18 +113,17 @@ func New() (*DHCPv4, error) { return nil, err } d := DHCPv4{ - opcode: OpcodeBootRequest, - hwType: iana.HwTypeEthernet, - hopCount: 0, - transactionID: xid, - numSeconds: 0, - flags: 0, - clientHwAddr: net.HardwareAddr{0, 0, 0, 0, 0, 0}, - clientIPAddr: net.IPv4zero, - yourIPAddr: net.IPv4zero, - serverIPAddr: net.IPv4zero, - gatewayIPAddr: net.IPv4zero, - options: make([]Option, 0, 10), + OpCode: OpcodeBootRequest, + HWType: iana.HwTypeEthernet, + HopCount: 0, + TransactionID: xid, + NumSeconds: 0, + Flags: 0, + ClientIPAddr: net.IPv4zero, + YourIPAddr: net.IPv4zero, + ServerIPAddr: net.IPv4zero, + GatewayIPAddr: net.IPv4zero, + Options: make([]Option, 0, 10), } // the End option has to be added explicitly d.AddOption(&OptionGeneric{OptionCode: OptionEnd}) @@ -149,9 +149,9 @@ func NewDiscovery(hwaddr net.HardwareAddr) (*DHCPv4, error) { return nil, err } // get hw addr - d.SetOpcode(OpcodeBootRequest) - d.SetHwType(iana.HwTypeEthernet) - d.SetClientHwAddr(hwaddr) + d.OpCode = OpcodeBootRequest + d.HWType = iana.HwTypeEthernet + d.ClientHWAddr = hwaddr d.SetBroadcast() d.AddOption(&OptMessageType{MessageType: MessageTypeDiscover}) d.AddOption(&OptParameterRequestList{ @@ -202,10 +202,10 @@ func NewInform(hwaddr net.HardwareAddr, localIP net.IP) (*DHCPv4, error) { return nil, err } - d.SetOpcode(OpcodeBootRequest) - d.SetHwType(iana.HwTypeEthernet) - d.SetClientHwAddr(hwaddr) - d.SetClientIPAddr(localIP) + d.OpCode = OpcodeBootRequest + d.HWType = iana.HwTypeEthernet + d.ClientHWAddr = hwaddr + d.ClientIPAddr = localIP d.AddOption(&OptMessageType{MessageType: MessageTypeInform}) return d, nil } @@ -216,10 +216,10 @@ func NewRequestFromOffer(offer *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) if err != nil { return nil, err } - d.SetOpcode(OpcodeBootRequest) - d.SetHwType(offer.HwType()) - d.SetClientHwAddr(offer.ClientHwAddr()) - d.SetTransactionID(offer.TransactionID()) + d.OpCode = OpcodeBootRequest + d.HWType = offer.HWType + d.ClientHWAddr = offer.ClientHWAddr + d.TransactionID = offer.TransactionID if offer.IsBroadcast() { d.SetBroadcast() } else { @@ -227,7 +227,7 @@ func NewRequestFromOffer(offer *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) } // find server IP address var serverIP []byte - for _, opt := range offer.options { + for _, opt := range offer.Options { if opt.Code() == OptionServerIdentifier { serverIP = opt.(*OptServerIdentifier).ServerID } @@ -235,9 +235,9 @@ func NewRequestFromOffer(offer *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) if serverIP == nil { return nil, errors.New("Missing Server IP Address in DHCP Offer") } - d.SetServerIPAddr(serverIP) + d.ServerIPAddr = serverIP d.AddOption(&OptMessageType{MessageType: MessageTypeRequest}) - d.AddOption(&OptRequestedIPAddress{RequestedAddr: offer.YourIPAddr()}) + d.AddOption(&OptRequestedIPAddress{RequestedAddr: offer.YourIPAddr}) d.AddOption(&OptServerIdentifier{ServerID: serverIP}) for _, mod := range modifiers { d = mod(d) @@ -251,12 +251,12 @@ func NewReplyFromRequest(request *DHCPv4, modifiers ...Modifier) (*DHCPv4, error if err != nil { return nil, err } - reply.SetOpcode(OpcodeBootReply) - reply.SetHwType(request.HwType()) - reply.SetClientHwAddr(request.ClientHwAddr()) - reply.SetTransactionID(request.TransactionID()) - reply.SetFlags(request.Flags()) - reply.SetGatewayIPAddr(request.GatewayIPAddr()) + reply.OpCode = OpcodeBootReply + reply.HWType = request.HWType + reply.ClientHWAddr = request.ClientHWAddr + reply.TransactionID = request.TransactionID + reply.Flags = request.Flags + reply.GatewayIPAddr = request.GatewayIPAddr for _, mod := range modifiers { reply = mod(reply) } @@ -269,28 +269,28 @@ func FromBytes(q []byte) (*DHCPv4, error) { var p DHCPv4 buf := uio.NewBigEndianBuffer(q) - p.opcode = OpcodeType(buf.Read8()) - p.hwType = iana.HwTypeType(buf.Read8()) - hwAddrLen := buf.Read8() - p.hopCount = buf.Read8() + p.OpCode = OpcodeType(buf.Read8()) + p.HWType = iana.HwTypeType(buf.Read8()) - buf.ReadBytes(p.transactionID[:]) + hwAddrLen := buf.Read8() - p.numSeconds = buf.Read16() - p.flags = buf.Read16() + p.HopCount = buf.Read8() + buf.ReadBytes(p.TransactionID[:]) + p.NumSeconds = buf.Read16() + p.Flags = buf.Read16() - p.clientIPAddr = net.IP(buf.CopyN(net.IPv4len)) - p.yourIPAddr = net.IP(buf.CopyN(net.IPv4len)) - p.serverIPAddr = net.IP(buf.CopyN(net.IPv4len)) - p.gatewayIPAddr = net.IP(buf.CopyN(net.IPv4len)) + p.ClientIPAddr = net.IP(buf.CopyN(net.IPv4len)) + p.YourIPAddr = net.IP(buf.CopyN(net.IPv4len)) + p.ServerIPAddr = net.IP(buf.CopyN(net.IPv4len)) + p.GatewayIPAddr = net.IP(buf.CopyN(net.IPv4len)) - if hwAddrLen > maxHWAddrLen { - hwAddrLen = maxHWAddrLen + if hwAddrLen > 16 { + hwAddrLen = 16 } - // Always read 16 bytes, but only use hwAddrLen of them. - p.clientHwAddr = make(net.HardwareAddr, maxHWAddrLen) - buf.ReadBytes(p.clientHwAddr) - p.clientHwAddr = p.clientHwAddr[:hwAddrLen] + // Always read 16 bytes, but only use hwaddrlen of them. + p.ClientHWAddr = make(net.HardwareAddr, 16) + buf.ReadBytes(p.ClientHWAddr) + p.ClientHWAddr = p.ClientHWAddr[:hwAddrLen] var sname [64]byte buf.ReadBytes(sname[:]) @@ -298,7 +298,7 @@ func FromBytes(q []byte) (*DHCPv4, error) { if length == -1 { length = 64 } - p.serverHostName = string(sname[:length]) + p.ServerHostName = string(sname[:length]) var file [128]byte buf.ReadBytes(file[:]) @@ -306,7 +306,7 @@ func FromBytes(q []byte) (*DHCPv4, error) { if length == -1 { length = 128 } - p.bootFileName = string(file[:length]) + p.BootFileName = string(file[:length]) var cookie [4]byte buf.ReadBytes(cookie[:]) @@ -322,88 +322,10 @@ func FromBytes(q []byte) (*DHCPv4, error) { if err != nil { return nil, err } - p.options = opts + p.Options = opts return &p, nil } -// Opcode returns the OpcodeType for the packet, -func (d *DHCPv4) Opcode() OpcodeType { - return d.opcode -} - -// OpcodeToString returns the mnemonic name for the packet's opcode. -func (d *DHCPv4) OpcodeToString() string { - return d.opcode.String() -} - -// SetOpcode sets a new opcode for the packet. It prints a warning if the opcode -// is unknown, but does not generate an error. -func (d *DHCPv4) SetOpcode(opcode OpcodeType) { - if _, ok := OpcodeToString[opcode]; !ok { - log.Printf("Warning: unknown DHCPv4 opcode: %v", opcode) - } - d.opcode = opcode -} - -// HwType returns the hardware type as defined by IANA. -func (d *DHCPv4) HwType() iana.HwTypeType { - return d.hwType -} - -// HwTypeToString returns the mnemonic name for the hardware type, e.g. -// "Ethernet". If the type is unknown, it returns "Unknown". -func (d *DHCPv4) HwTypeToString() string { - return d.hwType.String() -} - -// SetHwType returns the hardware type as defined by IANA. -func (d *DHCPv4) SetHwType(hwType iana.HwTypeType) { - if _, ok := iana.HwTypeToString[hwType]; !ok { - log.Printf("Warning: Invalid DHCPv4 hwtype: %v", hwType) - } - d.hwType = hwType -} - -// HopCount returns the hop count field. -func (d *DHCPv4) HopCount() uint8 { - return d.hopCount -} - -// SetHopCount sets the hop count value. -func (d *DHCPv4) SetHopCount(hopCount uint8) { - d.hopCount = hopCount -} - -// TransactionID returns the transaction ID as 32 bit unsigned integer. -func (d *DHCPv4) TransactionID() TransactionID { - return d.transactionID -} - -// SetTransactionID sets the value for the transaction ID. -func (d *DHCPv4) SetTransactionID(xid TransactionID) { - d.transactionID = xid -} - -// NumSeconds returns the number of seconds. -func (d *DHCPv4) NumSeconds() uint16 { - return d.numSeconds -} - -// SetNumSeconds sets the seconds field. -func (d *DHCPv4) SetNumSeconds(numSeconds uint16) { - d.numSeconds = numSeconds -} - -// Flags returns the DHCP flags portion of the packet. -func (d *DHCPv4) Flags() uint16 { - return d.flags -} - -// SetFlags sets the flags field in the packet. -func (d *DHCPv4) SetFlags(flags uint16) { - d.flags = flags -} - // FlagsToString returns a human-readable representation of the flags field. func (d *DHCPv4) FlagsToString() string { flags := "" @@ -412,7 +334,7 @@ func (d *DHCPv4) FlagsToString() string { } else { flags += "Unicast" } - if d.flags&0xfe != 0 { + if d.Flags&0xfe != 0 { flags += " (reserved bits not zeroed)" } return flags @@ -420,108 +342,22 @@ func (d *DHCPv4) FlagsToString() string { // IsBroadcast indicates whether the packet is a broadcast packet. func (d *DHCPv4) IsBroadcast() bool { - return d.flags&0x8000 == 0x8000 + return d.Flags&0x8000 == 0x8000 } // SetBroadcast sets the packet to be a broadcast packet. func (d *DHCPv4) SetBroadcast() { - d.flags |= 0x8000 + d.Flags |= 0x8000 } // IsUnicast indicates whether the packet is a unicast packet. func (d *DHCPv4) IsUnicast() bool { - return d.flags&0x8000 == 0 + return d.Flags&0x8000 == 0 } // SetUnicast sets the packet to be a unicast packet. func (d *DHCPv4) SetUnicast() { - d.flags &= ^uint16(0x8000) -} - -// ClientIPAddr returns the client IP address. -func (d *DHCPv4) ClientIPAddr() net.IP { - return d.clientIPAddr -} - -// SetClientIPAddr sets the client IP address. -func (d *DHCPv4) SetClientIPAddr(clientIPAddr net.IP) { - d.clientIPAddr = clientIPAddr -} - -// YourIPAddr returns the "your IP address" field. -func (d *DHCPv4) YourIPAddr() net.IP { - return d.yourIPAddr -} - -// SetYourIPAddr sets the "your IP address" field. -func (d *DHCPv4) SetYourIPAddr(yourIPAddr net.IP) { - d.yourIPAddr = yourIPAddr -} - -// ServerIPAddr returns the server IP address. -func (d *DHCPv4) ServerIPAddr() net.IP { - return d.serverIPAddr -} - -// SetServerIPAddr sets the server IP address. -func (d *DHCPv4) SetServerIPAddr(serverIPAddr net.IP) { - d.serverIPAddr = serverIPAddr -} - -// GatewayIPAddr returns the gateway IP address. -func (d *DHCPv4) GatewayIPAddr() net.IP { - return d.gatewayIPAddr -} - -// SetGatewayIPAddr sets the gateway IP address. -func (d *DHCPv4) SetGatewayIPAddr(gatewayIPAddr net.IP) { - d.gatewayIPAddr = gatewayIPAddr -} - -// ClientHwAddr returns the client hardware (MAC) address. -func (d *DHCPv4) ClientHwAddr() net.HardwareAddr { - return d.clientHwAddr -} - -// ClientHwAddrToString converts the hardware address field to a string. -func (d *DHCPv4) ClientHwAddrToString() string { - return d.clientHwAddr.String() -} - -// SetClientHwAddr sets the client hardware address. -func (d *DHCPv4) SetClientHwAddr(clientHwAddr net.HardwareAddr) { - if len(clientHwAddr) > maxHWAddrLen { - log.Printf("Warning: too long HW Address (%d bytes), truncating to 16 bytes", len(clientHwAddr)) - clientHwAddr = clientHwAddr[:maxHWAddrLen] - } - d.clientHwAddr = clientHwAddr -} - -// ServerHostName returns the server host name as a sequence of bytes. -func (d *DHCPv4) ServerHostName() string { - return d.serverHostName -} - -// SetServerHostName replaces the server host name, from a sequence of bytes, -// truncating it to the maximum length of 64. -func (d *DHCPv4) SetServerHostName(serverHostName string) { - d.serverHostName = serverHostName -} - -// BootFileName returns the boot file name as a sequence of bytes. -func (d *DHCPv4) BootFileName() string { - return d.bootFileName -} - -// SetBootFileName replaces the boot file name, from a sequence of bytes, -// truncating it to the maximum length oh 128. -func (d *DHCPv4) SetBootFileName(bootFileName string) { - d.bootFileName = bootFileName -} - -// Options returns the DHCPv4 options defined for the packet. -func (d *DHCPv4) Options() []Option { - return d.options + d.Flags &= ^uint16(0x8000) } // GetOption will attempt to get all options that match a DHCPv4 option @@ -529,7 +365,7 @@ func (d *DHCPv4) Options() []Option { // empty list. func (d *DHCPv4) GetOption(code OptionCode) []Option { opts := []Option{} - for _, opt := range d.Options() { + for _, opt := range d.Options { if opt.Code() == code { opts = append(opts, opt) } @@ -541,7 +377,7 @@ func (d *DHCPv4) GetOption(code OptionCode) []Option { // If there are multiple options with the same OptionCode it will only return // the first one found. If no matching option is found nil will be returned. func (d *DHCPv4) GetOneOption(code OptionCode) Option { - for _, opt := range d.Options() { + for _, opt := range d.Options { if opt.Code() == code { return opt } @@ -555,7 +391,7 @@ func (d *DHCPv4) StrippedOptions() []Option { // differently from Options() this function strips away anything coming // after the End option (normally just Pad options). strippedOptions := []Option{} - for _, opt := range d.options { + for _, opt := range d.Options { strippedOptions = append(strippedOptions, opt) if opt.Code() == OptionEnd { break @@ -564,30 +400,25 @@ func (d *DHCPv4) StrippedOptions() []Option { return strippedOptions } -// SetOptions replaces the current options with the provided ones. -func (d *DHCPv4) SetOptions(options []Option) { - d.options = options -} - // AddOption appends an option to the existing ones. If the last option is an // OptionEnd, it will be inserted before that. It does not deal with End // options that appead before the end, like in malformed packets. func (d *DHCPv4) AddOption(option Option) { - if len(d.options) == 0 || d.options[len(d.options)-1].Code() != OptionEnd { - d.options = append(d.options, option) + if len(d.Options) == 0 || d.Options[len(d.Options)-1].Code() != OptionEnd { + d.Options = append(d.Options, option) } else { - end := d.options[len(d.options)-1] - d.options[len(d.options)-1] = option - d.options = append(d.options, end) + end := d.Options[len(d.Options)-1] + d.Options[len(d.Options)-1] = option + d.Options = append(d.Options, end) } } // UpdateOption updates the existing options with the passed option, adding it // at the end if not present already func (d *DHCPv4) UpdateOption(option Option) { - for idx, opt := range d.options { + for idx, opt := range d.Options { if opt.Code() == option.Code() { - d.options[idx] = option + d.Options[idx] = option // don't look further return } @@ -598,17 +429,23 @@ func (d *DHCPv4) UpdateOption(option Option) { // MessageType returns the message type, trying to extract it from the // OptMessageType option. It returns nil if the message type cannot be extracted -func (d *DHCPv4) MessageType() *MessageType { +func (d *DHCPv4) MessageType() MessageType { opt := d.GetOneOption(OptionDHCPMessageType) if opt == nil { - return nil + return MessageTypeNone } - return &(opt.(*OptMessageType).MessageType) + return opt.(*OptMessageType).MessageType } +// HumanXID returns a human-readably integer transaction ID. +func (d *DHCPv4) HumanXID() uint32 { + return binary.LittleEndian.Uint32(d.TransactionID[:]) +} + +// String implements fmt.Stringer. func (d *DHCPv4) String() string { - return fmt.Sprintf("DHCPv4(opcode=%v xid=%d hwtype=%v hwaddr=%v)", - d.OpcodeToString(), d.TransactionID(), d.HwTypeToString(), d.ClientHwAddr()) + return fmt.Sprintf("DHCPv4(opcode=%s xid=%v hwtype=%s hwaddr=%s)", + d.OpCode.String(), d.HumanXID(), d.HWType, d.ClientHWAddr) } // Summary prints detailed information about the packet. @@ -628,23 +465,23 @@ func (d *DHCPv4) Summary() string { " clienthwaddr=%v\n"+ " serverhostname=%v\n"+ " bootfilename=%v\n", - d.OpcodeToString(), - d.HwTypeToString(), - d.HopCount(), - d.TransactionID(), - d.NumSeconds(), + d.OpCode, + d.HWType, + d.HopCount, + d.HumanXID(), + d.NumSeconds, d.FlagsToString(), - d.Flags(), - d.ClientIPAddr(), - d.YourIPAddr(), - d.ServerIPAddr(), - d.GatewayIPAddr(), - d.ClientHwAddrToString(), - d.ServerHostName(), - d.BootFileName(), + d.Flags, + d.ClientIPAddr, + d.YourIPAddr, + d.ServerIPAddr, + d.GatewayIPAddr, + d.ClientHWAddr, + d.ServerHostName, + d.BootFileName, ) ret += " options=\n" - for _, opt := range d.options { + for _, opt := range d.Options { optString := opt.String() // If this option has sub structures, offset them accordingly. if strings.Contains(optString, "\n") { @@ -663,7 +500,7 @@ func (d *DHCPv4) Summary() string { func (d *DHCPv4) ValidateOptions() { // TODO find duplicate options foundOptionEnd := false - for _, opt := range d.options { + for _, opt := range d.Options { if foundOptionEnd { if opt.Code() == OptionEnd { log.Print("Warning: found duplicate End option") @@ -704,44 +541,43 @@ func writeIP(b *uio.Lexer, ip net.IP) { } } -// ToBytes encodes a DHCPv4 structure into a sequence of bytes in its wire -// format. +// ToBytes writes the packet to binary. func (d *DHCPv4) ToBytes() []byte { buf := uio.NewBigEndianBuffer(make([]byte, 0, minPacketLen)) - buf.Write8(uint8(d.opcode)) - buf.Write8(uint8(d.hwType)) + buf.Write8(uint8(d.OpCode)) + buf.Write8(uint8(d.HWType)) // HwAddrLen - hlen := uint8(len(d.clientHwAddr)) - if hlen == 0 && d.hwType == iana.HwTypeEthernet { + hlen := uint8(len(d.ClientHWAddr)) + if hlen == 0 && d.HWType == iana.HwTypeEthernet { hlen = 6 } buf.Write8(hlen) + buf.Write8(d.HopCount) + buf.WriteBytes(d.TransactionID[:]) + buf.Write16(d.NumSeconds) + buf.Write16(d.Flags) - buf.Write8(d.hopCount) - buf.WriteBytes(d.transactionID[:]) - buf.Write16(d.numSeconds) - buf.Write16(d.flags) - - writeIP(buf, d.clientIPAddr[:]) - writeIP(buf, d.yourIPAddr[:]) - writeIP(buf, d.serverIPAddr[:]) - writeIP(buf, d.gatewayIPAddr[:]) - - copy(buf.WriteN(maxHWAddrLen), d.clientHwAddr) + writeIP(buf, d.ClientIPAddr) + writeIP(buf, d.YourIPAddr) + writeIP(buf, d.ServerIPAddr) + writeIP(buf, d.GatewayIPAddr) + copy(buf.WriteN(16), d.ClientHWAddr) var sname [64]byte - copy(sname[:], []byte(d.serverHostName)) + copy(sname[:], []byte(d.ServerHostName)) + sname[len(d.ServerHostName)] = 0 buf.WriteBytes(sname[:]) var file [128]byte - copy(file[:], []byte(d.bootFileName)) + copy(file[:], []byte(d.BootFileName)) + file[len(d.BootFileName)] = 0 buf.WriteBytes(file[:]) // The magic cookie. buf.WriteBytes(magicCookie[:]) - for _, opt := range d.options { + for _, opt := range d.Options { buf.WriteBytes(opt.ToBytes()) } return buf.Data() diff --git a/dhcpv4/dhcpv4_test.go b/dhcpv4/dhcpv4_test.go index 893377d..f6f9d7c 100644 --- a/dhcpv4/dhcpv4_test.go +++ b/dhcpv4/dhcpv4_test.go @@ -61,18 +61,18 @@ func TestFromBytes(t *testing.T) { d, err := FromBytes(data) require.NoError(t, err) - require.Equal(t, d.Opcode(), OpcodeBootRequest) - require.Equal(t, d.HwType(), iana.HwTypeEthernet) - require.Equal(t, d.HopCount(), byte(3)) - require.Equal(t, d.TransactionID(), TransactionID{0xaa, 0xbb, 0xcc, 0xdd}) - require.Equal(t, d.NumSeconds(), uint16(3)) - require.Equal(t, d.Flags(), uint16(1)) - require.True(t, d.ClientIPAddr().Equal(net.IPv4zero)) - require.True(t, d.YourIPAddr().Equal(net.IPv4zero)) - require.True(t, d.GatewayIPAddr().Equal(net.IPv4zero)) - require.Equal(t, d.ClientHwAddr(), net.HardwareAddr{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}) - require.Equal(t, d.ServerHostName(), "") - require.Equal(t, d.BootFileName(), "") + require.Equal(t, d.OpCode, OpcodeBootRequest) + require.Equal(t, d.HWType, iana.HwTypeEthernet) + require.Equal(t, d.HopCount, byte(3)) + require.Equal(t, d.TransactionID, TransactionID{0xaa, 0xbb, 0xcc, 0xdd}) + require.Equal(t, d.NumSeconds, uint16(3)) + require.Equal(t, d.Flags, uint16(1)) + require.True(t, d.ClientIPAddr.Equal(net.IPv4zero)) + require.True(t, d.YourIPAddr.Equal(net.IPv4zero)) + require.True(t, d.GatewayIPAddr.Equal(net.IPv4zero)) + require.Equal(t, d.ClientHWAddr, net.HardwareAddr{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}) + require.Equal(t, d.ServerHostName, "") + require.Equal(t, d.BootFileName, "") // no need to check Magic Cookie as it is already validated in FromBytes // above } @@ -118,143 +118,19 @@ func TestFromBytesInvalidOptions(t *testing.T) { require.Error(t, err) } -func TestSettersAndGetters(t *testing.T) { - data := []byte{ - 1, // dhcp request - 1, // ethernet hw type - 6, // hw addr length - 3, // hop count - 0xaa, 0xbb, 0xcc, 0xdd, // transaction ID, big endian (network) - 0, 3, // number of seconds - 0, 1, // broadcast - 1, 2, 3, 4, // client IP address - 5, 6, 7, 8, // your IP address - 9, 10, 11, 12, // server IP address - 13, 14, 15, 16, // gateway IP address - 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // client MAC address + padding - } - // server host name - expectedHostname := []byte{} - for i := 0; i < 64; i++ { - expectedHostname = append(expectedHostname, 0) - } - data = append(data, expectedHostname...) - // boot file name - expectedBootfilename := []byte{} - for i := 0; i < 128; i++ { - expectedBootfilename = append(expectedBootfilename, 0) - } - data = append(data, expectedBootfilename...) - // magic cookie, then no options - data = append(data, []byte{99, 130, 83, 99}...) - d, err := FromBytes(data) - require.NoError(t, err) - - // getter/setter for Opcode - require.Equal(t, OpcodeBootRequest, d.Opcode()) - d.SetOpcode(OpcodeBootReply) - require.Equal(t, OpcodeBootReply, d.Opcode()) - - // getter/setter for HwType - require.Equal(t, iana.HwTypeEthernet, d.HwType()) - d.SetHwType(iana.HwTypeARCNET) - require.Equal(t, iana.HwTypeARCNET, d.HwType()) - - // getter/setter for HopCount - require.Equal(t, uint8(3), d.HopCount()) - d.SetHopCount(1) - require.Equal(t, uint8(1), d.HopCount()) - - // getter/setter for TransactionID - require.Equal(t, TransactionID{0xaa, 0xbb, 0xcc, 0xdd}, d.TransactionID()) - d.SetTransactionID(TransactionID{0xee, 0xff, 0x00, 0x11}) - require.Equal(t, TransactionID{0xee, 0xff, 0x00, 0x11}, d.TransactionID()) - - // getter/setter for TransactionID - require.Equal(t, uint16(3), d.NumSeconds()) - d.SetNumSeconds(15) - require.Equal(t, uint16(15), d.NumSeconds()) - - // getter/setter for Flags - require.Equal(t, uint16(1), d.Flags()) - d.SetFlags(0) - require.Equal(t, uint16(0), d.Flags()) - - // getter/setter for ClientIPAddr - require.True(t, d.ClientIPAddr().Equal(net.IPv4(1, 2, 3, 4))) - d.SetClientIPAddr(net.IPv4(4, 3, 2, 1)) - require.True(t, d.ClientIPAddr().Equal(net.IPv4(4, 3, 2, 1))) - - // getter/setter for YourIPAddr - require.True(t, d.YourIPAddr().Equal(net.IPv4(5, 6, 7, 8))) - d.SetYourIPAddr(net.IPv4(8, 7, 6, 5)) - require.True(t, d.YourIPAddr().Equal(net.IPv4(8, 7, 6, 5))) - - // getter/setter for ServerIPAddr - require.True(t, d.ServerIPAddr().Equal(net.IPv4(9, 10, 11, 12))) - d.SetServerIPAddr(net.IPv4(12, 11, 10, 9)) - require.True(t, d.ServerIPAddr().Equal(net.IPv4(12, 11, 10, 9))) - - // getter/setter for GatewayIPAddr - require.True(t, d.GatewayIPAddr().Equal(net.IPv4(13, 14, 15, 16))) - d.SetGatewayIPAddr(net.IPv4(16, 15, 14, 13)) - require.True(t, d.GatewayIPAddr().Equal(net.IPv4(16, 15, 14, 13))) - - // getter/setter for ClientHwAddr - require.Equal(t, net.HardwareAddr{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, d.ClientHwAddr()) - d.SetFlags(0) - - // getter/setter for ServerHostName - require.Equal(t, "", d.ServerHostName()) - d.SetServerHostName("test") - require.Equal(t, "test", d.ServerHostName()) - - // getter/setter for BootFileName - require.Equal(t, "", d.BootFileName()) - d.SetBootFileName("test") - require.Equal(t, "test", d.BootFileName()) -} - func TestToStringMethods(t *testing.T) { d, err := New() if err != nil { t.Fatal(err) } - // OpcodeToString - d.SetOpcode(OpcodeBootRequest) - require.Equal(t, "BootRequest", d.OpcodeToString()) - d.SetOpcode(OpcodeBootReply) - require.Equal(t, "BootReply", d.OpcodeToString()) - d.SetOpcode(OpcodeType(0)) - require.Equal(t, "Unknown", d.OpcodeToString()) - - // HwTypeToString - d.SetHwType(iana.HwTypeEthernet) - require.Equal(t, "Ethernet", d.HwTypeToString()) - d.SetHwType(iana.HwTypeARCNET) - require.Equal(t, "ARCNET", d.HwTypeToString()) - d.SetHwType(iana.HwTypeType(0)) - require.Equal(t, "Invalid", d.HwTypeToString()) // FlagsToString d.SetUnicast() require.Equal(t, "Unicast", d.FlagsToString()) d.SetBroadcast() require.Equal(t, "Broadcast", d.FlagsToString()) - d.SetFlags(0xffff) + d.Flags = 0xffff require.Equal(t, "Broadcast (reserved bits not zeroed)", d.FlagsToString()) - - // ClientHwAddrToString - d.SetClientHwAddr(net.HardwareAddr{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}) - require.Equal(t, "aa:bb:cc:dd:ee:ff", d.ClientHwAddrToString()) - - // ServerHostNameToString - d.SetServerHostName("my.host.local") - require.Equal(t, "my.host.local", d.ServerHostName()) - - // BootFileNameToString - d.SetBootFileName("/my/boot/file") - require.Equal(t, "/my/boot/file", d.BootFileName()) } func TestNewToBytes(t *testing.T) { @@ -291,7 +167,7 @@ func TestNewToBytes(t *testing.T) { require.NoError(t, err) // fix TransactionID to match the expected one, since it's randomly // generated in New() - d.SetTransactionID(TransactionID{0x11, 0x22, 0x33, 0x44}) + d.TransactionID = TransactionID{0x11, 0x22, 0x33, 0x44} got := d.ToBytes() require.Equal(t, expected, got) } @@ -329,7 +205,7 @@ func TestAddOption(t *testing.T) { d.AddOption(bootFileOpt1) d.AddOption(bootFileOpt2) - options := d.Options() + options := d.Options require.Equal(t, len(options), 4) require.Equal(t, options[3].Code(), OptionEnd) } @@ -337,18 +213,18 @@ func TestAddOption(t *testing.T) { func TestUpdateOption(t *testing.T) { d, err := New() require.NoError(t, err) - require.Equal(t, 1, len(d.options)) - require.Equal(t, OptionEnd, d.options[0].Code()) + require.Equal(t, 1, len(d.Options)) + require.Equal(t, OptionEnd, d.Options[0].Code()) // test that it will add the option since it's missing d.UpdateOption(&OptDomainName{DomainName: "slackware.it"}) - require.Equal(t, 2, len(d.options)) - require.Equal(t, OptionDomainName, d.options[0].Code()) - require.Equal(t, OptionEnd, d.options[1].Code()) + require.Equal(t, 2, len(d.Options)) + require.Equal(t, OptionDomainName, d.Options[0].Code()) + require.Equal(t, OptionEnd, d.Options[1].Code()) // test that it won't add another option of the same type d.UpdateOption(&OptDomainName{DomainName: "slackware.it"}) - require.Equal(t, 2, len(d.options)) - require.Equal(t, OptionDomainName, d.options[0].Code()) - require.Equal(t, OptionEnd, d.options[1].Code()) + require.Equal(t, 2, len(d.Options)) + require.Equal(t, OptionDomainName, d.Options[0].Code()) + require.Equal(t, OptionEnd, d.Options[1].Code()) } func TestStrippedOptions(t *testing.T) { @@ -360,7 +236,7 @@ func TestStrippedOptions(t *testing.T) { &OptClassIdentifier{"something"}, &OptionGeneric{OptionCode: OptionEnd}, } - d.SetOptions(opts) + d.Options = opts stripped := d.StrippedOptions() require.Equal(t, len(opts), len(stripped)) for i := range stripped { @@ -369,7 +245,7 @@ func TestStrippedOptions(t *testing.T) { // Set of options with additional options after OptionEnd opts = append(opts, &OptMaximumDHCPMessageSize{uint16(1234)}) - d.SetOptions(opts) + d.Options = opts stripped = d.StrippedOptions() require.Equal(t, len(opts)-1, len(stripped)) for i := range stripped { @@ -391,8 +267,7 @@ func TestDHCPv4NewRequestFromOffer(t *testing.T) { // Broadcast request req, err = NewRequestFromOffer(offer) require.NoError(t, err) - require.NotNil(t, req.MessageType()) - require.Equal(t, MessageTypeRequest, *req.MessageType()) + require.Equal(t, MessageTypeRequest, req.MessageType()) require.False(t, req.IsUnicast()) require.True(t, req.IsBroadcast()) @@ -412,50 +287,48 @@ func TestDHCPv4NewRequestFromOfferWithModifier(t *testing.T) { userClass := WithUserClass([]byte("linuxboot"), false) req, err := NewRequestFromOffer(offer, userClass) require.NoError(t, err) - require.NotEqual(t, (*MessageType)(nil), *req.MessageType()) - require.Equal(t, MessageTypeRequest, *req.MessageType()) - require.Equal(t, "User Class Information -> linuxboot", req.options[3].String()) + require.Equal(t, MessageTypeRequest, req.MessageType()) + require.Equal(t, "User Class Information -> linuxboot", req.Options[3].String()) } func TestNewReplyFromRequest(t *testing.T) { discover, err := New() require.NoError(t, err) - discover.SetGatewayIPAddr(net.IPv4(192, 168, 0, 1)) + discover.GatewayIPAddr = net.IPv4(192, 168, 0, 1) reply, err := NewReplyFromRequest(discover) require.NoError(t, err) - require.Equal(t, discover.TransactionID(), reply.TransactionID()) - require.Equal(t, discover.GatewayIPAddr(), reply.GatewayIPAddr()) + require.Equal(t, discover.TransactionID, reply.TransactionID) + require.Equal(t, discover.GatewayIPAddr, reply.GatewayIPAddr) } func TestNewReplyFromRequestWithModifier(t *testing.T) { discover, err := New() require.NoError(t, err) - discover.SetGatewayIPAddr(net.IPv4(192, 168, 0, 1)) + discover.GatewayIPAddr = net.IPv4(192, 168, 0, 1) userClass := WithUserClass([]byte("linuxboot"), false) reply, err := NewReplyFromRequest(discover, userClass) require.NoError(t, err) - require.Equal(t, discover.TransactionID(), reply.TransactionID()) - require.Equal(t, discover.GatewayIPAddr(), reply.GatewayIPAddr()) - require.Equal(t, "User Class Information -> linuxboot", reply.options[0].String()) + require.Equal(t, discover.TransactionID, reply.TransactionID) + require.Equal(t, discover.GatewayIPAddr, reply.GatewayIPAddr) + require.Equal(t, "User Class Information -> linuxboot", reply.Options[0].String()) } func TestDHCPv4MessageTypeNil(t *testing.T) { m, err := New() require.NoError(t, err) - require.Nil(t, m.MessageType()) + require.Equal(t, MessageTypeNone, m.MessageType()) } func TestNewDiscovery(t *testing.T) { hwAddr := net.HardwareAddr{1, 2, 3, 4, 5, 6} m, err := NewDiscovery(hwAddr) require.NoError(t, err) - require.NotNil(t, m.MessageType()) - require.Equal(t, MessageTypeDiscover, *m.MessageType()) + require.Equal(t, MessageTypeDiscover, m.MessageType()) // Validate fields of DISCOVER packet. - require.Equal(t, OpcodeBootRequest, m.Opcode()) - require.Equal(t, iana.HwTypeEthernet, m.HwType()) - require.Equal(t, hwAddr, m.ClientHwAddr()) + require.Equal(t, OpcodeBootRequest, m.OpCode) + require.Equal(t, iana.HwTypeEthernet, m.HWType) + require.Equal(t, hwAddr, m.ClientHWAddr) require.True(t, m.IsBroadcast()) require.True(t, HasOption(m, OptionParameterRequestList)) require.True(t, HasOption(m, OptionEnd)) @@ -467,12 +340,11 @@ func TestNewInform(t *testing.T) { m, err := NewInform(hwAddr, localIP) require.NoError(t, err) - require.Equal(t, OpcodeBootRequest, m.Opcode()) - require.Equal(t, iana.HwTypeEthernet, m.HwType()) - require.Equal(t, hwAddr, m.ClientHwAddr()) - require.NotNil(t, m.MessageType()) - require.Equal(t, MessageTypeInform, *m.MessageType()) - require.True(t, m.ClientIPAddr().Equal(localIP)) + require.Equal(t, OpcodeBootRequest, m.OpCode) + require.Equal(t, iana.HwTypeEthernet, m.HWType) + require.Equal(t, hwAddr, m.ClientHWAddr) + require.Equal(t, MessageTypeInform, m.MessageType()) + require.True(t, m.ClientIPAddr.Equal(localIP)) } func TestIsOptionRequested(t *testing.T) { diff --git a/dhcpv4/modifiers.go b/dhcpv4/modifiers.go index ba58884..db67303 100644 --- a/dhcpv4/modifiers.go +++ b/dhcpv4/modifiers.go @@ -9,7 +9,7 @@ import ( // WithTransactionID sets the Transaction ID for the DHCPv4 packet func WithTransactionID(xid TransactionID) Modifier { return func(d *DHCPv4) *DHCPv4 { - d.SetTransactionID(xid) + d.TransactionID = xid return d } } @@ -27,9 +27,9 @@ func WithBroadcast(broadcast bool) Modifier { } // WithHwAddr sets the hardware address for a packet -func WithHwAddr(hwaddr []byte) Modifier { +func WithHwAddr(hwaddr net.HardwareAddr) Modifier { return func(d *DHCPv4) *DHCPv4 { - d.SetClientHwAddr(hwaddr) + d.ClientHWAddr = hwaddr return d } } @@ -111,8 +111,8 @@ func WithRequestedOptions(optionCodes ...OptionCode) Modifier { func WithRelay(ip net.IP) Modifier { return func(d *DHCPv4) *DHCPv4 { d.SetUnicast() - d.SetGatewayIPAddr(ip) - d.SetHopCount(1) + d.GatewayIPAddr = ip + d.HopCount = 1 return d } } diff --git a/dhcpv4/modifiers_test.go b/dhcpv4/modifiers_test.go index f3d8ead..15dbebf 100644 --- a/dhcpv4/modifiers_test.go +++ b/dhcpv4/modifiers_test.go @@ -11,7 +11,7 @@ func TestTransactionIDModifier(t *testing.T) { d, err := New() require.NoError(t, err) d = WithTransactionID(TransactionID{0xdd, 0xcc, 0xbb, 0xaa})(d) - require.Equal(t, TransactionID{0xdd, 0xcc, 0xbb, 0xaa}, d.TransactionID()) + require.Equal(t, TransactionID{0xdd, 0xcc, 0xbb, 0xaa}, d.TransactionID) } func TestBroadcastModifier(t *testing.T) { @@ -30,7 +30,7 @@ func TestHwAddrModifier(t *testing.T) { require.NoError(t, err) hwaddr := net.HardwareAddr{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} d = WithHwAddr(hwaddr)(d) - require.Equal(t, hwaddr, d.ClientHwAddr()) + require.Equal(t, hwaddr, d.ClientHWAddr) } func TestWithOptionModifier(t *testing.T) { @@ -53,8 +53,8 @@ func TestUserClassModifier(t *testing.T) { 9, // length 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } - require.Equal(t, "User Class Information -> linuxboot", d.options[0].String()) - require.Equal(t, expected, d.options[0].ToBytes()) + require.Equal(t, "User Class Information -> linuxboot", d.Options[0].String()) + require.Equal(t, expected, d.Options[0].ToBytes()) } func TestUserClassModifierRFC(t *testing.T) { @@ -66,14 +66,14 @@ func TestUserClassModifierRFC(t *testing.T) { 10, // length 9, 'l', 'i', 'n', 'u', 'x', 'b', 'o', 'o', 't', } - require.Equal(t, "User Class Information -> linuxboot", d.options[0].String()) - require.Equal(t, expected, d.options[0].ToBytes()) + require.Equal(t, "User Class Information -> linuxboot", d.Options[0].String()) + require.Equal(t, expected, d.Options[0].ToBytes()) } func TestWithNetboot(t *testing.T) { d, _ := New() d = WithNetboot(d) - require.Equal(t, "Parameter Request List -> [TFTP Server Name, Bootfile Name]", d.options[0].String()) + require.Equal(t, "Parameter Request List -> [TFTP Server Name, Bootfile Name]", d.Options[0].String()) } func TestWithNetbootExistingTFTP(t *testing.T) { @@ -83,7 +83,7 @@ func TestWithNetbootExistingTFTP(t *testing.T) { } d.AddOption(OptParams) d = WithNetboot(d) - require.Equal(t, "Parameter Request List -> [TFTP Server Name, Bootfile Name]", d.options[0].String()) + require.Equal(t, "Parameter Request List -> [TFTP Server Name, Bootfile Name]", d.Options[0].String()) } func TestWithNetbootExistingBootfileName(t *testing.T) { @@ -93,7 +93,7 @@ func TestWithNetbootExistingBootfileName(t *testing.T) { } d.AddOption(OptParams) d = WithNetboot(d) - require.Equal(t, "Parameter Request List -> [Bootfile Name, TFTP Server Name]", d.options[0].String()) + require.Equal(t, "Parameter Request List -> [Bootfile Name, TFTP Server Name]", d.Options[0].String()) } func TestWithNetbootExistingBoth(t *testing.T) { @@ -103,7 +103,7 @@ func TestWithNetbootExistingBoth(t *testing.T) { } d.AddOption(OptParams) d = WithNetboot(d) - require.Equal(t, "Parameter Request List -> [Bootfile Name, TFTP Server Name]", d.options[0].String()) + require.Equal(t, "Parameter Request List -> [Bootfile Name, TFTP Server Name]", d.Options[0].String()) } func TestWithRequestedOptions(t *testing.T) { @@ -133,34 +133,34 @@ func TestWithRelay(t *testing.T) { d = WithRelay(ip)(d) require.NotNil(t, d) require.True(t, d.IsUnicast(), "expected unicast") - require.Equal(t, ip, d.GatewayIPAddr()) - require.Equal(t, uint8(1), d.HopCount()) + require.Equal(t, ip, d.GatewayIPAddr) + require.Equal(t, uint8(1), d.HopCount) } func TestWithNetmask(t *testing.T) { d := &DHCPv4{} d = WithNetmask(net.IPv4Mask(255, 255, 255, 0))(d) - require.Equal(t, 1, len(d.options)) - require.Equal(t, OptionSubnetMask, d.options[0].Code()) - osm := d.options[0].(*OptSubnetMask) + require.Equal(t, 1, len(d.Options)) + require.Equal(t, OptionSubnetMask, d.Options[0].Code()) + osm := d.Options[0].(*OptSubnetMask) require.Equal(t, net.IPv4Mask(255, 255, 255, 0), osm.SubnetMask) } func TestWithLeaseTime(t *testing.T) { d := &DHCPv4{} d = WithLeaseTime(uint32(3600))(d) - require.Equal(t, 1, len(d.options)) - require.Equal(t, OptionIPAddressLeaseTime, d.options[0].Code()) - olt := d.options[0].(*OptIPAddressLeaseTime) + require.Equal(t, 1, len(d.Options)) + require.Equal(t, OptionIPAddressLeaseTime, d.Options[0].Code()) + olt := d.Options[0].(*OptIPAddressLeaseTime) require.Equal(t, uint32(3600), olt.LeaseTime) } func TestWithDNS(t *testing.T) { d := &DHCPv4{} d = WithDNS(net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2"))(d) - require.Equal(t, 1, len(d.options)) - require.Equal(t, OptionDomainNameServer, d.options[0].Code()) - olt := d.options[0].(*OptDomainNameServer) + require.Equal(t, 1, len(d.Options)) + require.Equal(t, OptionDomainNameServer, d.Options[0].Code()) + olt := d.Options[0].(*OptDomainNameServer) require.Equal(t, 2, len(olt.NameServers)) require.Equal(t, net.ParseIP("10.0.0.1"), olt.NameServers[0]) require.Equal(t, net.ParseIP("10.0.0.2"), olt.NameServers[1]) @@ -170,8 +170,8 @@ func TestWithDNS(t *testing.T) { func TestWithDomainSearchList(t *testing.T) { d := &DHCPv4{} d = WithDomainSearchList("slackware.it", "dhcp.slackware.it")(d) - require.Equal(t, 1, len(d.options)) - osl := d.options[0].(*OptDomainSearch) + require.Equal(t, 1, len(d.Options)) + osl := d.Options[0].(*OptDomainSearch) require.Equal(t, OptionDNSDomainSearchList, osl.Code()) require.NotNil(t, osl.DomainSearch) require.Equal(t, 2, len(osl.DomainSearch.Labels)) @@ -183,8 +183,8 @@ func TestWithRouter(t *testing.T) { d := &DHCPv4{} rtr := net.ParseIP("10.0.0.254") d = WithRouter(rtr)(d) - require.Equal(t, 1, len(d.options)) - ortr := d.options[0].(*OptRouter) + require.Equal(t, 1, len(d.Options)) + ortr := d.Options[0].(*OptRouter) require.Equal(t, OptionRouter, ortr.Code()) require.Equal(t, 1, len(ortr.Routers)) require.Equal(t, rtr, ortr.Routers[0]) diff --git a/dhcpv4/server_test.go b/dhcpv4/server_test.go index 454a113..bb2ee11 100644 --- a/dhcpv4/server_test.go +++ b/dhcpv4/server_test.go @@ -32,7 +32,7 @@ func DORAHandler(conn net.PacketConn, peer net.Addr, m *DHCPv4) { log.Printf("Packet is nil!") return } - if m.Opcode() != OpcodeBootRequest { + if m.OpCode != OpcodeBootRequest { log.Printf("Not a BootRequest!") return } @@ -114,18 +114,18 @@ func TestServerActivateAndServe(t *testing.T) { require.NotEqual(t, 0, len(ifaces)) xid := TransactionID{0xaa, 0xbb, 0xcc, 0xdd} - hwaddr := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} + hwaddr := net.HardwareAddr{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} modifiers := []Modifier{ WithTransactionID(xid), - WithHwAddr(hwaddr[:]), + WithHwAddr(hwaddr), } conv, err := c.Exchange(ifaces[0].Name, modifiers...) require.NoError(t, err) require.Equal(t, 4, len(conv)) for _, p := range conv { - require.Equal(t, xid, p.TransactionID()) - require.Equal(t, [16]byte(hwaddr), p.ClientHwAddr()) + require.Equal(t, xid, p.TransactionID) + require.Equal(t, hwaddr, p.ClientHWAddr) } } diff --git a/dhcpv4/types.go b/dhcpv4/types.go index aea5caf..c352292 100644 --- a/dhcpv4/types.go +++ b/dhcpv4/types.go @@ -28,14 +28,13 @@ const ( ) func (m MessageType) String() string { - if s, ok := MessageTypeToString[m]; ok { + if s, ok := messageTypeToString[m]; ok { return s } return "Unknown" } -// MessageTypeToString maps DHCP message types to human-readable strings. -var MessageTypeToString = map[MessageType]string{ +var messageTypeToString = map[MessageType]string{ MessageTypeDiscover: "DISCOVER", MessageTypeOffer: "OFFER", MessageTypeRequest: "REQUEST", @@ -56,14 +55,13 @@ const ( ) func (o OpcodeType) String() string { - if s, ok := OpcodeToString[o]; ok { + if s, ok := opcodeToString[o]; ok { return s } return "Unknown" } -// OpcodeToString maps an OpcodeType to its mnemonic name -var OpcodeToString = map[OpcodeType]string{ +var opcodeToString = map[OpcodeType]string{ OpcodeBootRequest: "BootRequest", OpcodeBootReply: "BootReply", } @@ -233,14 +231,13 @@ const ( ) func (o OptionCode) String() string { - if s, ok := OptionCodeToString[o]; ok { + if s, ok := optionCodeToString[o]; ok { return s } return "Unknown" } -// OptionCodeToString maps an OptionCode to its mnemonic name -var OptionCodeToString = map[OptionCode]string{ +var optionCodeToString = map[OptionCode]string{ OptionPad: "Pad", OptionSubnetMask: "Subnet Mask", OptionTimeOffset: "Time Offset", |