diff options
-rw-r--r-- | dhcpv4/nclient4/conn_unix.go | 16 | ||||
-rw-r--r-- | dhcpv4/nclient4/ipv4.go | 54 |
2 files changed, 48 insertions, 22 deletions
diff --git a/dhcpv4/nclient4/conn_unix.go b/dhcpv4/nclient4/conn_unix.go index 239d007..1495dc2 100644 --- a/dhcpv4/nclient4/conn_unix.go +++ b/dhcpv4/nclient4/conn_unix.go @@ -99,24 +99,16 @@ func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { pkt = pkt[:n] buf := uio.NewBigEndianBuffer(pkt) - // To read the header length, access data directly. - if !buf.Has(ipv4MinimumSize) { - continue - } - ipHdr := ipv4(buf.Data()) - headerLength := ipHdr.headerLength() - if !buf.Has(int(headerLength)) { + if !ipHdr.isValid(n) { continue } - ipHdr = ipv4(buf.Consume(int(headerLength))) + ipHdr = ipv4(buf.Consume(int(ipHdr.headerLength()))) - if headerLength > protocol { - if ipHdr.transportProtocol() != udpProtocolNumber { - continue - } + if ipHdr.transportProtocol() != udpProtocolNumber { + continue } if !buf.Has(udpHdrLen) { diff --git a/dhcpv4/nclient4/ipv4.go b/dhcpv4/nclient4/ipv4.go index c221965..3a3427a 100644 --- a/dhcpv4/nclient4/ipv4.go +++ b/dhcpv4/nclient4/ipv4.go @@ -14,6 +14,7 @@ // // This file contains code taken from gVisor. +//go:build go1.12 // +build go1.12 package nclient4 @@ -26,16 +27,17 @@ import ( ) const ( - versIHL = 0 - tos = 1 - totalLen = 2 - id = 4 - flagsFO = 6 - ttl = 8 - protocol = 9 - checksumOff = 10 - srcAddr = 12 - dstAddr = 16 + versIHL = 0 + tos = 1 + totalLen = 2 + id = 4 + flagsFO = 6 + ttl = 8 + protocol = 9 + checksumOff = 10 + srcAddr = 12 + dstAddr = 16 + ipVersionShift = 4 ) // transportProtocolNumber is the number of a transport protocol. @@ -95,8 +97,40 @@ const ( // ipv4AddressSize is the size, in bytes, of an IPv4 address. ipv4AddressSize = 4 + + // IPv4Version is the version of the IPv4 protocol. + ipv4Version = 4 ) +// IPVersion returns the version of IP used in the given packet. It returns -1 +// if the packet is not large enough to contain the version field. +func ipVersion(b []byte) int { + // Length must be at least offset+length of version field. + if len(b) < versIHL+1 { + return -1 + } + return int(b[versIHL] >> ipVersionShift) +} + +// IsValid performs basic validation on the packet. +func (b ipv4) isValid(pktSize int) bool { + if len(b) < ipv4MinimumSize { + return false + } + + hlen := int(b.headerLength()) + tlen := int(b.totalLength()) + if hlen < ipv4MinimumSize || hlen > tlen || tlen > pktSize { + return false + } + + if ipVersion(b) != ipv4Version { + return false + } + + return true +} + // headerLength returns the value of the "header length" field of the ipv4 // header. func (b ipv4) headerLength() uint8 { |