// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package fdbased

import (
	"fmt"
	"math/rand"
	"reflect"
	"syscall"
	"testing"
	"time"

	"gvisor.googlesource.com/gvisor/pkg/tcpip"
	"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
	"gvisor.googlesource.com/gvisor/pkg/tcpip/header"
	"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
)

type packetInfo struct {
	raddr    tcpip.LinkAddress
	proto    tcpip.NetworkProtocolNumber
	contents buffer.View
}

type context struct {
	t    *testing.T
	fds  [2]int
	ep   stack.LinkEndpoint
	ch   chan packetInfo
	done chan struct{}
}

func newContext(t *testing.T, opt *Options) *context {
	fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
	if err != nil {
		t.Fatalf("Socketpair failed: %v", err)
	}

	done := make(chan struct{}, 1)
	opt.ClosedFunc = func(*tcpip.Error) {
		done <- struct{}{}
	}

	opt.FD = fds[1]
	ep := stack.FindLinkEndpoint(New(opt)).(*endpoint)

	c := &context{
		t:    t,
		fds:  fds,
		ep:   ep,
		ch:   make(chan packetInfo, 100),
		done: done,
	}

	ep.Attach(c)

	return c
}

func (c *context) cleanup() {
	syscall.Close(c.fds[0])
	<-c.done
	syscall.Close(c.fds[1])
}

func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
	c.ch <- packetInfo{remoteLinkAddr, protocol, vv.ToView()}
}

func TestNoEthernetProperties(t *testing.T) {
	const mtu = 1500
	c := newContext(t, &Options{MTU: mtu})
	defer c.cleanup()

	if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v {
		t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
	}

	if want, v := uint32(mtu), c.ep.MTU(); want != v {
		t.Fatalf("MTU() = %v, want %v", v, want)
	}
}

func TestEthernetProperties(t *testing.T) {
	const mtu = 1500
	c := newContext(t, &Options{EthernetHeader: true, MTU: mtu})
	defer c.cleanup()

	if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v {
		t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
	}

	if want, v := uint32(mtu), c.ep.MTU(); want != v {
		t.Fatalf("MTU() = %v, want %v", v, want)
	}
}

func TestAddress(t *testing.T) {
	const mtu = 1500
	addrs := []tcpip.LinkAddress{"", "abc", "def"}
	for _, a := range addrs {
		t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) {
			c := newContext(t, &Options{Address: a, MTU: mtu})
			defer c.cleanup()

			if want, v := a, c.ep.LinkAddress(); want != v {
				t.Fatalf("LinkAddress() = %v, want %v", v, want)
			}
		})
	}
}

func TestWritePacket(t *testing.T) {
	const (
		mtu   = 1500
		laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66")
		raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc")
		proto = 10
	)

	lengths := []int{0, 100, 1000}
	eths := []bool{true, false}

	for _, eth := range eths {
		for _, plen := range lengths {
			t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
				c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
				defer c.cleanup()

				r := &stack.Route{
					RemoteLinkAddress: raddr,
				}

				// Build header.
				hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100)
				b := hdr.Prepend(100)
				for i := range b {
					b[i] = uint8(rand.Intn(256))
				}

				// Buiild payload and write.
				payload := make([]byte, plen)
				for i := range payload {
					payload[i] = uint8(rand.Intn(256))
				}
				want := append(hdr.UsedBytes(), payload...)
				if err := c.ep.WritePacket(r, &hdr, payload, proto); err != nil {
					t.Fatalf("WritePacket failed: %v", err)
				}

				// Read from fd, then compare with what we wrote.
				b = make([]byte, mtu)
				n, err := syscall.Read(c.fds[0], b)
				if err != nil {
					t.Fatalf("Read failed: %v", err)
				}
				b = b[:n]
				if eth {
					h := header.Ethernet(b)
					b = b[header.EthernetMinimumSize:]

					if a := h.SourceAddress(); a != laddr {
						t.Fatalf("SourceAddress() = %v, want %v", a, laddr)
					}

					if a := h.DestinationAddress(); a != raddr {
						t.Fatalf("DestinationAddress() = %v, want %v", a, raddr)
					}

					if et := h.Type(); et != proto {
						t.Fatalf("Type() = %v, want %v", et, proto)
					}
				}
				if len(b) != len(want) {
					t.Fatalf("Read returned %v bytes, want %v", len(b), len(want))
				}
				if !reflect.DeepEqual(b, want) {
					t.Fatalf("Read returned %x, want %x", b, want)
				}
			})
		}
	}
}

