// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package stack

import (
	"encoding/binary"
	"math"
	"testing"
	"time"

	"gvisor.dev/gvisor/pkg/sync"
	"gvisor.dev/gvisor/pkg/tcpip"
	"gvisor.dev/gvisor/pkg/tcpip/buffer"
	"gvisor.dev/gvisor/pkg/tcpip/header"
)

const (
	fwdTestNetNumber           tcpip.NetworkProtocolNumber = math.MaxUint32
	fwdTestNetHeaderLen                                    = 12
	fwdTestNetDefaultPrefixLen                             = 8

	// fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
	// except where another value is explicitly used. It is chosen to match
	// the MTU of loopback interfaces on linux systems.
	fwdTestNetDefaultMTU = 65536

	dstAddrOffset        = 0
	srcAddrOffset        = 1
	protocolNumberOffset = 2
)

// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fwdTestNetworkEndpoint struct {
	AddressableEndpointState

	nic        NetworkInterface
	proto      *fwdTestNetworkProtocol
	dispatcher TransportDispatcher
}

var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)

func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error {
	return nil
}

func (*fwdTestNetworkEndpoint) Enabled() bool {
	return true
}

func (*fwdTestNetworkEndpoint) Disable() {}

func (f *fwdTestNetworkEndpoint) MTU() uint32 {
	return f.nic.MTU() - uint32(f.MaxHeaderLength())
}

func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
	return 123
}

func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
	// Dispatch the packet to the transport protocol.
	f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}

func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
	return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
}

func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
	return 0
}

func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
	return f.proto.Number()
}

func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
	// Add the protocol's header to the packet and send it to the link
	// endpoint.
	b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen)
	b[dstAddrOffset] = r.RemoteAddress[0]
	b[srcAddrOffset] = r.LocalAddress[0]
	b[protocolNumberOffset] = byte(params.Protocol)

	return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt)
}

// WritePackets implements LinkEndpoint.WritePackets.
func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
	panic("not implemented")
}

func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
	return tcpip.ErrNotSupported
}

func (f *fwdTestNetworkEndpoint) Close() {
	f.AddressableEndpointState.Cleanup()
}

// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
type fwdTestNetworkProtocol struct {
	addrCache              *linkAddrCache
	neigh                  *neighborCache
	addrResolveDelay       time.Duration
	onLinkAddressResolved  func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress)
	onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)

	mu struct {
		sync.RWMutex
		forwarding bool
	}
}

var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)

func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
	return fwdTestNetNumber
}

func (f *fwdTestNetworkProtocol) MinimumPacketSize() int {
	return fwdTestNetHeaderLen
}

func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
	return fwdTestNetDefaultPrefixLen
}

func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
	return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}

func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
	netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen)
	if !ok {
		return 0, false, false
	}
	return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}

func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint {
	e := &fwdTestNetworkEndpoint{
		nic:        nic,
		proto:      f,
		dispatcher: dispatcher,
	}
	e.AddressableEndpointState.Init(e)
	return e
}

func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
	return tcpip.ErrUnknownProtocolOption
}

func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
	return tcpip.ErrUnknownProtocolOption
}

func (*fwdTestNetworkProtocol) Close() {}

func (*fwdTestNetworkProtocol) Wait() {}

func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
	if f.onLinkAddressResolved != nil {
		time.AfterFunc(f.addrResolveDelay, func() {
			f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr)
		})
	}
	return nil
}

func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
	if f.onResolveStaticAddress != nil {
		return f.onResolveStaticAddress(addr)
	}
	return "", false
}

func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
	return fwdTestNetNumber
}

// Forwarding implements stack.ForwardingNetworkProtocol.
func (f *fwdTestNetworkProtocol) Forwarding() bool {
	f.mu.RLock()
	defer f.mu.RUnlock()
	return f.mu.forwarding

}

// SetForwarding implements stack.ForwardingNetworkProtocol.
func (f *fwdTestNetworkProtocol) SetForwarding(v bool) {
	f.mu.Lock()
	defer f.mu.Unlock()
	f.mu.forwarding = v
}

// fwdTestPacketInfo holds all the information about an outbound packet.
type fwdTestPacketInfo struct {
	RemoteLinkAddress tcpip.LinkAddress
	LocalLinkAddress  tcpip.LinkAddress
	Pkt               *PacketBuffer
}

type fwdTestLinkEndpoint struct {
	dispatcher NetworkDispatcher
	mtu        uint32
	linkAddr   tcpip.LinkAddress

	// C is where outbound packets are queued.
	C chan fwdTestPacketInfo
}

