summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv4/nclient4/conn_unix.go16
-rw-r--r--dhcpv4/nclient4/ipv4.go54
2 files changed, 48 insertions, 22 deletions
diff --git a/dhcpv4/nclient4/conn_unix.go b/dhcpv4/nclient4/conn_unix.go
index 239d007..1495dc2 100644
--- a/dhcpv4/nclient4/conn_unix.go
+++ b/dhcpv4/nclient4/conn_unix.go
@@ -99,24 +99,16 @@ func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
pkt = pkt[:n]
buf := uio.NewBigEndianBuffer(pkt)
- // To read the header length, access data directly.
- if !buf.Has(ipv4MinimumSize) {
- continue
- }
-
ipHdr := ipv4(buf.Data())
- headerLength := ipHdr.headerLength()
- if !buf.Has(int(headerLength)) {
+ if !ipHdr.isValid(n) {
continue
}
- ipHdr = ipv4(buf.Consume(int(headerLength)))
+ ipHdr = ipv4(buf.Consume(int(ipHdr.headerLength())))
- if headerLength > protocol {
- if ipHdr.transportProtocol() != udpProtocolNumber {
- continue
- }
+ if ipHdr.transportProtocol() != udpProtocolNumber {
+ continue
}
if !buf.Has(udpHdrLen) {
diff --git a/dhcpv4/nclient4/ipv4.go b/dhcpv4/nclient4/ipv4.go
index c221965..3a3427a 100644
--- a/dhcpv4/nclient4/ipv4.go
+++ b/dhcpv4/nclient4/ipv4.go
@@ -14,6 +14,7 @@
//
// This file contains code taken from gVisor.
+//go:build go1.12
// +build go1.12
package nclient4
@@ -26,16 +27,17 @@ import (
)
const (
- versIHL = 0
- tos = 1
- totalLen = 2
- id = 4
- flagsFO = 6
- ttl = 8
- protocol = 9
- checksumOff = 10
- srcAddr = 12
- dstAddr = 16
+ versIHL = 0
+ tos = 1
+ totalLen = 2
+ id = 4
+ flagsFO = 6
+ ttl = 8
+ protocol = 9
+ checksumOff = 10
+ srcAddr = 12
+ dstAddr = 16
+ ipVersionShift = 4
)
// transportProtocolNumber is the number of a transport protocol.
@@ -95,8 +97,40 @@ const (
// ipv4AddressSize is the size, in bytes, of an IPv4 address.
ipv4AddressSize = 4
+
+ // IPv4Version is the version of the IPv4 protocol.
+ ipv4Version = 4
)
+// IPVersion returns the version of IP used in the given packet. It returns -1
+// if the packet is not large enough to contain the version field.
+func ipVersion(b []byte) int {
+ // Length must be at least offset+length of version field.
+ if len(b) < versIHL+1 {
+ return -1
+ }
+ return int(b[versIHL] >> ipVersionShift)
+}
+
+// IsValid performs basic validation on the packet.
+func (b ipv4) isValid(pktSize int) bool {
+ if len(b) < ipv4MinimumSize {
+ return false
+ }
+
+ hlen := int(b.headerLength())
+ tlen := int(b.totalLength())
+ if hlen < ipv4MinimumSize || hlen > tlen || tlen > pktSize {
+ return false
+ }
+
+ if ipVersion(b) != ipv4Version {
+ return false
+ }
+
+ return true
+}
+
// headerLength returns the value of the "header length" field of the ipv4
// header.
func (b ipv4) headerLength() uint8 {