func TestDeliverPacket(t *testing.T) {
	const (
		mtu   = 1500
		laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66")
		raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc")
		proto = 10
	)

	lengths := []int{100, 1000}
	eths := []bool{true, false}

	for _, eth := range eths {
		for _, plen := range lengths {
			t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
				c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
				defer c.cleanup()

				// Build packet.
				b := make([]byte, plen)
				all := b
				for i := range b {
					b[i] = uint8(rand.Intn(256))
				}

				if !eth {
					// So that it looks like an IPv4 packet.
					b[0] = 0x40
				} else {
					hdr := make(header.Ethernet, header.EthernetMinimumSize)
					hdr.Encode(&header.EthernetFields{
						SrcAddr: raddr,
						DstAddr: laddr,
						Type:    proto,
					})
					all = append(hdr, b...)
				}

				// Write packet via the file descriptor.
				if _, err := syscall.Write(c.fds[0], all); err != nil {
					t.Fatalf("Write failed: %v", err)
				}

				// Receive packet through the endpoint.
				select {
				case pi := <-c.ch:
					want := packetInfo{
						raddr:    raddr,
						proto:    proto,
						contents: b,
					}
					if !eth {
						want.proto = header.IPv4ProtocolNumber
						want.raddr = ""
					}
					if !reflect.DeepEqual(want, pi) {
						t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
					}
				case <-time.After(10 * time.Second):
					t.Fatalf("Timed out waiting for packet")
				}
			})
		}
	}
}

func TestBufConfigMaxLength(t *testing.T) {
	got := 0
	for _, i := range BufConfig {
		got += i
	}
	want := header.MaxIPPacketSize // maximum TCP packet size
	if got < want {
		t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want)
	}
}

func TestBufConfigFirst(t *testing.T) {
	// The stack assumes that the TCP/IP header is enterily contained in the first view.
	// Therefore, the first view needs to be large enough to contain the maximum TCP/IP
	// header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP).
	want := 120
	got := BufConfig[0]
	if got < want {
		t.Errorf("first view has an invalid size: got %d, want >= %d", got, want)
	}
}

func build(bufConfig []int) *endpoint {
	e := &endpoint{
		views:  make([]buffer.View, len(bufConfig)),
		iovecs: make([]syscall.Iovec, len(bufConfig)),
	}
	e.allocateViews(bufConfig)
	return e
}

var capLengthTestCases = []struct {
	comment     string
	config      []int
	n           int
	wantUsed    int
	wantLengths []int
}{
	{
		comment:     "Single slice",
		config:      []int{2},
		n:           1,
		wantUsed:    1,
		wantLengths: []int{1},
	},
	{
		comment:     "Multiple slices",
		config:      []int{1, 2},
		n:           2,
		wantUsed:    2,
		wantLengths: []int{1, 1},
	},
	{
		comment:     "Entire buffer",
		config:      []int{1, 2},
		n:           3,
		wantUsed:    2,
		wantLengths: []int{1, 2},
	},
	{
		comment:     "Entire buffer but not on the last slice",
		config:      []int{1, 2, 3},
		n:           3,
		wantUsed:    2,
		wantLengths: []int{1, 2, 3},
	},
}

func TestCapLength(t *testing.T) {
	for _, c := range capLengthTestCases {
		e := build(c.config)
		used := e.capViews(c.n, c.config)
		if used != c.wantUsed {
			t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
		}
		lengths := make([]int, len(e.views))
		for i, v := range e.views {
			lengths[i] = len(v)
		}
		if !reflect.DeepEqual(lengths, c.wantLengths) {
			t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
		}

	}
}