// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package iptables_test

import (
	"bytes"
	"testing"

	"github.com/google/go-cmp/cmp"
	"gvisor.dev/gvisor/pkg/tcpip"
	"gvisor.dev/gvisor/pkg/tcpip/buffer"
	"gvisor.dev/gvisor/pkg/tcpip/checker"
	"gvisor.dev/gvisor/pkg/tcpip/header"
	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
	"gvisor.dev/gvisor/pkg/tcpip/network/arp"
	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
	"gvisor.dev/gvisor/pkg/tcpip/stack"
	"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
	"gvisor.dev/gvisor/pkg/tcpip/testutil"
	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
	"gvisor.dev/gvisor/pkg/waiter"
)

type inputIfNameMatcher struct {
	name string
}

var _ stack.Matcher = (*inputIfNameMatcher)(nil)

func (*inputIfNameMatcher) Name() string {
	return "inputIfNameMatcher"
}

func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) {
	return (hook == stack.Input && im.name != "" && im.name == inNicName), false
}

const (
	nicID          = 1
	nicName        = "nic1"
	anotherNicName = "nic2"
	linkAddr       = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
	srcAddrV4      = tcpip.Address("\x0a\x00\x00\x01")
	dstAddrV4      = tcpip.Address("\x0a\x00\x00\x02")
	srcAddrV6      = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
	dstAddrV6      = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
	payloadSize    = 20
)

func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
	t.Helper()
	s := stack.New(stack.Options{
		NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
	})
	e := channel.New(0, header.IPv6MinimumMTU, linkAddr)
	nicOpts := stack.NICOptions{Name: nicName}
	if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
		t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
	}
	protocolAddr := tcpip.ProtocolAddress{
		Protocol:          header.IPv6ProtocolNumber,
		AddressWithPrefix: dstAddrV6.WithPrefix(),
	}
	if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
	}
	return s, e
}

func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
	t.Helper()
	s := stack.New(stack.Options{
		NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
	})
	e := channel.New(0, header.IPv4MinimumMTU, linkAddr)
	nicOpts := stack.NICOptions{Name: nicName}
	if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
		t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
	}
	protocolAddr := tcpip.ProtocolAddress{
		Protocol:          header.IPv4ProtocolNumber,
		AddressWithPrefix: dstAddrV4.WithPrefix(),
	}
	if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
		t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
	}
	return s, e
}

func genPacketV6() *stack.PacketBuffer {
	pktSize := header.IPv6MinimumSize + payloadSize
	hdr := buffer.NewPrependable(pktSize)
	ip := header.IPv6(hdr.Prepend(pktSize))
	ip.Encode(&header.IPv6Fields{
		PayloadLength:     payloadSize,
		TransportProtocol: 99,
		HopLimit:          255,
		SrcAddr:           srcAddrV6,
		DstAddr:           dstAddrV6,
	})
	vv := hdr.View().ToVectorisedView()
	return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
}

func genPacketV4() *stack.PacketBuffer {
	pktSize := header.IPv4MinimumSize + payloadSize
	hdr := buffer.NewPrependable(pktSize)
	ip := header.IPv4(hdr.Prepend(pktSize))
	ip.Encode(&header.IPv4Fields{
		TOS:            0,
		TotalLength:    uint16(pktSize),
		ID:             1,
		Flags:          0,
		FragmentOffset: 16,
		TTL:            48,
		Protocol:       99,
		SrcAddr:        srcAddrV4,
		DstAddr:        dstAddrV4,
	})
	ip.SetChecksum(0)
	ip.SetChecksum(^ip.CalculateChecksum())
	vv := hdr.View().ToVectorisedView()
	return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
}