// InjectInbound injects an inbound packet.
func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
	e.InjectLinkAddr(protocol, "", pkt)
}

// InjectLinkAddr injects an inbound packet with a remote link address.
func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) {
	e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt)
}

// Attach saves the stack network-layer dispatcher for use later when packets
// are injected.
func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
	e.dispatcher = dispatcher
}

// IsAttached implements stack.LinkEndpoint.IsAttached.
func (e *fwdTestLinkEndpoint) IsAttached() bool {
	return e.dispatcher != nil
}

// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *fwdTestLinkEndpoint) MTU() uint32 {
	return e.mtu
}

// Capabilities implements stack.LinkEndpoint.Capabilities.
func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
	caps := LinkEndpointCapabilities(0)
	return caps | CapabilityResolutionRequired
}

// GSOMaxSize returns the maximum GSO packet size.
func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 {
	return 1 << 15
}

// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
	return 0
}

// LinkAddress returns the link address of this endpoint.
func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
	return e.linkAddr
}

func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
	p := fwdTestPacketInfo{
		RemoteLinkAddress: r.RemoteLinkAddress,
		LocalLinkAddress:  r.LocalLinkAddress,
		Pkt:               pkt,
	}

	select {
	case e.C <- p:
	default:
	}

	return nil
}

// WritePackets stores outbound packets into the channel.
func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
	n := 0
	for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
		e.WritePacket(r, gso, protocol, pkt)
		n++
	}

	return n, nil
}

// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
	p := fwdTestPacketInfo{
		Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}),
	}

	select {
	case e.C <- p:
	default:
	}

	return nil
}

// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}

// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
	panic("not implemented")
}

// AddHeader implements stack.LinkEndpoint.AddHeader.
func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
	panic("not implemented")
}

func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) {
	// Create a stack with the network protocol and two NICs.
	s := New(Options{
		NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }},
		UseNeighborCache: useNeighborCache,
	})

	if !useNeighborCache {
		proto.addrCache = s.linkAddrCache
	}

	// Enable forwarding.
	s.SetForwarding(proto.Number(), true)

	// NIC 1 has the link address "a", and added the network address 1.
	ep1 = &fwdTestLinkEndpoint{
		C:        make(chan fwdTestPacketInfo, 300),
		mtu:      fwdTestNetDefaultMTU,
		linkAddr: "a",
	}
	if err := s.CreateNIC(1, ep1); err != nil {
		t.Fatal("CreateNIC #1 failed:", err)
	}
	if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
		t.Fatal("AddAddress #1 failed:", err)
	}

	// NIC 2 has the link address "b", and added the network address 2.
	ep2 = &fwdTestLinkEndpoint{
		C:        make(chan fwdTestPacketInfo, 300),
		mtu:      fwdTestNetDefaultMTU,
		linkAddr: "b",
	}
	if err := s.CreateNIC(2, ep2); err != nil {
		t.Fatal("CreateNIC #2 failed:", err)
	}
	if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
		t.Fatal("AddAddress #2 failed:", err)
	}

	if useNeighborCache {
		// Control the neighbor cache for NIC 2.
		nic, ok := s.nics[2]
		if !ok {
			t.Fatal("failed to get the neighbor cache for NIC 2")
		}
		proto.neigh = nic.neigh
	}

	// Route all packets to NIC 2.
	{
		subnet, err := tcpip.NewSubnet("\x00", "\x00")
		if err != nil {
			t.Fatal(err)
		}
		s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}})
	}

	return ep1, ep2
}

func TestForwardingWithStaticResolver(t *testing.T) {
	tests := []struct {
		name             string
		useNeighborCache bool
	}{
		{
			name:             "linkAddrCache",
			useNeighborCache: false,
		},
		{
			name:             "neighborCache",
			useNeighborCache: true,
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			// Create a network protocol with a static resolver.
			proto := &fwdTestNetworkProtocol{
				onResolveStaticAddress:
				// The network address 3 is resolved to the link address "c".
				func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
					if addr == "\x03" {
						return "c", true
					}
					return "", false
				},
			}

			ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache)

			// Inject an inbound packet to address 3 on NIC 1, and see if it is
			// forwarded to NIC 2.
			buf := buffer.NewView(30)
			buf[dstAddrOffset] = 3
			ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
				Data: buf.ToVectorisedView(),
			}))

			var p fwdTestPacketInfo

			select {
			case p = <-ep2.C:
			default:
				t.Fatal("packet not forwarded")
			}

			// Test that the static address resolution happened correctly.
			if p.RemoteLinkAddress != "c" {
				t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
			}
			if p.LocalLinkAddress != "b" {
				t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
			}
		})
	}
}

