summaryrefslogtreecommitdiffhomepage
path: root/dhcpv4/nclient4/ipv4.go
diff options
context:
space:
mode:
Diffstat (limited to 'dhcpv4/nclient4/ipv4.go')
-rw-r--r--dhcpv4/nclient4/ipv4.go376
1 files changed, 376 insertions, 0 deletions
diff --git a/dhcpv4/nclient4/ipv4.go b/dhcpv4/nclient4/ipv4.go
new file mode 100644
index 0000000..81ba837
--- /dev/null
+++ b/dhcpv4/nclient4/ipv4.go
@@ -0,0 +1,376 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// This file contains code taken from gVisor.
+
+// +build go1.12
+
+package nclient4
+
+import (
+ "encoding/binary"
+ "net"
+
+ "github.com/u-root/u-root/pkg/uio"
+)
+
+const (
+ versIHL = 0
+ tos = 1
+ totalLen = 2
+ id = 4
+ flagsFO = 6
+ ttl = 8
+ protocol = 9
+ checksum = 10
+ srcAddr = 12
+ dstAddr = 16
+)
+
+// TransportProtocolNumber is the number of a transport protocol.
+type TransportProtocolNumber uint32
+
+// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv4Fields struct {
+ // IHL is the "internet header length" field of an IPv4 packet.
+ IHL uint8
+
+ // TOS is the "type of service" field of an IPv4 packet.
+ TOS uint8
+
+ // TotalLength is the "total length" field of an IPv4 packet.
+ TotalLength uint16
+
+ // ID is the "identification" field of an IPv4 packet.
+ ID uint16
+
+ // Flags is the "flags" field of an IPv4 packet.
+ Flags uint8
+
+ // FragmentOffset is the "fragment offset" field of an IPv4 packet.
+ FragmentOffset uint16
+
+ // TTL is the "time to live" field of an IPv4 packet.
+ TTL uint8
+
+ // Protocol is the "protocol" field of an IPv4 packet.
+ Protocol uint8
+
+ // Checksum is the "checksum" field of an IPv4 packet.
+ Checksum uint16
+
+ // SrcAddr is the "source ip address" of an IPv4 packet.
+ SrcAddr net.IP
+
+ // DstAddr is the "destination ip address" of an IPv4 packet.
+ DstAddr net.IP
+}
+
+// IPv4 represents an ipv4 header stored in a byte array.
+// Most of the methods of IPv4 access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv4 before using other methods.
+type IPv4 []byte
+
+const (
+ // IPv4MinimumSize is the minimum size of a valid IPv4 packet.
+ IPv4MinimumSize = 20
+
+ // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
+ // that there are only 4 bits to represents the header length in 32-bit
+ // units, the header cannot exceed 15*4 = 60 bytes.
+ IPv4MaximumHeaderSize = 60
+
+ // IPv4AddressSize is the size, in bytes, of an IPv4 address.
+ IPv4AddressSize = 4
+
+ // IPv4Version is the version of the ipv4 protocol.
+ IPv4Version = 4
+)
+
+var (
+ // IPv4Broadcast is the broadcast address of the IPv4 procotol.
+ IPv4Broadcast = net.IP{0xff, 0xff, 0xff, 0xff}
+
+ // IPv4Any is the non-routable IPv4 "any" meta address.
+ IPv4Any = net.IP{0, 0, 0, 0}
+)
+
+// Flags that may be set in an IPv4 packet.
+const (
+ IPv4FlagMoreFragments = 1 << iota
+ IPv4FlagDontFragment
+)
+
+// HeaderLength returns the value of the "header length" field of the ipv4
+// header.
+func (b IPv4) HeaderLength() uint8 {
+ return (b[versIHL] & 0xf) * 4
+}
+
+// Protocol returns the value of the protocol field of the ipv4 header.
+func (b IPv4) Protocol() uint8 {
+ return b[protocol]
+}
+
+// SourceAddress returns the "source address" field of the ipv4 header.
+func (b IPv4) SourceAddress() net.IP {
+ return net.IP(b[srcAddr : srcAddr+IPv4AddressSize])
+}
+
+// DestinationAddress returns the "destination address" field of the ipv4
+// header.
+func (b IPv4) DestinationAddress() net.IP {
+ return net.IP(b[dstAddr : dstAddr+IPv4AddressSize])
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv4) TransportProtocol() TransportProtocolNumber {
+ return TransportProtocolNumber(b.Protocol())
+}
+
+// Payload implements Network.Payload.
+func (b IPv4) Payload() []byte {
+ return b[b.HeaderLength():][:b.PayloadLength()]
+}
+
+// PayloadLength returns the length of the payload portion of the ipv4 packet.
+func (b IPv4) PayloadLength() uint16 {
+ return b.TotalLength() - uint16(b.HeaderLength())
+}
+
+// TotalLength returns the "total length" field of the ipv4 header.
+func (b IPv4) TotalLength() uint16 {
+ return binary.BigEndian.Uint16(b[totalLen:])
+}
+
+// SetTotalLength sets the "total length" field of the ipv4 header.
+func (b IPv4) SetTotalLength(totalLength uint16) {
+ binary.BigEndian.PutUint16(b[totalLen:], totalLength)
+}
+
+// SetChecksum sets the checksum field of the ipv4 header.
+func (b IPv4) SetChecksum(v uint16) {
+ binary.BigEndian.PutUint16(b[checksum:], v)
+}
+
+// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
+// ipv4 header.
+func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
+ v := (uint16(flags) << 13) | (offset >> 3)
+ binary.BigEndian.PutUint16(b[flagsFO:], v)
+}
+
+// SetSourceAddress sets the "source address" field of the ipv4 header.
+func (b IPv4) SetSourceAddress(addr net.IP) {
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.To4())
+}
+
+// SetDestinationAddress sets the "destination address" field of the ipv4
+// header.
+func (b IPv4) SetDestinationAddress(addr net.IP) {
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.To4())
+}
+
+// CalculateChecksum calculates the checksum of the ipv4 header.
+func (b IPv4) CalculateChecksum() uint16 {
+ return Checksum(b[:b.HeaderLength()], 0)
+}
+
+// Encode encodes all the fields of the ipv4 header.
+func (b IPv4) Encode(i *IPv4Fields) {
+ b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
+ b[tos] = i.TOS
+ b.SetTotalLength(i.TotalLength)
+ binary.BigEndian.PutUint16(b[id:], i.ID)
+ b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset)
+ b[ttl] = i.TTL
+ b[protocol] = i.Protocol
+ b.SetChecksum(i.Checksum)
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr)
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
+}
+
+const (
+ udpSrcPort = 0
+ udpDstPort = 2
+ udpLength = 4
+ udpChecksum = 6
+)
+
+// UDPFields contains the fields of a UDP packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type UDPFields struct {
+ // SrcPort is the "source port" field of a UDP packet.
+ SrcPort uint16
+
+ // DstPort is the "destination port" field of a UDP packet.
+ DstPort uint16
+
+ // Length is the "length" field of a UDP packet.
+ Length uint16
+
+ // Checksum is the "checksum" field of a UDP packet.
+ Checksum uint16
+}
+
+// UDP represents a UDP header stored in a byte array.
+type UDP []byte
+
+const (
+ // UDPMinimumSize is the minimum size of a valid UDP packet.
+ UDPMinimumSize = 8
+
+ // UDPProtocolNumber is UDP's transport protocol number.
+ UDPProtocolNumber TransportProtocolNumber = 17
+)
+
+// SourcePort returns the "source port" field of the udp header.
+func (b UDP) SourcePort() uint16 {
+ return binary.BigEndian.Uint16(b[udpSrcPort:])
+}
+
+// DestinationPort returns the "destination port" field of the udp header.
+func (b UDP) DestinationPort() uint16 {
+ return binary.BigEndian.Uint16(b[udpDstPort:])
+}
+
+// Length returns the "length" field of the udp header.
+func (b UDP) Length() uint16 {
+ return binary.BigEndian.Uint16(b[udpLength:])
+}
+
+// SetSourcePort sets the "source port" field of the udp header.
+func (b UDP) SetSourcePort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], port)
+}
+
+// SetDestinationPort sets the "destination port" field of the udp header.
+func (b UDP) SetDestinationPort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpDstPort:], port)
+}
+
+// SetChecksum sets the "checksum" field of the udp header.
+func (b UDP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
+}
+
+// Payload returns the data contained in the UDP datagram.
+func (b UDP) Payload() []byte {
+ return b[UDPMinimumSize:]
+}
+
+// Checksum returns the "checksum" field of the udp header.
+func (b UDP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[udpChecksum:])
+}
+
+// CalculateChecksum calculates the checksum of the udp packet, given the total
+// length of the packet and the checksum of the network-layer pseudo-header
+// (excluding the total length) and the checksum of the payload.
+func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
+ // Add the length portion of the checksum to the pseudo-checksum.
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ checksum := Checksum(tmp, partialChecksum)
+
+ // Calculate the rest of the checksum.
+ return Checksum(b[:UDPMinimumSize], checksum)
+}
+
+// Encode encodes all the fields of the udp header.
+func (b UDP) Encode(u *UDPFields) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
+ binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
+ binary.BigEndian.PutUint16(b[udpLength:], u.Length)
+ binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
+}
+
+func calculateChecksum(buf []byte, initial uint32) uint16 {
+ v := initial
+
+ l := len(buf)
+ if l&1 != 0 {
+ l--
+ v += uint32(buf[l]) << 8
+ }
+
+ for i := 0; i < l; i += 2 {
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ }
+
+ return ChecksumCombine(uint16(v), uint16(v>>16))
+}
+
+// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
+// given byte array.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func Checksum(buf []byte, initial uint16) uint16 {
+ return calculateChecksum(buf, uint32(initial))
+}
+
+// ChecksumCombine combines the two uint16 to form their checksum. This is done
+// by adding them and the carry.
+//
+// Note that checksum a must have been computed on an even number of bytes.
+func ChecksumCombine(a, b uint16) uint16 {
+ v := uint32(a) + uint32(b)
+ return uint16(v + v>>16)
+}
+
+// PseudoHeaderChecksum calculates the pseudo-header checksum for the
+// given destination protocol and network address, ignoring the length
+// field. Pseudo-headers are needed by transport layers when calculating
+// their own checksum.
+func PseudoHeaderChecksum(protocol TransportProtocolNumber, srcAddr net.IP, dstAddr net.IP) uint16 {
+ xsum := Checksum([]byte(srcAddr), 0)
+ xsum = Checksum([]byte(dstAddr), xsum)
+ return Checksum([]byte{0, uint8(protocol)}, xsum)
+}
+
+func udp4pkt(packet []byte, dest *net.UDPAddr, src *net.UDPAddr) []byte {
+ ipLen := IPv4MinimumSize
+ udpLen := UDPMinimumSize
+
+ h := make([]byte, 0, ipLen+udpLen+len(packet))
+ hdr := uio.NewBigEndianBuffer(h)
+
+ ipv4fields := &IPv4Fields{
+ IHL: IPv4MinimumSize,
+ TotalLength: uint16(ipLen + udpLen + len(packet)),
+ TTL: 30,
+ Protocol: uint8(UDPProtocolNumber),
+ SrcAddr: src.IP.To4(),
+ DstAddr: dest.IP.To4(),
+ }
+ ipv4hdr := IPv4(hdr.WriteN(ipLen))
+ ipv4hdr.Encode(ipv4fields)
+ ipv4hdr.SetChecksum(^ipv4hdr.CalculateChecksum())
+
+ udphdr := UDP(hdr.WriteN(udpLen))
+ udphdr.Encode(&UDPFields{
+ SrcPort: uint16(src.Port),
+ DstPort: uint16(dest.Port),
+ Length: uint16(udpLen + len(packet)),
+ })
+
+ xsum := Checksum(packet, PseudoHeaderChecksum(
+ ipv4hdr.TransportProtocol(), ipv4fields.SrcAddr, ipv4fields.DstAddr))
+ udphdr.SetChecksum(^udphdr.CalculateChecksum(xsum, udphdr.Length()))
+
+ hdr.WriteBytes(packet)
+ return hdr.Data()
+}