func TestIPTablesStatsForInput(t *testing.T) {
	tests := []struct {
		name               string
		setupStack         func(*testing.T) (*stack.Stack, *channel.Endpoint)
		setupFilter        func(*testing.T, *stack.Stack)
		genPacket          func() *stack.PacketBuffer
		proto              tcpip.NetworkProtocolNumber
		expectReceived     int
		expectInputDropped int
	}{
		{
			name:               "IPv6 Accept",
			setupStack:         genStackV6,
			setupFilter:        func(*testing.T, *stack.Stack) { /* no filter */ },
			genPacket:          genPacketV6,
			proto:              header.IPv6ProtocolNumber,
			expectReceived:     1,
			expectInputDropped: 0,
		},
		{
			name:               "IPv4 Accept",
			setupStack:         genStackV4,
			setupFilter:        func(*testing.T, *stack.Stack) { /* no filter */ },
			genPacket:          genPacketV4,
			proto:              header.IPv4ProtocolNumber,
			expectReceived:     1,
			expectInputDropped: 0,
		},
		{
			name:       "IPv6 Drop (input interface matches)",
			setupStack: genStackV6,
			setupFilter: func(t *testing.T, s *stack.Stack) {
				t.Helper()
				ipt := s.IPTables()
				filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
				ruleIdx := filter.BuiltinChains[stack.Input]
				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
				filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
				// Make sure the packet is not dropped by the next rule.
				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
				if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
					t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
				}
			},
			genPacket:          genPacketV6,
			proto:              header.IPv6ProtocolNumber,
			expectReceived:     1,
			expectInputDropped: 1,
		},
		{
			name:       "IPv4 Drop (input interface matches)",
			setupStack: genStackV4,
			setupFilter: func(t *testing.T, s *stack.Stack) {
				t.Helper()
				ipt := s.IPTables()
				filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
				ruleIdx := filter.BuiltinChains[stack.Input]
				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
				filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
				if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
					t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
				}
			},
			genPacket:          genPacketV4,
			proto:              header.IPv4ProtocolNumber,
			expectReceived:     1,
			expectInputDropped: 1,
		},
		{
			name:       "IPv6 Accept (input interface does not match)",
			setupStack: genStackV6,
			setupFilter: func(t *testing.T, s *stack.Stack) {
				t.Helper()
				ipt := s.IPTables()
				filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
				ruleIdx := filter.BuiltinChains[stack.Input]
				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
				if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
					t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
				}
			},
			genPacket:          genPacketV6,
			proto:              header.IPv6ProtocolNumber,
			expectReceived:     1,
			expectInputDropped: 0,
		},
		{
			name:       "IPv4 Accept (input interface does not match)",
			setupStack: genStackV4,
			setupFilter: func(t *testing.T, s *stack.Stack) {
				t.Helper()
				ipt := s.IPTables()
				filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
				ruleIdx := filter.BuiltinChains[stack.Input]
				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
				if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
					t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
				}
			},
			genPacket:          genPacketV4,
			proto:              header.IPv4ProtocolNumber,
			expectReceived:     1,
			expectInputDropped: 0,
		},
		{
			name:       "IPv6 Drop (input interface does not match but invert is true)",
			setupStack: genStackV6,
			setupFilter: func(t *testing.T, s *stack.Stack) {
				t.Helper()
				ipt := s.IPTables()
				filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
				ruleIdx := filter.BuiltinChains[stack.Input]
				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
					InputInterface:       anotherNicName,
					InputInterfaceInvert: true,
				}
				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
				if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
					t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
				}
			},
			genPacket:          genPacketV6,
			proto:              header.IPv6ProtocolNumber,
			expectReceived:     1,
			expectInputDropped: 1,
		},
		{
			name:       "IPv4 Drop (input interface does not match but invert is true)",
			setupStack: genStackV4,
			setupFilter: func(t *testing.T, s *stack.Stack) {
				t.Helper()
				ipt := s.IPTables()
				filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
				ruleIdx := filter.BuiltinChains[stack.Input]
				filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
					InputInterface:       anotherNicName,
					InputInterfaceInvert: true,
				}
				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
				if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
					t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
				}
			},
			genPacket:          genPacketV4,
			proto:              header.IPv4ProtocolNumber,
			expectReceived:     1,
			expectInputDropped: 1,
		},
		{
			name:       "IPv6 Accept (input interface does not match using a matcher)",
			setupStack: genStackV6,
			setupFilter: func(t *testing.T, s *stack.Stack) {
				t.Helper()
				ipt := s.IPTables()
				filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
				ruleIdx := filter.BuiltinChains[stack.Input]
				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
				filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
				if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
					t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
				}
			},
			genPacket:          genPacketV6,
			proto:              header.IPv6ProtocolNumber,
			expectReceived:     1,
			expectInputDropped: 0,
		},
		{
			name:       "IPv4 Accept (input interface does not match using a matcher)",
			setupStack: genStackV4,
			setupFilter: func(t *testing.T, s *stack.Stack) {
				t.Helper()
				ipt := s.IPTables()
				filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
				ruleIdx := filter.BuiltinChains[stack.Input]
				filter.Rules[ruleIdx].Target = &stack.DropTarget{}
				filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
				filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
				if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
					t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
				}
			},
			genPacket:          genPacketV4,
			proto:              header.IPv4ProtocolNumber,
			expectReceived:     1,
			expectInputDropped: 0,
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			s, e := test.setupStack(t)
			test.setupFilter(t, s)
			e.InjectInbound(test.proto, test.genPacket())

			if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived {
				t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived)
			}
			if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped {
				t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped)
			}
		})
	}
}