func TestForwardingWithFakeResolver(t *testing.T) {
	tests := []struct {
		name             string
		useNeighborCache bool
		proto            *fwdTestNetworkProtocol
	}{
		{
			name:             "linkAddrCache",
			useNeighborCache: false,
			proto: &fwdTestNetworkProtocol{
				addrResolveDelay: 500 * time.Millisecond,
				onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
					// Any address will be resolved to the link address "c".
					cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
				},
			},
		},
		{
			name:             "neighborCache",
			useNeighborCache: true,
			proto: &fwdTestNetworkProtocol{
				addrResolveDelay: 500 * time.Millisecond,
				onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
					t.Helper()
					if len(remoteLinkAddr) != 0 {
						t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
					}
					// Any address will be resolved to the link address "c".
					neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
						Solicited: true,
						Override:  false,
						IsRouter:  false,
					})
				},
			},
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)

			// Inject an inbound packet to address 3 on NIC 1, and see if it is
			// forwarded to NIC 2.
			buf := buffer.NewView(30)
			buf[dstAddrOffset] = 3
			ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
				Data: buf.ToVectorisedView(),
			}))

			var p fwdTestPacketInfo

			select {
			case p = <-ep2.C:
			case <-time.After(time.Second):
				t.Fatal("packet not forwarded")
			}

			// Test that the address resolution happened correctly.
			if p.RemoteLinkAddress != "c" {
				t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
			}
			if p.LocalLinkAddress != "b" {
				t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
			}
		})
	}
}

func TestForwardingWithNoResolver(t *testing.T) {
	// Create a network protocol without a resolver.
	proto := &fwdTestNetworkProtocol{}

	// Whether or not we use the neighbor cache here does not matter since
	// neither linkAddrCache nor neighborCache will be used.
	ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */)

	// inject an inbound packet to address 3 on NIC 1, and see if it is
	// forwarded to NIC 2.
	buf := buffer.NewView(30)
	buf[dstAddrOffset] = 3
	ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
		Data: buf.ToVectorisedView(),
	}))

	select {
	case <-ep2.C:
		t.Fatal("Packet should not be forwarded")
	case <-time.After(time.Second):
	}
}

func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
	tests := []struct {
		name             string
		useNeighborCache bool
		proto            *fwdTestNetworkProtocol
	}{
		{
			name:             "linkAddrCache",
			useNeighborCache: false,
			proto: &fwdTestNetworkProtocol{
				addrResolveDelay: 500 * time.Millisecond,
				onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
					// Only packets to address 3 will be resolved to the
					// link address "c".
					if addr == "\x03" {
						cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
					}
				},
			},
		},
		{
			name:             "neighborCache",
			useNeighborCache: true,
			proto: &fwdTestNetworkProtocol{
				addrResolveDelay: 500 * time.Millisecond,
				onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
					t.Helper()
					if len(remoteLinkAddr) != 0 {
						t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
					}
					// Only packets to address 3 will be resolved to the
					// link address "c".
					if addr == "\x03" {
						neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
							Solicited: true,
							Override:  false,
							IsRouter:  false,
						})
					}
				},
			},
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)

			// Inject an inbound packet to address 4 on NIC 1. This packet should
			// not be forwarded.
			buf := buffer.NewView(30)
			buf[dstAddrOffset] = 4
			ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
				Data: buf.ToVectorisedView(),
			}))

			// Inject an inbound packet to address 3 on NIC 1, and see if it is
			// forwarded to NIC 2.
			buf = buffer.NewView(30)
			buf[dstAddrOffset] = 3
			ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
				Data: buf.ToVectorisedView(),
			}))

			var p fwdTestPacketInfo

			select {
			case p = <-ep2.C:
			case <-time.After(time.Second):
				t.Fatal("packet not forwarded")
			}

			if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
				t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
			}

			// Test that the address resolution happened correctly.
			if p.RemoteLinkAddress != "c" {
				t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
			}
			if p.LocalLinkAddress != "b" {
				t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
			}
		})
	}
}

