diff options
Diffstat (limited to 'dhcpv4')
-rw-r--r-- | dhcpv4/dhcpv4.go | 7 | ||||
-rw-r--r-- | dhcpv4/nclient4/client.go | 78 | ||||
-rw-r--r-- | dhcpv4/nclient4/conn_linux.go | 22 | ||||
-rw-r--r-- | dhcpv4/nclient4/ipv4.go | 271 | ||||
-rw-r--r-- | dhcpv4/option_string.go | 5 | ||||
-rw-r--r-- | dhcpv4/server4/server.go | 105 |
6 files changed, 243 insertions, 245 deletions
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index 4316c5e..c09e76a 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -735,6 +735,13 @@ func (d *DHCPv4) MessageType() MessageType { return m } +// Message returns the DHCPv4 (Error) Message option. +// +// The message options is described in RFC 2132, Section 9.9. +func (d *DHCPv4) Message() string { + return GetString(OptionMessage, d.Options) +} + // ParameterRequestList returns the DHCPv4 Parameter Request List. // // The parameter request list option is described by RFC 2132, Section 9.8. diff --git a/dhcpv4/nclient4/client.go b/dhcpv4/nclient4/client.go index 83ed065..8fedf4a 100644 --- a/dhcpv4/nclient4/client.go +++ b/dhcpv4/nclient4/client.go @@ -399,15 +399,36 @@ func WithServerAddr(n *net.UDPAddr) ClientOpt { // Matcher matches DHCP packets. type Matcher func(*dhcpv4.DHCPv4) bool -// IsMessageType returns a matcher that checks for the message type. -// -// If t is MessageTypeNone, all packets are matched. -func IsMessageType(t dhcpv4.MessageType) Matcher { +// IsMessageType returns a matcher that checks for the message types. +func IsMessageType(t dhcpv4.MessageType, tt ...dhcpv4.MessageType) Matcher { return func(p *dhcpv4.DHCPv4) bool { - return p.MessageType() == t || t == dhcpv4.MessageTypeNone + if p.MessageType() == t { + return true + } + for _, mt := range tt { + if p.MessageType() == mt { + return true + } + } + return false } } +// RemoteAddr is the default DHCP server address this client sends messages to. +func (c *Client) RemoteAddr() *net.UDPAddr { + // Make a copy so the caller cannot modify the address once the client + // is running. + cop := *c.serverAddr + return &cop +} + +// InterfaceAddr returns the MAC address of the client's interface. +func (c *Client) InterfaceAddr() net.HardwareAddr { + b := make(net.HardwareAddr, len(c.ifaceHWAddr)) + copy(b, c.ifaceHWAddr) + return b +} + // DiscoverOffer sends a DHCPDiscover message and returns the first valid offer // received. func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer *dhcpv4.DHCPv4, err error) { @@ -416,17 +437,14 @@ func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier discover, err := dhcpv4.NewDiscovery(c.ifaceHWAddr, dhcpv4.PrependModifiers(modifiers, dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...) if err != nil { - err = fmt.Errorf("unable to create a discovery request: %w", err) - return + return nil, fmt.Errorf("unable to create a discovery request: %w", err) } offer, err = c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer)) if err != nil { - err = fmt.Errorf("got an error while the discovery request: %w", err) - return + return nil, fmt.Errorf("got an error while the discovery request: %w", err) } - - return + return offer, nil } // Request completes the 4-way Discover-Offer-Request-Ack handshake. @@ -438,25 +456,47 @@ func (c *Client) Request(ctx context.Context, modifiers ...dhcpv4.Modifier) (lea err = fmt.Errorf("unable to receive an offer: %w", err) return } + return c.RequestFromOffer(ctx, offer, modifiers...) +} + +// ErrNak is returned if a DHCP server rejected our Request. +type ErrNak struct { + Offer *dhcpv4.DHCPv4 + Nak *dhcpv4.DHCPv4 +} + +// Error implements error.Error. +func (e *ErrNak) Error() string { + if msg := e.Nak.Message(); len(msg) > 0 { + return fmt.Sprintf("server rejected request with Nak (msg: %s)", msg) + } + return "server rejected request with Nak" +} +// RequestFromOffer sends a Request message and waits for an response. +func (c *Client) RequestFromOffer(ctx context.Context, offer *dhcpv4.DHCPv4, modifiers ...dhcpv4.Modifier) (*Lease, error) { // TODO(chrisko): should this be unicast to the server? request, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers, dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...) if err != nil { - err = fmt.Errorf("unable to create a request: %w", err) - return + return nil, fmt.Errorf("unable to create a request: %w", err) } - ack, err := c.SendAndRead(ctx, c.serverAddr, request, nil) + response, err := c.SendAndRead(ctx, c.serverAddr, request, IsMessageType(dhcpv4.MessageTypeAck, dhcpv4.MessageTypeNak)) if err != nil { - err = fmt.Errorf("got an error while processing the request: %w", err) - return + return nil, fmt.Errorf("got an error while processing the request: %w", err) + } + if response.MessageType() == dhcpv4.MessageTypeNak { + return nil, &ErrNak{ + Offer: offer, + Nak: response, + } } - lease = &Lease{} - lease.ACK = ack + lease := &Lease{} + lease.ACK = response lease.Offer = offer lease.CreationTime = time.Now() - return + return lease, nil } // ErrTransactionIDInUse is returned if there were an attempt to send a message diff --git a/dhcpv4/nclient4/conn_linux.go b/dhcpv4/nclient4/conn_linux.go index 1d0ec3a..9257bee 100644 --- a/dhcpv4/nclient4/conn_linux.go +++ b/dhcpv4/nclient4/conn_linux.go @@ -82,8 +82,8 @@ func udpMatch(addr *net.UDPAddr, bound *net.UDPAddr) bool { // ReadFrom reads raw IP packets and will try to match them against // upc.boundAddr. Any matching packets are returned via the given buffer. func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { - ipHdrMaxLen := IPv4MaximumHeaderSize - udpHdrLen := UDPMinimumSize + ipHdrMaxLen := ipv4MaximumHeaderSize + udpHdrLen := udpMinimumSize for { pkt := make([]byte, ipHdrMaxLen+udpHdrLen+len(b)) @@ -98,28 +98,28 @@ func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { buf := uio.NewBigEndianBuffer(pkt) // To read the header length, access data directly. - ipHdr := IPv4(buf.Data()) - ipHdr = IPv4(buf.Consume(int(ipHdr.HeaderLength()))) + ipHdr := ipv4(buf.Data()) + ipHdr = ipv4(buf.Consume(int(ipHdr.headerLength()))) - if ipHdr.TransportProtocol() != UDPProtocolNumber { + if ipHdr.transportProtocol() != udpProtocolNumber { continue } - udpHdr := UDP(buf.Consume(udpHdrLen)) + udpHdr := udp(buf.Consume(udpHdrLen)) addr := &net.UDPAddr{ - IP: ipHdr.DestinationAddress(), - Port: int(udpHdr.DestinationPort()), + IP: ipHdr.destinationAddress(), + Port: int(udpHdr.destinationPort()), } if !udpMatch(addr, upc.boundAddr) { continue } srcAddr := &net.UDPAddr{ - IP: ipHdr.SourceAddress(), - Port: int(udpHdr.SourcePort()), + IP: ipHdr.sourceAddress(), + Port: int(udpHdr.sourcePort()), } // Extra padding after end of IP packet should be ignored, // if not dhcp option parsing will fail. - dhcpLen := int(ipHdr.PayloadLength()) - udpHdrLen + dhcpLen := int(ipHdr.payloadLength()) - udpHdrLen return copy(b, buf.Consume(dhcpLen)), srcAddr, nil } } diff --git a/dhcpv4/nclient4/ipv4.go b/dhcpv4/nclient4/ipv4.go index 5733eb4..2dffd29 100644 --- a/dhcpv4/nclient4/ipv4.go +++ b/dhcpv4/nclient4/ipv4.go @@ -26,24 +26,24 @@ import ( ) const ( - versIHL = 0 - tos = 1 - totalLen = 2 - id = 4 - flagsFO = 6 - ttl = 8 - protocol = 9 - checksum = 10 - srcAddr = 12 - dstAddr = 16 + versIHL = 0 + tos = 1 + totalLen = 2 + id = 4 + flagsFO = 6 + ttl = 8 + protocol = 9 + checksumOff = 10 + srcAddr = 12 + dstAddr = 16 ) -// TransportProtocolNumber is the number of a transport protocol. -type TransportProtocolNumber uint32 +// transportProtocolNumber is the number of a transport protocol. +type transportProtocolNumber uint32 -// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the +// ipv4Fields contains the fields of an IPv4 packet. It is used to describe the // fields of a packet that needs to be encoded. -type IPv4Fields struct { +type ipv4Fields struct { // IHL is the "internet header length" field of an IPv4 packet. IHL uint8 @@ -68,8 +68,8 @@ type IPv4Fields struct { // Protocol is the "protocol" field of an IPv4 packet. Protocol uint8 - // Checksum is the "checksum" field of an IPv4 packet. - Checksum uint16 + // checksum is the "checksum" field of an IPv4 packet. + checksum uint16 // SrcAddr is the "source ip address" of an IPv4 packet. SrcAddr net.IP @@ -78,142 +78,109 @@ type IPv4Fields struct { DstAddr net.IP } -// IPv4 represents an ipv4 header stored in a byte array. +// ipv4 represents an ipv4 header stored in a byte array. // Most of the methods of IPv4 access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. // Always call IsValid() to validate an instance of IPv4 before using other methods. -type IPv4 []byte +type ipv4 []byte const ( - // IPv4MinimumSize is the minimum size of a valid IPv4 packet. - IPv4MinimumSize = 20 + // ipv4MinimumSize is the minimum size of a valid IPv4 packet. + ipv4MinimumSize = 20 - // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given + // ipv4MaximumHeaderSize is the maximum size of an IPv4 header. Given // that there are only 4 bits to represents the header length in 32-bit // units, the header cannot exceed 15*4 = 60 bytes. - IPv4MaximumHeaderSize = 60 + ipv4MaximumHeaderSize = 60 - // IPv4AddressSize is the size, in bytes, of an IPv4 address. - IPv4AddressSize = 4 - - // IPv4Version is the version of the ipv4 protocol. - IPv4Version = 4 -) - -var ( - // IPv4Broadcast is the broadcast address of the IPv4 protocol. - IPv4Broadcast = net.IP{0xff, 0xff, 0xff, 0xff} - - // IPv4Any is the non-routable IPv4 "any" meta address. - IPv4Any = net.IP{0, 0, 0, 0} + // ipv4AddressSize is the size, in bytes, of an IPv4 address. + ipv4AddressSize = 4 ) -// Flags that may be set in an IPv4 packet. -const ( - IPv4FlagMoreFragments = 1 << iota - IPv4FlagDontFragment -) - -// HeaderLength returns the value of the "header length" field of the ipv4 +// headerLength returns the value of the "header length" field of the ipv4 // header. -func (b IPv4) HeaderLength() uint8 { +func (b ipv4) headerLength() uint8 { return (b[versIHL] & 0xf) * 4 } -// Protocol returns the value of the protocol field of the ipv4 header. -func (b IPv4) Protocol() uint8 { +// protocol returns the value of the protocol field of the ipv4 header. +func (b ipv4) protocol() uint8 { return b[protocol] } -// SourceAddress returns the "source address" field of the ipv4 header. -func (b IPv4) SourceAddress() net.IP { - return net.IP(b[srcAddr : srcAddr+IPv4AddressSize]) +// sourceAddress returns the "source address" field of the ipv4 header. +func (b ipv4) sourceAddress() net.IP { + return net.IP(b[srcAddr : srcAddr+ipv4AddressSize]) } -// DestinationAddress returns the "destination address" field of the ipv4 +// destinationAddress returns the "destination address" field of the ipv4 // header. -func (b IPv4) DestinationAddress() net.IP { - return net.IP(b[dstAddr : dstAddr+IPv4AddressSize]) -} - -// TransportProtocol implements Network.TransportProtocol. -func (b IPv4) TransportProtocol() TransportProtocolNumber { - return TransportProtocolNumber(b.Protocol()) +func (b ipv4) destinationAddress() net.IP { + return net.IP(b[dstAddr : dstAddr+ipv4AddressSize]) } -// Payload implements Network.Payload. -func (b IPv4) Payload() []byte { - return b[b.HeaderLength():][:b.PayloadLength()] +// transportProtocol implements Network.transportProtocol. +func (b ipv4) transportProtocol() transportProtocolNumber { + return transportProtocolNumber(b.protocol()) } -// PayloadLength returns the length of the payload portion of the ipv4 packet. -func (b IPv4) PayloadLength() uint16 { - return b.TotalLength() - uint16(b.HeaderLength()) +// payloadLength returns the length of the payload portion of the ipv4 packet. +func (b ipv4) payloadLength() uint16 { + return b.totalLength() - uint16(b.headerLength()) } -// TotalLength returns the "total length" field of the ipv4 header. -func (b IPv4) TotalLength() uint16 { +// totalLength returns the "total length" field of the ipv4 header. +func (b ipv4) totalLength() uint16 { return binary.BigEndian.Uint16(b[totalLen:]) } -// SetTotalLength sets the "total length" field of the ipv4 header. -func (b IPv4) SetTotalLength(totalLength uint16) { +// setTotalLength sets the "total length" field of the ipv4 header. +func (b ipv4) setTotalLength(totalLength uint16) { binary.BigEndian.PutUint16(b[totalLen:], totalLength) } -// SetChecksum sets the checksum field of the ipv4 header. -func (b IPv4) SetChecksum(v uint16) { - binary.BigEndian.PutUint16(b[checksum:], v) +// setChecksum sets the checksum field of the ipv4 header. +func (b ipv4) setChecksum(v uint16) { + binary.BigEndian.PutUint16(b[checksumOff:], v) } -// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the +// setFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the // ipv4 header. -func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) { +func (b ipv4) setFlagsFragmentOffset(flags uint8, offset uint16) { v := (uint16(flags) << 13) | (offset >> 3) binary.BigEndian.PutUint16(b[flagsFO:], v) } -// SetSourceAddress sets the "source address" field of the ipv4 header. -func (b IPv4) SetSourceAddress(addr net.IP) { - copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.To4()) -} - -// SetDestinationAddress sets the "destination address" field of the ipv4 -// header. -func (b IPv4) SetDestinationAddress(addr net.IP) { - copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.To4()) -} - -// CalculateChecksum calculates the checksum of the ipv4 header. -func (b IPv4) CalculateChecksum() uint16 { - return Checksum(b[:b.HeaderLength()], 0) +// calculateChecksum calculates the checksum of the ipv4 header. +func (b ipv4) calculateChecksum() uint16 { + return checksum(b[:b.headerLength()], 0) } -// Encode encodes all the fields of the ipv4 header. -func (b IPv4) Encode(i *IPv4Fields) { +// encode encodes all the fields of the ipv4 header. +func (b ipv4) encode(i *ipv4Fields) { b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf) b[tos] = i.TOS - b.SetTotalLength(i.TotalLength) + b.setTotalLength(i.TotalLength) binary.BigEndian.PutUint16(b[id:], i.ID) - b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset) + b.setFlagsFragmentOffset(i.Flags, i.FragmentOffset) b[ttl] = i.TTL b[protocol] = i.Protocol - b.SetChecksum(i.Checksum) - copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr) - copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr) + b.setChecksum(i.checksum) + copy(b[srcAddr:srcAddr+ipv4AddressSize], i.SrcAddr) + copy(b[dstAddr:dstAddr+ipv4AddressSize], i.DstAddr) } const ( udpSrcPort = 0 udpDstPort = 2 udpLength = 4 - udpChecksum = 6 + udpchecksum = 6 ) -// UDPFields contains the fields of a UDP packet. It is used to describe the +// udpFields contains the fields of a udp packet. It is used to describe the // fields of a packet that needs to be encoded. -type UDPFields struct { - // SrcPort is the "source port" field of a UDP packet. +type udpFields struct { + // SrcPort is the "source port" field of a udp packet. SrcPort uint16 // DstPort is the "destination port" field of a UDP packet. @@ -222,80 +189,60 @@ type UDPFields struct { // Length is the "length" field of a UDP packet. Length uint16 - // Checksum is the "checksum" field of a UDP packet. - Checksum uint16 + // checksum is the "checksum" field of a UDP packet. + checksum uint16 } -// UDP represents a UDP header stored in a byte array. -type UDP []byte +// udp represents a udp header stored in a byte array. +type udp []byte const ( - // UDPMinimumSize is the minimum size of a valid UDP packet. - UDPMinimumSize = 8 + // udpMinimumSize is the minimum size of a valid udp packet. + udpMinimumSize = 8 - // UDPProtocolNumber is UDP's transport protocol number. - UDPProtocolNumber TransportProtocolNumber = 17 + // udpProtocolNumber is udp's transport protocol number. + udpProtocolNumber transportProtocolNumber = 17 ) -// SourcePort returns the "source port" field of the udp header. -func (b UDP) SourcePort() uint16 { +// sourcePort returns the "source port" field of the udp header. +func (b udp) sourcePort() uint16 { return binary.BigEndian.Uint16(b[udpSrcPort:]) } // DestinationPort returns the "destination port" field of the udp header. -func (b UDP) DestinationPort() uint16 { +func (b udp) destinationPort() uint16 { return binary.BigEndian.Uint16(b[udpDstPort:]) } // Length returns the "length" field of the udp header. -func (b UDP) Length() uint16 { +func (b udp) length() uint16 { return binary.BigEndian.Uint16(b[udpLength:]) } -// SetSourcePort sets the "source port" field of the udp header. -func (b UDP) SetSourcePort(port uint16) { - binary.BigEndian.PutUint16(b[udpSrcPort:], port) -} - -// SetDestinationPort sets the "destination port" field of the udp header. -func (b UDP) SetDestinationPort(port uint16) { - binary.BigEndian.PutUint16(b[udpDstPort:], port) -} - -// SetChecksum sets the "checksum" field of the udp header. -func (b UDP) SetChecksum(checksum uint16) { - binary.BigEndian.PutUint16(b[udpChecksum:], checksum) -} - -// Payload returns the data contained in the UDP datagram. -func (b UDP) Payload() []byte { - return b[UDPMinimumSize:] -} - -// Checksum returns the "checksum" field of the udp header. -func (b UDP) Checksum() uint16 { - return binary.BigEndian.Uint16(b[udpChecksum:]) +// setChecksum sets the "checksum" field of the udp header. +func (b udp) setChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[udpchecksum:], checksum) } -// CalculateChecksum calculates the checksum of the udp packet, given the total +// calculateChecksum calculates the checksum of the udp packet, given the total // length of the packet and the checksum of the network-layer pseudo-header // (excluding the total length) and the checksum of the payload. -func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 { +func (b udp) calculateChecksum(partialchecksum uint16, totalLen uint16) uint16 { // Add the length portion of the checksum to the pseudo-checksum. tmp := make([]byte, 2) binary.BigEndian.PutUint16(tmp, totalLen) - checksum := Checksum(tmp, partialChecksum) + xsum := checksum(tmp, partialchecksum) // Calculate the rest of the checksum. - return Checksum(b[:UDPMinimumSize], checksum) + return checksum(b[:udpMinimumSize], xsum) } -// Encode encodes all the fields of the udp header. -func (b UDP) Encode(u *UDPFields) { +// encode encodes all the fields of the udp header. +func (b udp) encode(u *udpFields) { binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort) binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort) binary.BigEndian.PutUint16(b[udpLength:], u.Length) - binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum) + binary.BigEndian.PutUint16(b[udpchecksum:], u.checksum) } func calculateChecksum(buf []byte, initial uint32) uint16 { @@ -311,65 +258,65 @@ func calculateChecksum(buf []byte, initial uint32) uint16 { v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) } - return ChecksumCombine(uint16(v), uint16(v>>16)) + return checksumCombine(uint16(v), uint16(v>>16)) } -// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the +// checksum calculates the checksum (as defined in RFC 1071) of the bytes in the // given byte array. // // The initial checksum must have been computed on an even number of bytes. -func Checksum(buf []byte, initial uint16) uint16 { +func checksum(buf []byte, initial uint16) uint16 { return calculateChecksum(buf, uint32(initial)) } -// ChecksumCombine combines the two uint16 to form their checksum. This is done +// checksumCombine combines the two uint16 to form their checksum. This is done // by adding them and the carry. // // Note that checksum a must have been computed on an even number of bytes. -func ChecksumCombine(a, b uint16) uint16 { +func checksumCombine(a, b uint16) uint16 { v := uint32(a) + uint32(b) return uint16(v + v>>16) } -// PseudoHeaderChecksum calculates the pseudo-header checksum for the +// pseudoHeaderchecksum calculates the pseudo-header checksum for the // given destination protocol and network address, ignoring the length -// field. Pseudo-headers are needed by transport layers when calculating +// field. pseudo-headers are needed by transport layers when calculating // their own checksum. -func PseudoHeaderChecksum(protocol TransportProtocolNumber, srcAddr net.IP, dstAddr net.IP) uint16 { - xsum := Checksum([]byte(srcAddr), 0) - xsum = Checksum([]byte(dstAddr), xsum) - return Checksum([]byte{0, uint8(protocol)}, xsum) +func pseudoHeaderchecksum(protocol transportProtocolNumber, srcAddr net.IP, dstAddr net.IP) uint16 { + xsum := checksum([]byte(srcAddr), 0) + xsum = checksum([]byte(dstAddr), xsum) + return checksum([]byte{0, uint8(protocol)}, xsum) } func udp4pkt(packet []byte, dest *net.UDPAddr, src *net.UDPAddr) []byte { - ipLen := IPv4MinimumSize - udpLen := UDPMinimumSize + ipLen := ipv4MinimumSize + udpLen := udpMinimumSize h := make([]byte, 0, ipLen+udpLen+len(packet)) hdr := uio.NewBigEndianBuffer(h) - ipv4fields := &IPv4Fields{ - IHL: IPv4MinimumSize, + ipv4fields := &ipv4Fields{ + IHL: ipv4MinimumSize, TotalLength: uint16(ipLen + udpLen + len(packet)), TTL: 64, // Per RFC 1700's recommendation for IP time to live - Protocol: uint8(UDPProtocolNumber), + Protocol: uint8(udpProtocolNumber), SrcAddr: src.IP.To4(), DstAddr: dest.IP.To4(), } - ipv4hdr := IPv4(hdr.WriteN(ipLen)) - ipv4hdr.Encode(ipv4fields) - ipv4hdr.SetChecksum(^ipv4hdr.CalculateChecksum()) + ipv4hdr := ipv4(hdr.WriteN(ipLen)) + ipv4hdr.encode(ipv4fields) + ipv4hdr.setChecksum(^ipv4hdr.calculateChecksum()) - udphdr := UDP(hdr.WriteN(udpLen)) - udphdr.Encode(&UDPFields{ + udphdr := udp(hdr.WriteN(udpLen)) + udphdr.encode(&udpFields{ SrcPort: uint16(src.Port), DstPort: uint16(dest.Port), Length: uint16(udpLen + len(packet)), }) - xsum := Checksum(packet, PseudoHeaderChecksum( - ipv4hdr.TransportProtocol(), ipv4fields.SrcAddr, ipv4fields.DstAddr)) - udphdr.SetChecksum(^udphdr.CalculateChecksum(xsum, udphdr.Length())) + xsum := checksum(packet, pseudoHeaderchecksum( + ipv4hdr.transportProtocol(), ipv4fields.SrcAddr, ipv4fields.DstAddr)) + udphdr.setChecksum(^udphdr.calculateChecksum(xsum, udphdr.length())) hdr.WriteBytes(packet) return hdr.Data() diff --git a/dhcpv4/option_string.go b/dhcpv4/option_string.go index 289319b..eb0cc2b 100644 --- a/dhcpv4/option_string.go +++ b/dhcpv4/option_string.go @@ -77,3 +77,8 @@ func OptClassIdentifier(name string) Option { func OptUserClass(name string) Option { return Option{Code: OptionUserClassInformation, Value: String(name)} } + +// OptMessage returns a new DHCPv4 (Error) Message option. +func OptMessage(msg string) Option { + return Option{Code: OptionMessage, Value: String(msg)} +} diff --git a/dhcpv4/server4/server.go b/dhcpv4/server4/server.go index c50e6a5..4e6796f 100644 --- a/dhcpv4/server4/server.go +++ b/dhcpv4/server4/server.go @@ -1,3 +1,55 @@ +// Package server4 is a basic, extensible DHCPv4 server. +// +// To use the DHCPv4 server code you have to call NewServer with two arguments: +// - an interface to listen on, +// - an address to listen on, and +// - a handler function, that will be called every time a valid DHCPv4 packet is +// received. +// +// The address to listen on is used to know IP address, port and optionally the +// scope to create and UDP socket to listen on for DHCPv4 traffic. +// +// The handler is a function that takes as input a packet connection, that can +// be used to reply to the client; a peer address, that identifies the client +// sending the request, and the DHCPv4 packet itself. Just implement your +// custom logic in the handler. +// +// Optionally, NewServer can receive options that will modify the server +// object. Some options already exist, for example WithConn. If this option is +// passed with a valid connection, the listening address argument is ignored. +// +// Example program: +// +// package main +// +// import ( +// "log" +// "net" +// +// "github.com/insomniacslk/dhcp/dhcpv4" +// "github.com/insomniacslk/dhcp/dhcpv4/server4" +// ) +// +// func handler(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) { +// // this function will just print the received DHCPv4 message, without replying +// log.Print(m.Summary()) +// } +// +// func main() { +// laddr := net.UDPAddr{ +// IP: net.ParseIP("127.0.0.1"), +// Port: 67, +// } +// server, err := server4.NewServer("eth0", &laddr, handler) +// if err != nil { +// log.Fatal(err) +// } +// +// // This never returns. If you want to do other stuff, dump it into a +// // goroutine. +// server.Serve() +// } +// package server4 import ( @@ -8,59 +60,6 @@ import ( "github.com/insomniacslk/dhcp/dhcpv4" ) -/* - To use the DHCPv4 server code you have to call NewServer with two arguments: - - an address to listen on, and - - a handler function, that will be called every time a valid DHCPv4 packet is - received. - - The address to listen on is used to know IP address, port and optionally the - scope to create and UDP socket to listen on for DHCPv4 traffic. - - The handler is a function that takes as input a packet connection, that can be - used to reply to the client; a peer address, that identifies the client sending - the request, and the DHCPv4 packet itself. Just implement your custom logic in - the handler. - - Optionally, NewServer can receive options that will modify the server object. - Some options already exist, for example WithConn. If this option is passed with - a valid connection, the listening address argument is ignored. - - Example program: - - -package main - -import ( - "log" - "net" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/insomniacslk/dhcp/dhcpv4/server4" -) - -func handler(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) { - // this function will just print the received DHCPv4 message, without replying - log.Print(m.Summary()) -} - -func main() { - laddr := net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 67, - } - server, err := server4.NewServer(&laddr, handler) - if err != nil { - log.Fatal(err) - } - - // This never returns. If you want to do other stuff, dump it into a - // goroutine. - server.Serve() -} - -*/ - // Handler is a type that defines the handler function to be called every time a // valid DHCPv4 message is received type Handler func(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) |