var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil)

// channelEndpointWithoutWritePacket is a channel endpoint that does not support
// stack.LinkEndpoint.WritePacket.
type channelEndpointWithoutWritePacket struct {
	*channel.Endpoint

	t *testing.T
}

func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
	c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets")
	return &tcpip.ErrNotSupported{}
}

var _ stack.Matcher = (*udpSourcePortMatcher)(nil)

type udpSourcePortMatcher struct {
	port uint16
}

func (*udpSourcePortMatcher) Name() string {
	return "udpSourcePortMatcher"
}

func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) {
	udp := header.UDP(pkt.TransportHeader().View())
	if len(udp) < header.UDPMinimumSize {
		// Drop immediately as the packet is invalid.
		return false, true
	}

	return udp.SourcePort() == m.port, false
}

func TestIPTableWritePackets(t *testing.T) {
	const (
		nicID = 1

		dropLocalPort = utils.LocalPort - 1
		acceptPackets = 2
		dropPackets   = 3
	)

	udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) {
		u := header.UDP(hdr)
		u.Encode(&header.UDPFields{
			SrcPort: srcPort,
			DstPort: dstPort,
			Length:  header.UDPMinimumSize,
		})
		sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize)
		sum = header.Checksum(hdr, sum)
		u.SetChecksum(^u.CalculateChecksum(sum))
	}

	tests := []struct {
		name                string
		setupFilter         func(*testing.T, *stack.Stack)
		genPacket           func(*stack.Route) stack.PacketBufferList
		proto               tcpip.NetworkProtocolNumber
		remoteAddr          tcpip.Address
		expectSent          uint64
		expectOutputDropped uint64
	}{
		{
			name:        "IPv4 Accept",
			setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
			genPacket: func(r *stack.Route) stack.PacketBufferList {
				var pkts stack.PacketBufferList

				pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
					ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
				})
				hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
				udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
				pkts.PushFront(pkt)

				return pkts
			},
			proto:               header.IPv4ProtocolNumber,
			remoteAddr:          dstAddrV4,
			expectSent:          1,
			expectOutputDropped: 0,
		},
		{
			name: "IPv4 Drop Other Port",
			setupFilter: func(t *testing.T, s *stack.Stack) {
				t.Helper()

				table := stack.Table{
					Rules: []stack.Rule{
						{
							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
						},
						{
							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
						},
						{
							Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
							Target:   &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber},
						},
						{
							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
						},
						{
							Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber},
						},
					},
					BuiltinChains: [stack.NumHooks]int{
						stack.Prerouting:  stack.HookUnset,
						stack.Input:       0,
						stack.Forward:     1,
						stack.Output:      2,
						stack.Postrouting: stack.HookUnset,
					},
					Underflows: [stack.NumHooks]int{
						stack.Prerouting:  stack.HookUnset,
						stack.Input:       0,
						stack.Forward:     1,
						stack.Output:      2,
						stack.Postrouting: stack.HookUnset,
					},
				}

				if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
					t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err)
				}
			},
			genPacket: func(r *stack.Route) stack.PacketBufferList {
				var pkts stack.PacketBufferList

				for i := 0; i < acceptPackets; i++ {
					pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
						ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
					})
					hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
					udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
					pkts.PushFront(pkt)
				}
				for i := 0; i < dropPackets; i++ {
					pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
						ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
					})
					hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
					udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort)
					pkts.PushFront(pkt)
				}

				return pkts
			},
			proto:               header.IPv4ProtocolNumber,
			remoteAddr:          dstAddrV4,
			expectSent:          acceptPackets,
			expectOutputDropped: dropPackets,
		},
		{
			name:        "IPv6 Accept",
			setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
			genPacket: func(r *stack.Route) stack.PacketBufferList {
				var pkts stack.PacketBufferList

				pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
					ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
				})
				hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
				udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
				pkts.PushFront(pkt)

				return pkts
			},
			proto:               header.IPv6ProtocolNumber,
			remoteAddr:          dstAddrV6,
			expectSent:          1,
			expectOutputDropped: 0,
		},
		{
			name: "IPv6 Drop Other Port",
			setupFilter: func(t *testing.T, s *stack.Stack) {
				t.Helper()

				table := stack.Table{
					Rules: []stack.Rule{
						{
							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
						},
						{
							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
						},
						{
							Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
							Target:   &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber},
						},
						{
							Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
						},
						{
							Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber},
						},
					},
					BuiltinChains: [stack.NumHooks]int{
						stack.Prerouting:  stack.HookUnset,
						stack.Input:       0,
						stack.Forward:     1,
						stack.Output:      2,
						stack.Postrouting: stack.HookUnset,
					},
					Underflows: [stack.NumHooks]int{
						stack.Prerouting:  stack.HookUnset,
						stack.Input:       0,
						stack.Forward:     1,
						stack.Output:      2,
						stack.Postrouting: stack.HookUnset,
					},
				}

				if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
					t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err)
				}
			},
			genPacket: func(r *stack.Route) stack.PacketBufferList {
				var pkts stack.PacketBufferList

				for i := 0; i < acceptPackets; i++ {
					pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
						ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
					})
					hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
					udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
					pkts.PushFront(pkt)
				}
				for i := 0; i < dropPackets; i++ {
					pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
						ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
					})
					hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
					udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort)
					pkts.PushFront(pkt)
				}

				return pkts
			},
			proto:               header.IPv6ProtocolNumber,
			remoteAddr:          dstAddrV6,
			expectSent:          acceptPackets,
			expectOutputDropped: dropPackets,
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			s := stack.New(stack.Options{
				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
			})
			e := channelEndpointWithoutWritePacket{
				Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr),
				t:        t,
			}
			if err := s.CreateNIC(nicID, &e); err != nil {
				t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
			}
			protocolAddrV6 := tcpip.ProtocolAddress{
				Protocol:          header.IPv6ProtocolNumber,
				AddressWithPrefix: srcAddrV6.WithPrefix(),
			}
			if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
			}
			protocolAddrV4 := tcpip.ProtocolAddress{
				Protocol:          header.IPv4ProtocolNumber,
				AddressWithPrefix: srcAddrV4.WithPrefix(),
			}
			if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
				t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
			}

			s.SetRouteTable([]tcpip.Route{
				{
					Destination: header.IPv4EmptySubnet,
					NIC:         nicID,
				},
				{
					Destination: header.IPv6EmptySubnet,
					NIC:         nicID,
				},
			})

			test.setupFilter(t, s)

			r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false)
			if err != nil {
				t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err)
			}
			defer r.Release()

			pkts := test.genPacket(r)
			pktsLen := pkts.Len()
			if n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{
				Protocol: header.UDPProtocolNumber,
				TTL:      64,
			}); err != nil {
				t.Fatalf("WritePackets(...): %s", err)
			} else if n != pktsLen {
				t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen)
			}

			if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent {
				t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent)
			}
			if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped {
				t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped)
			}
		})
	}
}