func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
	tests := []struct {
		name             string
		useNeighborCache bool
		proto            *fwdTestNetworkProtocol
	}{
		{
			name:             "linkAddrCache",
			useNeighborCache: false,
			proto: &fwdTestNetworkProtocol{
				addrResolveDelay: 500 * time.Millisecond,
				onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
					// Any packets will be resolved to the link address "c".
					cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
				},
			},
		},
		{
			name:             "neighborCache",
			useNeighborCache: true,
			proto: &fwdTestNetworkProtocol{
				addrResolveDelay: 500 * time.Millisecond,
				onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
					t.Helper()
					if len(remoteLinkAddr) != 0 {
						t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
					}
					// Any packets will be resolved to the link address "c".
					neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
						Solicited: true,
						Override:  false,
						IsRouter:  false,
					})
				},
			},
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)

			// Inject two inbound packets to address 3 on NIC 1.
			for i := 0; i < 2; i++ {
				buf := buffer.NewView(30)
				buf[dstAddrOffset] = 3
				ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
					Data: buf.ToVectorisedView(),
				}))
			}

			for i := 0; i < 2; i++ {
				var p fwdTestPacketInfo

				select {
				case p = <-ep2.C:
				case <-time.After(time.Second):
					t.Fatal("packet not forwarded")
				}

				if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
					t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
				}

				// Test that the address resolution happened correctly.
				if p.RemoteLinkAddress != "c" {
					t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
				}
				if p.LocalLinkAddress != "b" {
					t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
				}
			}
		})
	}
}

func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
	tests := []struct {
		name             string
		useNeighborCache bool
		proto            *fwdTestNetworkProtocol
	}{
		{
			name:             "linkAddrCache",
			useNeighborCache: false,
			proto: &fwdTestNetworkProtocol{
				addrResolveDelay: 500 * time.Millisecond,
				onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
					// Any packets will be resolved to the link address "c".
					cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
				},
			},
		},
		{
			name:             "neighborCache",
			useNeighborCache: true,
			proto: &fwdTestNetworkProtocol{
				addrResolveDelay: 500 * time.Millisecond,
				onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
					t.Helper()
					if len(remoteLinkAddr) != 0 {
						t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
					}
					// Any packets will be resolved to the link address "c".
					neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
						Solicited: true,
						Override:  false,
						IsRouter:  false,
					})
				},
			},
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)

			for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
				// Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
				buf := buffer.NewView(30)
				buf[dstAddrOffset] = 3
				// Set the packet sequence number.
				binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
				ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
					Data: buf.ToVectorisedView(),
				}))
			}

			for i := 0; i < maxPendingPacketsPerResolution; i++ {
				var p fwdTestPacketInfo

				select {
				case p = <-ep2.C:
				case <-time.After(time.Second):
					t.Fatal("packet not forwarded")
				}

				b := PayloadSince(p.Pkt.NetworkHeader())
				if b[dstAddrOffset] != 3 {
					t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
				}
				if len(b) < fwdTestNetHeaderLen+2 {
					t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b)
				}
				seqNumBuf := b[fwdTestNetHeaderLen:]

				// The first 5 packets should not be forwarded so the sequence number should
				// start with 5.
				want := uint16(i + 5)
				if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
					t.Fatalf("got the packet #%d, want = #%d", n, want)
				}

				// Test that the address resolution happened correctly.
				if p.RemoteLinkAddress != "c" {
					t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
				}
				if p.LocalLinkAddress != "b" {
					t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
				}
			}
		})
	}
}

func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
	tests := []struct {
		name             string
		useNeighborCache bool
		proto            *fwdTestNetworkProtocol
	}{
		{
			name:             "linkAddrCache",
			useNeighborCache: false,
			proto: &fwdTestNetworkProtocol{
				addrResolveDelay: 500 * time.Millisecond,
				onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
					// Any packets will be resolved to the link address "c".
					cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
				},
			},
		},
		{
			name:             "neighborCache",
			useNeighborCache: true,
			proto: &fwdTestNetworkProtocol{
				addrResolveDelay: 500 * time.Millisecond,
				onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
					t.Helper()
					if len(remoteLinkAddr) != 0 {
						t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
					}
					// Any packets will be resolved to the link address "c".
					neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
						Solicited: true,
						Override:  false,
						IsRouter:  false,
					})
				},
			},
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)

			for i := 0; i < maxPendingResolutions+5; i++ {
				// Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
				// Each packet has a different destination address (3 to
				// maxPendingResolutions + 7).
				buf := buffer.NewView(30)
				buf[dstAddrOffset] = byte(3 + i)
				ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
					Data: buf.ToVectorisedView(),
				}))
			}

			for i := 0; i < maxPendingResolutions; i++ {
				var p fwdTestPacketInfo

				select {
				case p = <-ep2.C:
				case <-time.After(time.Second):
					t.Fatal("packet not forwarded")
				}

				// The first 5 packets (address 3 to 7) should not be forwarded
				// because their address resolutions are interrupted.
				if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 {
					t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset])
				}

				// Test that the address resolution happened correctly.
				if p.RemoteLinkAddress != "c" {
					t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
				}
				if p.LocalLinkAddress != "b" {
					t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
				}
			}
		})
	}
}