// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package parse provides utilities to parse packets.
package parse

import (
	"fmt"

	"gvisor.dev/gvisor/pkg/tcpip"
	"gvisor.dev/gvisor/pkg/tcpip/buffer"
	"gvisor.dev/gvisor/pkg/tcpip/header"
	"gvisor.dev/gvisor/pkg/tcpip/stack"
)

// ARP populates pkt's network header with an ARP header found in
// pkt.Data.
//
// Returns true if the header was successfully parsed.
func ARP(pkt *stack.PacketBuffer) bool {
	_, ok := pkt.NetworkHeader().Consume(header.ARPSize)
	if ok {
		pkt.NetworkProtocolNumber = header.ARPProtocolNumber
	}
	return ok
}

// IPv4 parses an IPv4 packet found in pkt.Data and populates pkt's network
// header with the IPv4 header.
//
// Returns true if the header was successfully parsed.
func IPv4(pkt *stack.PacketBuffer) bool {
	hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
	if !ok {
		return false
	}
	ipHdr := header.IPv4(hdr)

	// Header may have options, determine the true header length.
	headerLen := int(ipHdr.HeaderLength())
	if headerLen < header.IPv4MinimumSize {
		// TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in
		// order for the packet to be valid. Figure out if we want to reject this
		// case.
		headerLen = header.IPv4MinimumSize
	}
	hdr, ok = pkt.NetworkHeader().Consume(headerLen)
	if !ok {
		return false
	}
	ipHdr = header.IPv4(hdr)

	pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
	pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
	return true
}

// IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network
// header with the IPv6 header.
func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) {
	hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
	if !ok {
		return 0, 0, 0, false, false
	}
	ipHdr := header.IPv6(hdr)

	// dataClone consists of:
	// - Any IPv6 header bytes after the first 40 (i.e. extensions).
	// - The transport header, if present.
	// - Any other payload data.
	views := [8]buffer.View{}
	dataClone := pkt.Data.Clone(views[:])
	dataClone.TrimFront(header.IPv6MinimumSize)
	it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)

	// Iterate over the IPv6 extensions to find their length.
	var nextHdr tcpip.TransportProtocolNumber
	var extensionsSize int

traverseExtensions:
	for {
		extHdr, done, err := it.Next()
		if err != nil {
			break
		}

		// If we exhaust the extension list, the entire packet is the IPv6 header
		// and (possibly) extensions.
		if done {
			extensionsSize = dataClone.Size()
			break
		}

		switch extHdr := extHdr.(type) {
		case header.IPv6FragmentExtHdr:
			if fragID == 0 && fragOffset == 0 && !fragMore {
				fragID = extHdr.ID()
				fragOffset = extHdr.FragmentOffset()
				fragMore = extHdr.More()
			}

		case header.IPv6RawPayloadHeader:
			// We've found the payload after any extensions.
			extensionsSize = dataClone.Size() - extHdr.Buf.Size()
			nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
			break traverseExtensions

		default:
			// Any other extension is a no-op, keep looping until we find the payload.
		}
	}

	// Put the IPv6 header with extensions in pkt.NetworkHeader().
	hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize)
	if !ok {
		panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
	}
	ipHdr = header.IPv6(hdr)
	pkt.Data.CapLength(int(ipHdr.PayloadLength()))
	pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber

	return nextHdr, fragID, fragOffset, fragMore, true
}

// UDP parses a UDP packet found in pkt.Data and populates pkt's transport
// header with the UDP header.
//
// Returns true if the header was successfully parsed.
func UDP(pkt *stack.PacketBuffer) bool {
	_, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize)
	pkt.TransportProtocolNumber = header.UDPProtocolNumber
	return ok
}

// TCP parses a TCP packet found in pkt.Data and populates pkt's transport
// header with the TCP header.
//
// Returns true if the header was successfully parsed.
func TCP(pkt *stack.PacketBuffer) bool {
	// TCP header is variable length, peek at it first.
	hdrLen := header.TCPMinimumSize
	hdr, ok := pkt.Data.PullUp(hdrLen)
	if !ok {
		return false
	}

	// If the header has options, pull those up as well.
	if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
		// TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of
		// packets.
		hdrLen = offset
	}

	_, ok = pkt.TransportHeader().Consume(hdrLen)
	pkt.TransportProtocolNumber = header.TCPProtocolNumber
	return ok
}