const ttl = 64

var (
	ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
	ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
)

func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
	utils.RxICMPv4EchoReply(e, src, dst, ttl)
}

func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
	utils.RxICMPv6EchoReply(e, src, dst, ttl)
}

func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
	checker.IPv4(t, b,
		checker.SrcAddr(src),
		checker.DstAddr(dst),
		checker.TTL(ttl-1),
		checker.ICMPv4(
			checker.ICMPv4Type(header.ICMPv4EchoReply)))
}

func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
	checker.IPv6(t, b,
		checker.SrcAddr(src),
		checker.DstAddr(dst),
		checker.TTL(ttl-1),
		checker.ICMPv6(
			checker.ICMPv6Type(header.ICMPv6EchoReply)))
}

func boolToInt(v bool) uint64 {
	if v {
		return 1
	}
	return 0
}

func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
	return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
		t.Helper()

		ipv6 := netProto == ipv6.ProtocolNumber

		ipt := s.IPTables()
		filter := ipt.GetTable(stack.FilterID, ipv6)
		ruleIdx := filter.BuiltinChains[hook]
		filter.Rules[ruleIdx].Filter = f
		filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
		// Make sure the packet is not dropped by the next rule.
		filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
		if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
			t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
		}
	}
}

func TestForwardingHook(t *testing.T) {
	const (
		nicID1 = 1
		nicID2 = 2

		nic1Name = "nic1"
		nic2Name = "nic2"

		otherNICName = "otherNIC"
	)

	tests := []struct {
		name             string
		netProto         tcpip.NetworkProtocolNumber
		local            bool
		srcAddr, dstAddr tcpip.Address
		rx               func(*channel.Endpoint, tcpip.Address, tcpip.Address)
		checker          func(*testing.T, []byte)
	}{
		{
			name:     "IPv4 remote",
			netProto: ipv4.ProtocolNumber,
			local:    false,
			srcAddr:  utils.RemoteIPv4Addr,
			dstAddr:  utils.Ipv4Addr2.AddressWithPrefix.Address,
			rx:       rxICMPv4EchoReply,
			checker: func(t *testing.T, b []byte) {
				forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
			},
		},
		{
			name:     "IPv4 local",
			netProto: ipv4.ProtocolNumber,
			local:    true,
			srcAddr:  utils.RemoteIPv4Addr,
			dstAddr:  utils.Ipv4Addr.Address,
			rx:       rxICMPv4EchoReply,
		},
		{
			name:     "IPv6 remote",
			netProto: ipv6.ProtocolNumber,
			local:    false,
			srcAddr:  utils.RemoteIPv6Addr,
			dstAddr:  utils.Ipv6Addr2.AddressWithPrefix.Address,
			rx:       rxICMPv6EchoReply,
			checker: func(t *testing.T, b []byte) {
				forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
			},
		},
		{
			name:     "IPv6 local",
			netProto: ipv6.ProtocolNumber,
			local:    true,
			srcAddr:  utils.RemoteIPv6Addr,
			dstAddr:  utils.Ipv6Addr.Address,
			rx:       rxICMPv6EchoReply,
		},
	}

	subTests := []struct {
		name          string
		setupFilter   func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
		expectForward bool
	}{
		{
			name:          "Accept",
			setupFilter:   func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
			expectForward: true,
		},

		{
			name:          "Drop",
			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{}),
			expectForward: false,
		},
		{
			name:          "Drop with input NIC filtering",
			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}),
			expectForward: false,
		},
		{
			name:          "Drop with output NIC filtering",
			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}),
			expectForward: false,
		},
		{
			name:          "Drop with input and output NIC filtering",
			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
			expectForward: false,
		},

		{
			name:          "Drop with other input NIC filtering",
			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}),
			expectForward: true,
		},
		{
			name:          "Drop with other output NIC filtering",
			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}),
			expectForward: true,
		},
		{
			name:          "Drop with other input and output NIC filtering",
			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
			expectForward: true,
		},
		{
			name:          "Drop with input and other output NIC filtering",
			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
			expectForward: true,
		},
		{
			name:          "Drop with other input and other output NIC filtering",
			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
			expectForward: true,
		},

		{
			name:          "Drop with inverted input NIC filtering",
			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
			expectForward: true,
		},
		{
			name:          "Drop with inverted output NIC filtering",
			setupFilter:   setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
			expectForward: true,
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			for _, subTest := range subTests {
				t.Run(subTest.name, func(t *testing.T) {
					s := stack.New(stack.Options{
						NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
					})

					subTest.setupFilter(t, s, test.netProto)

					e1 := channel.New(1, header.IPv6MinimumMTU, "")
					if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
						t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
					}

					e2 := channel.New(1, header.IPv6MinimumMTU, "")
					if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
						t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
					}

					protocolAddrV4 := tcpip.ProtocolAddress{
						Protocol:          ipv4.ProtocolNumber,
						AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(),
					}
					if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
						t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
					}
					protocolAddrV6 := tcpip.ProtocolAddress{
						Protocol:          ipv6.ProtocolNumber,
						AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(),
					}
					if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
						t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
					}

					if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
					}
					if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
					}

					s.SetRouteTable([]tcpip.Route{
						{
							Destination: header.IPv4EmptySubnet,
							NIC:         nicID2,
						},
						{
							Destination: header.IPv6EmptySubnet,
							NIC:         nicID2,
						},
					})

					test.rx(e1, test.srcAddr, test.dstAddr)

					expectTransmitPacket := subTest.expectForward && !test.local

					ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
					if err != nil {
						t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
					}
					ep1Stats := ep1.Stats()
					ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
					if !ok {
						t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
					}
					ip1Stats := ipEP1Stats.IPStats()

					if got := ip1Stats.PacketsReceived.Value(); got != 1 {
						t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
					}
					if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
						t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
					}
					if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want {
						t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want)
					}
					if got := ip1Stats.PacketsSent.Value(); got != 0 {
						t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got)
					}

					ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
					if err != nil {
						t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
					}
					ep2Stats := ep2.Stats()
					ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
					if !ok {
						t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
					}
					ip2Stats := ipEP2Stats.IPStats()
					if got := ip2Stats.PacketsReceived.Value(); got != 0 {
						t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
					}
					if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want {
						t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want)
					}
					if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want {
						t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want)
					}

					p, ok := e2.Read()
					if ok != expectTransmitPacket {
						t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket)
					}
					if expectTransmitPacket {
						test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
					}
				})
			}
		})
	}
}

