summaryrefslogtreecommitdiffhomepage
path: root/dhcpv4
diff options
context:
space:
mode:
authorHu Jun <hujun.work@gmail.com>2020-09-20 16:18:11 -0700
committerHu Jun <hujun.work@gmail.com>2020-09-20 16:18:11 -0700
commit305577bdea50fad7732f75ef37864255b2f46d43 (patch)
treebd68c682d32a5f9d61e7b7074596e64f91921dcc /dhcpv4
parent79aba137cf3ea8e0c1c980cad412cc034e145c0e (diff)
parent1a1c38473709f69a75e1d90fb3b4ff63f7b8c2cd (diff)
Merge branch 'master' of https://github.com/insomniacslk/dhcp
Diffstat (limited to 'dhcpv4')
-rw-r--r--dhcpv4/dhcpv4.go7
-rw-r--r--dhcpv4/nclient4/client.go78
-rw-r--r--dhcpv4/nclient4/conn_linux.go22
-rw-r--r--dhcpv4/nclient4/ipv4.go271
-rw-r--r--dhcpv4/option_string.go5
-rw-r--r--dhcpv4/server4/server.go105
6 files changed, 243 insertions, 245 deletions
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go
index 4316c5e..c09e76a 100644
--- a/dhcpv4/dhcpv4.go
+++ b/dhcpv4/dhcpv4.go
@@ -735,6 +735,13 @@ func (d *DHCPv4) MessageType() MessageType {
return m
}
+// Message returns the DHCPv4 (Error) Message option.
+//
+// The message options is described in RFC 2132, Section 9.9.
+func (d *DHCPv4) Message() string {
+ return GetString(OptionMessage, d.Options)
+}
+
// ParameterRequestList returns the DHCPv4 Parameter Request List.
//
// The parameter request list option is described by RFC 2132, Section 9.8.
diff --git a/dhcpv4/nclient4/client.go b/dhcpv4/nclient4/client.go
index 83ed065..8fedf4a 100644
--- a/dhcpv4/nclient4/client.go
+++ b/dhcpv4/nclient4/client.go
@@ -399,15 +399,36 @@ func WithServerAddr(n *net.UDPAddr) ClientOpt {
// Matcher matches DHCP packets.
type Matcher func(*dhcpv4.DHCPv4) bool
-// IsMessageType returns a matcher that checks for the message type.
-//
-// If t is MessageTypeNone, all packets are matched.
-func IsMessageType(t dhcpv4.MessageType) Matcher {
+// IsMessageType returns a matcher that checks for the message types.
+func IsMessageType(t dhcpv4.MessageType, tt ...dhcpv4.MessageType) Matcher {
return func(p *dhcpv4.DHCPv4) bool {
- return p.MessageType() == t || t == dhcpv4.MessageTypeNone
+ if p.MessageType() == t {
+ return true
+ }
+ for _, mt := range tt {
+ if p.MessageType() == mt {
+ return true
+ }
+ }
+ return false
}
}
+// RemoteAddr is the default DHCP server address this client sends messages to.
+func (c *Client) RemoteAddr() *net.UDPAddr {
+ // Make a copy so the caller cannot modify the address once the client
+ // is running.
+ cop := *c.serverAddr
+ return &cop
+}
+
+// InterfaceAddr returns the MAC address of the client's interface.
+func (c *Client) InterfaceAddr() net.HardwareAddr {
+ b := make(net.HardwareAddr, len(c.ifaceHWAddr))
+ copy(b, c.ifaceHWAddr)
+ return b
+}
+
// DiscoverOffer sends a DHCPDiscover message and returns the first valid offer
// received.
func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer *dhcpv4.DHCPv4, err error) {
@@ -416,17 +437,14 @@ func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier
discover, err := dhcpv4.NewDiscovery(c.ifaceHWAddr, dhcpv4.PrependModifiers(modifiers,
dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...)
if err != nil {
- err = fmt.Errorf("unable to create a discovery request: %w", err)
- return
+ return nil, fmt.Errorf("unable to create a discovery request: %w", err)
}
offer, err = c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer))
if err != nil {
- err = fmt.Errorf("got an error while the discovery request: %w", err)
- return
+ return nil, fmt.Errorf("got an error while the discovery request: %w", err)
}
-
- return
+ return offer, nil
}
// Request completes the 4-way Discover-Offer-Request-Ack handshake.
@@ -438,25 +456,47 @@ func (c *Client) Request(ctx context.Context, modifiers ...dhcpv4.Modifier) (lea
err = fmt.Errorf("unable to receive an offer: %w", err)
return
}
+ return c.RequestFromOffer(ctx, offer, modifiers...)
+}
+
+// ErrNak is returned if a DHCP server rejected our Request.
+type ErrNak struct {
+ Offer *dhcpv4.DHCPv4
+ Nak *dhcpv4.DHCPv4
+}
+
+// Error implements error.Error.
+func (e *ErrNak) Error() string {
+ if msg := e.Nak.Message(); len(msg) > 0 {
+ return fmt.Sprintf("server rejected request with Nak (msg: %s)", msg)
+ }
+ return "server rejected request with Nak"
+}
+// RequestFromOffer sends a Request message and waits for an response.
+func (c *Client) RequestFromOffer(ctx context.Context, offer *dhcpv4.DHCPv4, modifiers ...dhcpv4.Modifier) (*Lease, error) {
// TODO(chrisko): should this be unicast to the server?
request, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers,
dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...)
if err != nil {
- err = fmt.Errorf("unable to create a request: %w", err)
- return
+ return nil, fmt.Errorf("unable to create a request: %w", err)
}
- ack, err := c.SendAndRead(ctx, c.serverAddr, request, nil)
+ response, err := c.SendAndRead(ctx, c.serverAddr, request, IsMessageType(dhcpv4.MessageTypeAck, dhcpv4.MessageTypeNak))
if err != nil {
- err = fmt.Errorf("got an error while processing the request: %w", err)
- return
+ return nil, fmt.Errorf("got an error while processing the request: %w", err)
+ }
+ if response.MessageType() == dhcpv4.MessageTypeNak {
+ return nil, &ErrNak{
+ Offer: offer,
+ Nak: response,
+ }
}
- lease = &Lease{}
- lease.ACK = ack
+ lease := &Lease{}
+ lease.ACK = response
lease.Offer = offer
lease.CreationTime = time.Now()
- return
+ return lease, nil
}
// ErrTransactionIDInUse is returned if there were an attempt to send a message
diff --git a/dhcpv4/nclient4/conn_linux.go b/dhcpv4/nclient4/conn_linux.go
index 1d0ec3a..9257bee 100644
--- a/dhcpv4/nclient4/conn_linux.go
+++ b/dhcpv4/nclient4/conn_linux.go
@@ -82,8 +82,8 @@ func udpMatch(addr *net.UDPAddr, bound *net.UDPAddr) bool {
// ReadFrom reads raw IP packets and will try to match them against
// upc.boundAddr. Any matching packets are returned via the given buffer.
func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
- ipHdrMaxLen := IPv4MaximumHeaderSize
- udpHdrLen := UDPMinimumSize
+ ipHdrMaxLen := ipv4MaximumHeaderSize
+ udpHdrLen := udpMinimumSize
for {
pkt := make([]byte, ipHdrMaxLen+udpHdrLen+len(b))
@@ -98,28 +98,28 @@ func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
buf := uio.NewBigEndianBuffer(pkt)
// To read the header length, access data directly.
- ipHdr := IPv4(buf.Data())
- ipHdr = IPv4(buf.Consume(int(ipHdr.HeaderLength())))
+ ipHdr := ipv4(buf.Data())
+ ipHdr = ipv4(buf.Consume(int(ipHdr.headerLength())))
- if ipHdr.TransportProtocol() != UDPProtocolNumber {
+ if ipHdr.transportProtocol() != udpProtocolNumber {
continue
}
- udpHdr := UDP(buf.Consume(udpHdrLen))
+ udpHdr := udp(buf.Consume(udpHdrLen))
addr := &net.UDPAddr{
- IP: ipHdr.DestinationAddress(),
- Port: int(udpHdr.DestinationPort()),
+ IP: ipHdr.destinationAddress(),
+ Port: int(udpHdr.destinationPort()),
}
if !udpMatch(addr, upc.boundAddr) {
continue
}
srcAddr := &net.UDPAddr{
- IP: ipHdr.SourceAddress(),
- Port: int(udpHdr.SourcePort()),
+ IP: ipHdr.sourceAddress(),
+ Port: int(udpHdr.sourcePort()),
}
// Extra padding after end of IP packet should be ignored,
// if not dhcp option parsing will fail.
- dhcpLen := int(ipHdr.PayloadLength()) - udpHdrLen
+ dhcpLen := int(ipHdr.payloadLength()) - udpHdrLen
return copy(b, buf.Consume(dhcpLen)), srcAddr, nil
}
}
diff --git a/dhcpv4/nclient4/ipv4.go b/dhcpv4/nclient4/ipv4.go
index 5733eb4..2dffd29 100644
--- a/dhcpv4/nclient4/ipv4.go
+++ b/dhcpv4/nclient4/ipv4.go
@@ -26,24 +26,24 @@ import (
)
const (
- versIHL = 0
- tos = 1
- totalLen = 2
- id = 4
- flagsFO = 6
- ttl = 8
- protocol = 9
- checksum = 10
- srcAddr = 12
- dstAddr = 16
+ versIHL = 0
+ tos = 1
+ totalLen = 2
+ id = 4
+ flagsFO = 6
+ ttl = 8
+ protocol = 9
+ checksumOff = 10
+ srcAddr = 12
+ dstAddr = 16
)
-// TransportProtocolNumber is the number of a transport protocol.
-type TransportProtocolNumber uint32
+// transportProtocolNumber is the number of a transport protocol.
+type transportProtocolNumber uint32
-// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
+// ipv4Fields contains the fields of an IPv4 packet. It is used to describe the
// fields of a packet that needs to be encoded.
-type IPv4Fields struct {
+type ipv4Fields struct {
// IHL is the "internet header length" field of an IPv4 packet.
IHL uint8
@@ -68,8 +68,8 @@ type IPv4Fields struct {
// Protocol is the "protocol" field of an IPv4 packet.
Protocol uint8
- // Checksum is the "checksum" field of an IPv4 packet.
- Checksum uint16
+ // checksum is the "checksum" field of an IPv4 packet.
+ checksum uint16
// SrcAddr is the "source ip address" of an IPv4 packet.
SrcAddr net.IP
@@ -78,142 +78,109 @@ type IPv4Fields struct {
DstAddr net.IP
}
-// IPv4 represents an ipv4 header stored in a byte array.
+// ipv4 represents an ipv4 header stored in a byte array.
// Most of the methods of IPv4 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv4 before using other methods.
-type IPv4 []byte
+type ipv4 []byte
const (
- // IPv4MinimumSize is the minimum size of a valid IPv4 packet.
- IPv4MinimumSize = 20
+ // ipv4MinimumSize is the minimum size of a valid IPv4 packet.
+ ipv4MinimumSize = 20
- // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
+ // ipv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
// that there are only 4 bits to represents the header length in 32-bit
// units, the header cannot exceed 15*4 = 60 bytes.
- IPv4MaximumHeaderSize = 60
+ ipv4MaximumHeaderSize = 60
- // IPv4AddressSize is the size, in bytes, of an IPv4 address.
- IPv4AddressSize = 4
-
- // IPv4Version is the version of the ipv4 protocol.
- IPv4Version = 4
-)
-
-var (
- // IPv4Broadcast is the broadcast address of the IPv4 protocol.
- IPv4Broadcast = net.IP{0xff, 0xff, 0xff, 0xff}
-
- // IPv4Any is the non-routable IPv4 "any" meta address.
- IPv4Any = net.IP{0, 0, 0, 0}
+ // ipv4AddressSize is the size, in bytes, of an IPv4 address.
+ ipv4AddressSize = 4
)
-// Flags that may be set in an IPv4 packet.
-const (
- IPv4FlagMoreFragments = 1 << iota
- IPv4FlagDontFragment
-)
-
-// HeaderLength returns the value of the "header length" field of the ipv4
+// headerLength returns the value of the "header length" field of the ipv4
// header.
-func (b IPv4) HeaderLength() uint8 {
+func (b ipv4) headerLength() uint8 {
return (b[versIHL] & 0xf) * 4
}
-// Protocol returns the value of the protocol field of the ipv4 header.
-func (b IPv4) Protocol() uint8 {
+// protocol returns the value of the protocol field of the ipv4 header.
+func (b ipv4) protocol() uint8 {
return b[protocol]
}
-// SourceAddress returns the "source address" field of the ipv4 header.
-func (b IPv4) SourceAddress() net.IP {
- return net.IP(b[srcAddr : srcAddr+IPv4AddressSize])
+// sourceAddress returns the "source address" field of the ipv4 header.
+func (b ipv4) sourceAddress() net.IP {
+ return net.IP(b[srcAddr : srcAddr+ipv4AddressSize])
}
-// DestinationAddress returns the "destination address" field of the ipv4
+// destinationAddress returns the "destination address" field of the ipv4
// header.
-func (b IPv4) DestinationAddress() net.IP {
- return net.IP(b[dstAddr : dstAddr+IPv4AddressSize])
-}
-
-// TransportProtocol implements Network.TransportProtocol.
-func (b IPv4) TransportProtocol() TransportProtocolNumber {
- return TransportProtocolNumber(b.Protocol())
+func (b ipv4) destinationAddress() net.IP {
+ return net.IP(b[dstAddr : dstAddr+ipv4AddressSize])
}
-// Payload implements Network.Payload.
-func (b IPv4) Payload() []byte {
- return b[b.HeaderLength():][:b.PayloadLength()]
+// transportProtocol implements Network.transportProtocol.
+func (b ipv4) transportProtocol() transportProtocolNumber {
+ return transportProtocolNumber(b.protocol())
}
-// PayloadLength returns the length of the payload portion of the ipv4 packet.
-func (b IPv4) PayloadLength() uint16 {
- return b.TotalLength() - uint16(b.HeaderLength())
+// payloadLength returns the length of the payload portion of the ipv4 packet.
+func (b ipv4) payloadLength() uint16 {
+ return b.totalLength() - uint16(b.headerLength())
}
-// TotalLength returns the "total length" field of the ipv4 header.
-func (b IPv4) TotalLength() uint16 {
+// totalLength returns the "total length" field of the ipv4 header.
+func (b ipv4) totalLength() uint16 {
return binary.BigEndian.Uint16(b[totalLen:])
}
-// SetTotalLength sets the "total length" field of the ipv4 header.
-func (b IPv4) SetTotalLength(totalLength uint16) {
+// setTotalLength sets the "total length" field of the ipv4 header.
+func (b ipv4) setTotalLength(totalLength uint16) {
binary.BigEndian.PutUint16(b[totalLen:], totalLength)
}
-// SetChecksum sets the checksum field of the ipv4 header.
-func (b IPv4) SetChecksum(v uint16) {
- binary.BigEndian.PutUint16(b[checksum:], v)
+// setChecksum sets the checksum field of the ipv4 header.
+func (b ipv4) setChecksum(v uint16) {
+ binary.BigEndian.PutUint16(b[checksumOff:], v)
}
-// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
+// setFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
// ipv4 header.
-func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
+func (b ipv4) setFlagsFragmentOffset(flags uint8, offset uint16) {
v := (uint16(flags) << 13) | (offset >> 3)
binary.BigEndian.PutUint16(b[flagsFO:], v)
}
-// SetSourceAddress sets the "source address" field of the ipv4 header.
-func (b IPv4) SetSourceAddress(addr net.IP) {
- copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.To4())
-}
-
-// SetDestinationAddress sets the "destination address" field of the ipv4
-// header.
-func (b IPv4) SetDestinationAddress(addr net.IP) {
- copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.To4())
-}
-
-// CalculateChecksum calculates the checksum of the ipv4 header.
-func (b IPv4) CalculateChecksum() uint16 {
- return Checksum(b[:b.HeaderLength()], 0)
+// calculateChecksum calculates the checksum of the ipv4 header.
+func (b ipv4) calculateChecksum() uint16 {
+ return checksum(b[:b.headerLength()], 0)
}
-// Encode encodes all the fields of the ipv4 header.
-func (b IPv4) Encode(i *IPv4Fields) {
+// encode encodes all the fields of the ipv4 header.
+func (b ipv4) encode(i *ipv4Fields) {
b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
b[tos] = i.TOS
- b.SetTotalLength(i.TotalLength)
+ b.setTotalLength(i.TotalLength)
binary.BigEndian.PutUint16(b[id:], i.ID)
- b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset)
+ b.setFlagsFragmentOffset(i.Flags, i.FragmentOffset)
b[ttl] = i.TTL
b[protocol] = i.Protocol
- b.SetChecksum(i.Checksum)
- copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr)
- copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
+ b.setChecksum(i.checksum)
+ copy(b[srcAddr:srcAddr+ipv4AddressSize], i.SrcAddr)
+ copy(b[dstAddr:dstAddr+ipv4AddressSize], i.DstAddr)
}
const (
udpSrcPort = 0
udpDstPort = 2
udpLength = 4
- udpChecksum = 6
+ udpchecksum = 6
)
-// UDPFields contains the fields of a UDP packet. It is used to describe the
+// udpFields contains the fields of a udp packet. It is used to describe the
// fields of a packet that needs to be encoded.
-type UDPFields struct {
- // SrcPort is the "source port" field of a UDP packet.
+type udpFields struct {
+ // SrcPort is the "source port" field of a udp packet.
SrcPort uint16
// DstPort is the "destination port" field of a UDP packet.
@@ -222,80 +189,60 @@ type UDPFields struct {
// Length is the "length" field of a UDP packet.
Length uint16
- // Checksum is the "checksum" field of a UDP packet.
- Checksum uint16
+ // checksum is the "checksum" field of a UDP packet.
+ checksum uint16
}
-// UDP represents a UDP header stored in a byte array.
-type UDP []byte
+// udp represents a udp header stored in a byte array.
+type udp []byte
const (
- // UDPMinimumSize is the minimum size of a valid UDP packet.
- UDPMinimumSize = 8
+ // udpMinimumSize is the minimum size of a valid udp packet.
+ udpMinimumSize = 8
- // UDPProtocolNumber is UDP's transport protocol number.
- UDPProtocolNumber TransportProtocolNumber = 17
+ // udpProtocolNumber is udp's transport protocol number.
+ udpProtocolNumber transportProtocolNumber = 17
)
-// SourcePort returns the "source port" field of the udp header.
-func (b UDP) SourcePort() uint16 {
+// sourcePort returns the "source port" field of the udp header.
+func (b udp) sourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
// DestinationPort returns the "destination port" field of the udp header.
-func (b UDP) DestinationPort() uint16 {
+func (b udp) destinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
// Length returns the "length" field of the udp header.
-func (b UDP) Length() uint16 {
+func (b udp) length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
-// SetSourcePort sets the "source port" field of the udp header.
-func (b UDP) SetSourcePort(port uint16) {
- binary.BigEndian.PutUint16(b[udpSrcPort:], port)
-}
-
-// SetDestinationPort sets the "destination port" field of the udp header.
-func (b UDP) SetDestinationPort(port uint16) {
- binary.BigEndian.PutUint16(b[udpDstPort:], port)
-}
-
-// SetChecksum sets the "checksum" field of the udp header.
-func (b UDP) SetChecksum(checksum uint16) {
- binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
-}
-
-// Payload returns the data contained in the UDP datagram.
-func (b UDP) Payload() []byte {
- return b[UDPMinimumSize:]
-}
-
-// Checksum returns the "checksum" field of the udp header.
-func (b UDP) Checksum() uint16 {
- return binary.BigEndian.Uint16(b[udpChecksum:])
+// setChecksum sets the "checksum" field of the udp header.
+func (b udp) setChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[udpchecksum:], checksum)
}
-// CalculateChecksum calculates the checksum of the udp packet, given the total
+// calculateChecksum calculates the checksum of the udp packet, given the total
// length of the packet and the checksum of the network-layer pseudo-header
// (excluding the total length) and the checksum of the payload.
-func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
+func (b udp) calculateChecksum(partialchecksum uint16, totalLen uint16) uint16 {
// Add the length portion of the checksum to the pseudo-checksum.
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
- checksum := Checksum(tmp, partialChecksum)
+ xsum := checksum(tmp, partialchecksum)
// Calculate the rest of the checksum.
- return Checksum(b[:UDPMinimumSize], checksum)
+ return checksum(b[:udpMinimumSize], xsum)
}
-// Encode encodes all the fields of the udp header.
-func (b UDP) Encode(u *UDPFields) {
+// encode encodes all the fields of the udp header.
+func (b udp) encode(u *udpFields) {
binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
- binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
+ binary.BigEndian.PutUint16(b[udpchecksum:], u.checksum)
}
func calculateChecksum(buf []byte, initial uint32) uint16 {
@@ -311,65 +258,65 @@ func calculateChecksum(buf []byte, initial uint32) uint16 {
v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
}
- return ChecksumCombine(uint16(v), uint16(v>>16))
+ return checksumCombine(uint16(v), uint16(v>>16))
}
-// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
+// checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
// given byte array.
//
// The initial checksum must have been computed on an even number of bytes.
-func Checksum(buf []byte, initial uint16) uint16 {
+func checksum(buf []byte, initial uint16) uint16 {
return calculateChecksum(buf, uint32(initial))
}
-// ChecksumCombine combines the two uint16 to form their checksum. This is done
+// checksumCombine combines the two uint16 to form their checksum. This is done
// by adding them and the carry.
//
// Note that checksum a must have been computed on an even number of bytes.
-func ChecksumCombine(a, b uint16) uint16 {
+func checksumCombine(a, b uint16) uint16 {
v := uint32(a) + uint32(b)
return uint16(v + v>>16)
}
-// PseudoHeaderChecksum calculates the pseudo-header checksum for the
+// pseudoHeaderchecksum calculates the pseudo-header checksum for the
// given destination protocol and network address, ignoring the length
-// field. Pseudo-headers are needed by transport layers when calculating
+// field. pseudo-headers are needed by transport layers when calculating
// their own checksum.
-func PseudoHeaderChecksum(protocol TransportProtocolNumber, srcAddr net.IP, dstAddr net.IP) uint16 {
- xsum := Checksum([]byte(srcAddr), 0)
- xsum = Checksum([]byte(dstAddr), xsum)
- return Checksum([]byte{0, uint8(protocol)}, xsum)
+func pseudoHeaderchecksum(protocol transportProtocolNumber, srcAddr net.IP, dstAddr net.IP) uint16 {
+ xsum := checksum([]byte(srcAddr), 0)
+ xsum = checksum([]byte(dstAddr), xsum)
+ return checksum([]byte{0, uint8(protocol)}, xsum)
}
func udp4pkt(packet []byte, dest *net.UDPAddr, src *net.UDPAddr) []byte {
- ipLen := IPv4MinimumSize
- udpLen := UDPMinimumSize
+ ipLen := ipv4MinimumSize
+ udpLen := udpMinimumSize
h := make([]byte, 0, ipLen+udpLen+len(packet))
hdr := uio.NewBigEndianBuffer(h)
- ipv4fields := &IPv4Fields{
- IHL: IPv4MinimumSize,
+ ipv4fields := &ipv4Fields{
+ IHL: ipv4MinimumSize,
TotalLength: uint16(ipLen + udpLen + len(packet)),
TTL: 64, // Per RFC 1700's recommendation for IP time to live
- Protocol: uint8(UDPProtocolNumber),
+ Protocol: uint8(udpProtocolNumber),
SrcAddr: src.IP.To4(),
DstAddr: dest.IP.To4(),
}
- ipv4hdr := IPv4(hdr.WriteN(ipLen))
- ipv4hdr.Encode(ipv4fields)
- ipv4hdr.SetChecksum(^ipv4hdr.CalculateChecksum())
+ ipv4hdr := ipv4(hdr.WriteN(ipLen))
+ ipv4hdr.encode(ipv4fields)
+ ipv4hdr.setChecksum(^ipv4hdr.calculateChecksum())
- udphdr := UDP(hdr.WriteN(udpLen))
- udphdr.Encode(&UDPFields{
+ udphdr := udp(hdr.WriteN(udpLen))
+ udphdr.encode(&udpFields{
SrcPort: uint16(src.Port),
DstPort: uint16(dest.Port),
Length: uint16(udpLen + len(packet)),
})
- xsum := Checksum(packet, PseudoHeaderChecksum(
- ipv4hdr.TransportProtocol(), ipv4fields.SrcAddr, ipv4fields.DstAddr))
- udphdr.SetChecksum(^udphdr.CalculateChecksum(xsum, udphdr.Length()))
+ xsum := checksum(packet, pseudoHeaderchecksum(
+ ipv4hdr.transportProtocol(), ipv4fields.SrcAddr, ipv4fields.DstAddr))
+ udphdr.setChecksum(^udphdr.calculateChecksum(xsum, udphdr.length()))
hdr.WriteBytes(packet)
return hdr.Data()
diff --git a/dhcpv4/option_string.go b/dhcpv4/option_string.go
index 289319b..eb0cc2b 100644
--- a/dhcpv4/option_string.go
+++ b/dhcpv4/option_string.go
@@ -77,3 +77,8 @@ func OptClassIdentifier(name string) Option {
func OptUserClass(name string) Option {
return Option{Code: OptionUserClassInformation, Value: String(name)}
}
+
+// OptMessage returns a new DHCPv4 (Error) Message option.
+func OptMessage(msg string) Option {
+ return Option{Code: OptionMessage, Value: String(msg)}
+}
diff --git a/dhcpv4/server4/server.go b/dhcpv4/server4/server.go
index c50e6a5..4e6796f 100644
--- a/dhcpv4/server4/server.go
+++ b/dhcpv4/server4/server.go
@@ -1,3 +1,55 @@
+// Package server4 is a basic, extensible DHCPv4 server.
+//
+// To use the DHCPv4 server code you have to call NewServer with two arguments:
+// - an interface to listen on,
+// - an address to listen on, and
+// - a handler function, that will be called every time a valid DHCPv4 packet is
+// received.
+//
+// The address to listen on is used to know IP address, port and optionally the
+// scope to create and UDP socket to listen on for DHCPv4 traffic.
+//
+// The handler is a function that takes as input a packet connection, that can
+// be used to reply to the client; a peer address, that identifies the client
+// sending the request, and the DHCPv4 packet itself. Just implement your
+// custom logic in the handler.
+//
+// Optionally, NewServer can receive options that will modify the server
+// object. Some options already exist, for example WithConn. If this option is
+// passed with a valid connection, the listening address argument is ignored.
+//
+// Example program:
+//
+// package main
+//
+// import (
+// "log"
+// "net"
+//
+// "github.com/insomniacslk/dhcp/dhcpv4"
+// "github.com/insomniacslk/dhcp/dhcpv4/server4"
+// )
+//
+// func handler(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) {
+// // this function will just print the received DHCPv4 message, without replying
+// log.Print(m.Summary())
+// }
+//
+// func main() {
+// laddr := net.UDPAddr{
+// IP: net.ParseIP("127.0.0.1"),
+// Port: 67,
+// }
+// server, err := server4.NewServer("eth0", &laddr, handler)
+// if err != nil {
+// log.Fatal(err)
+// }
+//
+// // This never returns. If you want to do other stuff, dump it into a
+// // goroutine.
+// server.Serve()
+// }
+//
package server4
import (
@@ -8,59 +60,6 @@ import (
"github.com/insomniacslk/dhcp/dhcpv4"
)
-/*
- To use the DHCPv4 server code you have to call NewServer with two arguments:
- - an address to listen on, and
- - a handler function, that will be called every time a valid DHCPv4 packet is
- received.
-
- The address to listen on is used to know IP address, port and optionally the
- scope to create and UDP socket to listen on for DHCPv4 traffic.
-
- The handler is a function that takes as input a packet connection, that can be
- used to reply to the client; a peer address, that identifies the client sending
- the request, and the DHCPv4 packet itself. Just implement your custom logic in
- the handler.
-
- Optionally, NewServer can receive options that will modify the server object.
- Some options already exist, for example WithConn. If this option is passed with
- a valid connection, the listening address argument is ignored.
-
- Example program:
-
-
-package main
-
-import (
- "log"
- "net"
-
- "github.com/insomniacslk/dhcp/dhcpv4"
- "github.com/insomniacslk/dhcp/dhcpv4/server4"
-)
-
-func handler(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) {
- // this function will just print the received DHCPv4 message, without replying
- log.Print(m.Summary())
-}
-
-func main() {
- laddr := net.UDPAddr{
- IP: net.ParseIP("127.0.0.1"),
- Port: 67,
- }
- server, err := server4.NewServer(&laddr, handler)
- if err != nil {
- log.Fatal(err)
- }
-
- // This never returns. If you want to do other stuff, dump it into a
- // goroutine.
- server.Serve()
-}
-
-*/
-
// Handler is a type that defines the handler function to be called every time a
// valid DHCPv4 message is received
type Handler func(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4)