func TestInputHookWithLocalForwarding(t *testing.T) {
	const (
		nicID1 = 1
		nicID2 = 2

		nic1Name = "nic1"
		nic2Name = "nic2"

		otherNICName = "otherNIC"
	)

	tests := []struct {
		name     string
		netProto tcpip.NetworkProtocolNumber
		rx       func(*channel.Endpoint)
		checker  func(*testing.T, []byte)
	}{
		{
			name:     "IPv4",
			netProto: ipv4.ProtocolNumber,
			rx: func(e *channel.Endpoint) {
				utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl)
			},
			checker: func(t *testing.T, b []byte) {
				checker.IPv4(t, b,
					checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address),
					checker.DstAddr(utils.RemoteIPv4Addr),
					checker.ICMPv4(
						checker.ICMPv4Type(header.ICMPv4EchoReply)))
			},
		},
		{
			name:     "IPv6",
			netProto: ipv6.ProtocolNumber,
			rx: func(e *channel.Endpoint) {
				utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl)
			},
			checker: func(t *testing.T, b []byte) {
				checker.IPv6(t, b,
					checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address),
					checker.DstAddr(utils.RemoteIPv6Addr),
					checker.ICMPv6(
						checker.ICMPv6Type(header.ICMPv6EchoReply)))
			},
		},
	}

	subTests := []struct {
		name        string
		setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
		expectDrop  bool
	}{
		{
			name:        "Accept",
			setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
			expectDrop:  false,
		},

		{
			name:        "Drop",
			setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}),
			expectDrop:  true,
		},
		{
			name:        "Drop with input NIC filtering on arrival NIC",
			setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}),
			expectDrop:  true,
		},
		{
			name:        "Drop with input NIC filtering on delivered NIC",
			setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}),
			expectDrop:  false,
		},

		{
			name:        "Drop with input NIC filtering on other NIC",
			setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}),
			expectDrop:  false,
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			for _, subTest := range subTests {
				t.Run(subTest.name, func(t *testing.T) {
					s := stack.New(stack.Options{
						NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
					})

					subTest.setupFilter(t, s, test.netProto)

					e1 := channel.New(1, header.IPv6MinimumMTU, "")
					if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
						t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
					}
					if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
						t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err)
					}
					if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
						t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err)
					}

					e2 := channel.New(1, header.IPv6MinimumMTU, "")
					if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
						t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
					}
					if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
						t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err)
					}
					if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
						t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err)
					}

					if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
					}
					if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
					}

					s.SetRouteTable([]tcpip.Route{
						{
							Destination: header.IPv4EmptySubnet,
							NIC:         nicID1,
						},
						{
							Destination: header.IPv6EmptySubnet,
							NIC:         nicID1,
						},
					})

					test.rx(e1)

					ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
					if err != nil {
						t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
					}
					ep1Stats := ep1.Stats()
					ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
					if !ok {
						t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
					}
					ip1Stats := ipEP1Stats.IPStats()

					if got := ip1Stats.PacketsReceived.Value(); got != 1 {
						t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
					}
					if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
						t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
					}
					if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want {
						t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want)
					}

					ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
					if err != nil {
						t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
					}
					ep2Stats := ep2.Stats()
					ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
					if !ok {
						t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
					}
					ip2Stats := ipEP2Stats.IPStats()
					if got := ip2Stats.PacketsReceived.Value(); got != 0 {
						t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
					}
					if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 {
						t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
					}
					if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want {
						t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want)
					}
					if got := ip2Stats.PacketsSent.Value(); got != 0 {
						t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got)
					}

					if p, ok := e1.Read(); ok == subTest.expectDrop {
						t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop)
					} else if !subTest.expectDrop {
						test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
					}
					if p, ok := e2.Read(); ok {
						t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p)
					}
				})
			}
		})
	}
}

func TestSNAT(t *testing.T) {
	const listenPort = 8080

	type endpointAndAddresses struct {
		serverEP         tcpip.Endpoint
		serverAddr       tcpip.Address
		serverReadableCH chan struct{}

		clientEP         tcpip.Endpoint
		clientAddr       tcpip.Address
		clientReadableCH chan struct{}

		nattedClientAddr tcpip.Address
	}

	newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
		t.Helper()
		var wq waiter.Queue
		we, ch := waiter.NewChannelEntry(nil)
		wq.EventRegister(&we, waiter.ReadableEvents)
		t.Cleanup(func() {
			wq.EventUnregister(&we)
		})

		ep, err := s.NewEndpoint(transProto, netProto, &wq)
		if err != nil {
			t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
		}
		t.Cleanup(ep.Close)

		return ep, ch
	}

	tests := []struct {
		name       string
		netProto   tcpip.NetworkProtocolNumber
		epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
	}{
		{
			name:     "IPv4 host1 server with host2 client",
			netProto: ipv4.ProtocolNumber,
			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
				t.Helper()

				ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
				ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
				return endpointAndAddresses{
					serverEP:         ep1,
					serverAddr:       utils.Host1IPv4Addr.AddressWithPrefix.Address,
					serverReadableCH: ep1WECH,

					clientEP:         ep2,
					clientAddr:       utils.Host2IPv4Addr.AddressWithPrefix.Address,
					clientReadableCH: ep2WECH,

					nattedClientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
				}
			},
		},
		{
			name:     "IPv6 host1 server with host2 client",
			netProto: ipv6.ProtocolNumber,
			epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
				t.Helper()

				ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
				ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
				return endpointAndAddresses{
					serverEP:         ep1,
					serverAddr:       utils.Host1IPv6Addr.AddressWithPrefix.Address,
					serverReadableCH: ep1WECH,

					clientEP:         ep2,
					clientAddr:       utils.Host2IPv6Addr.AddressWithPrefix.Address,
					clientReadableCH: ep2WECH,

					nattedClientAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
				}
			},
		},
	}

	subTests := []struct {
		name               string
		proto              tcpip.TransportProtocolNumber
		expectedConnectErr tcpip.Error
		setupServer        func(t *testing.T, ep tcpip.Endpoint)
		setupServerConn    func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
		needRemoteAddr     bool
	}{
		{
			name:               "UDP",
			proto:              udp.ProtocolNumber,
			expectedConnectErr: nil,
			setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
				t.Helper()

				if err := ep.Connect(clientAddr); err != nil {
					t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
				}
				return nil, nil
			},
			needRemoteAddr: true,
		},
		{
			name:               "TCP",
			proto:              tcp.ProtocolNumber,
			expectedConnectErr: &tcpip.ErrConnectStarted{},
			setupServer: func(t *testing.T, ep tcpip.Endpoint) {
				t.Helper()

				if err := ep.Listen(1); err != nil {
					t.Fatalf("ep.Listen(1): %s", err)
				}
			},
			setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
				t.Helper()

				var addr tcpip.FullAddress
				for {
					newEP, wq, err := ep.Accept(&addr)
					if _, ok := err.(*tcpip.ErrWouldBlock); ok {
						<-ch
						continue
					}
					if err != nil {
						t.Fatalf("ep.Accept(_): %s", err)
					}
					if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
						"NIC",
					)); diff != "" {
						t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
					}

					we, newCH := waiter.NewChannelEntry(nil)
					wq.EventRegister(&we, waiter.ReadableEvents)
					return newEP, newCH
				}
			},
			needRemoteAddr: false,
		},
	}

	setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, target stack.Target) {
		t.Helper()

		ipv6 := netProto == ipv6.ProtocolNumber
		ipt := s.IPTables()
		filter := ipt.GetTable(stack.NATID, ipv6)
		ruleIdx := filter.BuiltinChains[stack.Postrouting]
		filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
		filter.Rules[ruleIdx].Target = target
		// Make sure the packet is not dropped by the next rule.
		filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
		if err := ipt.ReplaceTable(stack.NATID, filter, ipv6); err != nil {
			t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
		}
	}

	natTypes := []struct {
		name     string
		setupNAT func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber, tcpip.Address)
	}{
		{
			name: "SNAT",
			setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) {
				t.Helper()

				setupNAT(t, s, netProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: natToAddr})
			},
		},
		{
			name: "Masquerade",
			setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) {
				t.Helper()

				setupNAT(t, s, netProto, &stack.MasqueradeTarget{NetworkProtocol: netProto})
			},
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			for _, subTest := range subTests {
				t.Run(subTest.name, func(t *testing.T) {
					for _, natType := range natTypes {
						t.Run(natType.name, func(t *testing.T) {
							stackOpts := stack.Options{
								NetworkProtocols:   []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
								TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
							}

							host1Stack := stack.New(stackOpts)
							routerStack := stack.New(stackOpts)
							host2Stack := stack.New(stackOpts)
							utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)

							epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)

							natType.setupNAT(t, routerStack, test.netProto, epsAndAddrs.nattedClientAddr)

							serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
							if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
								t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
							}
							clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
							if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
								t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
							}

							if subTest.setupServer != nil {
								subTest.setupServer(t, epsAndAddrs.serverEP)
							}
							{
								err := epsAndAddrs.clientEP.Connect(serverAddr)
								if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
									t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
								}
							}
							nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr}
							if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
								t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
							} else {
								nattedClientAddr.Port = addr.Port
							}

							serverEP := epsAndAddrs.serverEP
							serverCH := epsAndAddrs.serverReadableCH
							if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil {
								defer ep.Close()
								serverEP = ep
								serverCH = ch
							}

							write := func(ep tcpip.Endpoint, data []byte) {
								t.Helper()

								var r bytes.Reader
								r.Reset(data)
								var wOpts tcpip.WriteOptions
								n, err := ep.Write(&r, wOpts)
								if err != nil {
									t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
								}
								if want := int64(len(data)); n != want {
									t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
								}
							}

							read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
								t.Helper()

								var buf bytes.Buffer
								var res tcpip.ReadResult
								for {
									var err tcpip.Error
									opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
									res, err = ep.Read(&buf, opts)
									if _, ok := err.(*tcpip.ErrWouldBlock); ok {
										<-ch
										continue
									}
									if err != nil {
										t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
									}
									break
								}

								readResult := tcpip.ReadResult{
									Count: len(data),
									Total: len(data),
								}
								if subTest.needRemoteAddr {
									readResult.RemoteAddr = expectedFrom
								}
								if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
									"ControlMessages",
									"RemoteAddr.NIC",
								)); diff != "" {
									t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
								}
								if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
									t.Errorf("received data mismatch (-want +got):\n%s", diff)
								}

								if t.Failed() {
									t.FailNow()
								}
							}

							{
								data := []byte{1, 2, 3, 4}
								write(epsAndAddrs.clientEP, data)
								read(serverCH, serverEP, data, nattedClientAddr)
							}

							{
								data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
								write(serverEP, data)
								read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
							}
						})
					}
				})
			}
		})
	}
}