summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/tests
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go248
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go285
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go453
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go11
5 files changed, 855 insertions, 144 deletions
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index ab2dab60c..181ef799e 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -53,12 +53,14 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/pipe",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 42bc53328..92fa6257d 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -16,6 +16,7 @@ package forward_test
import (
"bytes"
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -34,6 +35,39 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+const ttl = 64
+
+var (
+ ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
+ ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
+)
+
+func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoRequest(e, src, dst, ttl)
+}
+
+func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoRequest(e, src, dst, ttl)
+}
+
+func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo)))
+}
+
+func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoRequest)))
+}
+
func TestForwarding(t *testing.T) {
const listenPort = 8080
@@ -320,45 +354,16 @@ func TestMulticastForwarding(t *testing.T) {
const (
nicID1 = 1
nicID2 = 2
- ttl = 64
)
var (
ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10")
ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10")
- ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a")
ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a")
- ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
)
- rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv4EchoRequest(e, src, dst, ttl)
- }
-
- rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv6EchoRequest(e, src, dst, ttl)
- }
-
- v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv4(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo)))
- }
-
- v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv6(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoRequest)))
- }
-
tests := []struct {
name string
srcAddr, dstAddr tcpip.Address
@@ -394,7 +399,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv4EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
},
},
{
@@ -404,7 +409,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv4EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
},
},
@@ -436,7 +441,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv6EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
},
},
{
@@ -446,7 +451,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv6EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
},
},
}
@@ -506,3 +511,180 @@ func TestMulticastForwarding(t *testing.T) {
})
}
}
+
+func TestPerInterfaceForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ )
+
+ tests := []struct {
+ name string
+ srcAddr, dstAddr tcpip.Address
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4 unicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv4EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv4 multicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4GlobalMulticastAddr,
+ rx: rxICMPv4EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+ },
+ },
+
+ {
+ name: "IPv6 unicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv6EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv6 multicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6GlobalMulticastAddr,
+ rx: rxICMPv6EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+ },
+ },
+ }
+
+ netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber}
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ // ARP is not used in this test but it is a network protocol that does
+ // not support forwarding. We install the protocol to make sure that
+ // forwarding information for a NIC is only reported for network
+ // protocols that support forwarding.
+ arp.NewProtocol,
+
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ for _, add := range [...]struct {
+ nicID tcpip.NICID
+ addr tcpip.ProtocolAddress
+ }{
+ {
+ nicID: nicID1,
+ addr: utils.RouterNIC1IPv4Addr,
+ },
+ {
+ nicID: nicID1,
+ addr: utils.RouterNIC1IPv6Addr,
+ },
+ {
+ nicID: nicID2,
+ addr: utils.RouterNIC2IPv4Addr,
+ },
+ {
+ nicID: nicID2,
+ addr: utils.RouterNIC2IPv6Addr,
+ },
+ } {
+ if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
+ }
+ }
+
+ // Only enable forwarding on NIC1 and make sure that only packets arriving
+ // on NIC1 are forwarded.
+ for _, netProto := range netProtos {
+ if err := s.SetNICForwarding(nicID1, netProto, true); err != nil {
+ t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err)
+ }
+ }
+
+ nicsInfo := s.NICInfo()
+ for _, subTest := range [...]struct {
+ nicID tcpip.NICID
+ nicEP *channel.Endpoint
+ otherNICID tcpip.NICID
+ otherNICEP *channel.Endpoint
+ expectForwarding bool
+ }{
+ {
+ nicID: nicID1,
+ nicEP: e1,
+ otherNICID: nicID2,
+ otherNICEP: e2,
+ expectForwarding: true,
+ },
+ {
+ nicID: nicID2,
+ nicEP: e2,
+ otherNICID: nicID2,
+ otherNICEP: e1,
+ expectForwarding: false,
+ },
+ } {
+ t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) {
+ nicInfo, ok := nicsInfo[subTest.nicID]
+ if !ok {
+ t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo)
+ } else {
+ forwarding := make(map[tcpip.NetworkProtocolNumber]bool)
+ for _, netProto := range netProtos {
+ forwarding[netProto] = subTest.expectForwarding
+ }
+
+ if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" {
+ t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: subTest.otherNICID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: subTest.otherNICID,
+ },
+ })
+
+ test.rx(subTest.nicEP, test.srcAddr, test.dstAddr)
+ if p, ok := subTest.nicEP.Read(); ok {
+ t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p)
+ }
+ if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding {
+ t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding)
+ } else if subTest.expectForwarding {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 07ba2b837..f9ab7d0af 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -166,7 +166,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
// Make sure the packet is not dropped by the next rule.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -187,7 +187,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -207,7 +207,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -227,7 +227,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -250,7 +250,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -273,7 +273,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -293,7 +293,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -313,7 +313,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -465,7 +465,7 @@ func TestIPTableWritePackets(t *testing.T) {
}
if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
- t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err)
+ t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err)
}
},
genPacket: func(r *stack.Route) stack.PacketBufferList {
@@ -556,7 +556,7 @@ func TestIPTableWritePackets(t *testing.T) {
}
if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
- t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err)
+ t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err)
}
},
genPacket: func(r *stack.Route) stack.PacketBufferList {
@@ -681,6 +681,32 @@ func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Addr
checker.ICMPv6Type(header.ICMPv6EchoReply)))
}
+func boolToInt(v bool) uint64 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
+func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
+ return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, ipv6)
+ ruleIdx := filter.BuiltinChains[hook]
+ filter.Rules[ruleIdx].Filter = f
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
+ }
+ }
+}
+
func TestForwardingHook(t *testing.T) {
const (
nicID1 = 1
@@ -740,32 +766,6 @@ func TestForwardingHook(t *testing.T) {
},
}
- setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
- return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
- t.Helper()
-
- ipv6 := netProto == ipv6.ProtocolNumber
-
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, ipv6)
- ruleIdx := filter.BuiltinChains[stack.Forward]
- filter.Rules[ruleIdx].Filter = f
- filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
- if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
- }
- }
- }
-
- boolToInt := func(v bool) uint64 {
- if v {
- return 1
- }
- return 0
- }
-
subTests := []struct {
name string
setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
@@ -779,59 +779,59 @@ func TestForwardingHook(t *testing.T) {
{
name: "Drop",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}),
expectForward: false,
},
{
name: "Drop with input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}),
expectForward: false,
},
{
name: "Drop with output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}),
expectForward: false,
},
{
name: "Drop with input and output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
expectForward: false,
},
{
name: "Drop with other input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other input and output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
expectForward: true,
},
{
name: "Drop with input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with inverted input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
expectForward: true,
},
{
name: "Drop with inverted output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
expectForward: true,
},
}
@@ -941,3 +941,194 @@ func TestForwardingHook(t *testing.T) {
})
}
}
+
+func TestInputHookWithLocalForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Name = "nic1"
+ nic2Name = "nic2"
+
+ otherNICName = "otherNIC"
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ rx func(*channel.Endpoint)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ rx: func(e *channel.Endpoint) {
+ utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl)
+ },
+ checker: func(t *testing.T, b []byte) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address),
+ checker.DstAddr(utils.RemoteIPv4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply)))
+ },
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ rx: func(e *channel.Endpoint) {
+ utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl)
+ },
+ checker: func(t *testing.T, b []byte) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address),
+ checker.DstAddr(utils.RemoteIPv6Addr),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply)))
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
+ expectDrop bool
+ }{
+ {
+ name: "Accept",
+ setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
+ expectDrop: false,
+ },
+
+ {
+ name: "Drop",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}),
+ expectDrop: true,
+ },
+ {
+ name: "Drop with input NIC filtering on arrival NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}),
+ expectDrop: true,
+ },
+ {
+ name: "Drop with input NIC filtering on delivered NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}),
+ expectDrop: false,
+ },
+
+ {
+ name: "Drop with input NIC filtering on other NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}),
+ expectDrop: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+
+ subTest.setupFilter(t, s, test.netProto)
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
+ }
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err)
+ }
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err)
+ }
+
+ if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID1,
+ },
+ })
+
+ test.rx(e1)
+
+ ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
+ }
+ ep1Stats := ep1.Stats()
+ ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
+ }
+ ip1Stats := ipEP1Stats.IPStats()
+
+ if got := ip1Stats.PacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want {
+ t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want)
+ }
+
+ ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
+ }
+ ep2Stats := ep2.Stats()
+ ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
+ }
+ ip2Stats := ipEP2Stats.IPStats()
+ if got := ip2Stats.PacketsReceived.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want {
+ t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want)
+ }
+ if got := ip2Stats.PacketsSent.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got)
+ }
+
+ if p, ok := e1.Read(); ok == subTest.expectDrop {
+ t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop)
+ } else if !subTest.expectDrop {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ if p, ok := e2.Read(); ok {
+ t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index c657714ba..27caa0c28 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -17,6 +17,8 @@ package link_resolution_test
import (
"bytes"
"fmt"
+ "net"
+ "runtime"
"testing"
"time"
@@ -27,12 +29,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -283,9 +287,11 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ Clock: clock,
}
host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
@@ -329,7 +335,17 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
// Wait for an error due to link resolution failing, or the endpoint to be
// writable.
+ if test.expectedWriteErr != nil {
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
+ if err != nil {
+ t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
+ }
+ clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
+ } else {
+ clock.RunImmediatelyScheduledJobs()
+ }
<-ch
+
{
var r bytes.Reader
r.Reset([]byte{0})
@@ -395,6 +411,242 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
}
}
+func TestForwardingWithLinkResolutionFailure(t *testing.T) {
+ const (
+ incomingNICID = 1
+ outgoingNICID = 2
+ ttl = 2
+ expectedHostUnreachableErrorCount = 1
+ )
+ outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06")
+
+ rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoRequest(e, src, dst, ttl)
+ }
+
+ rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoRequest(e, src, dst, ttl)
+ }
+
+ arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
+ if request.Proto != arp.ProtocolNumber {
+ t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber)
+ }
+ if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
+ t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ }
+ rep := header.ARP(request.Pkt.NetworkHeader().View())
+ if got := rep.Op(); got != header.ARPRequest {
+ t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != src {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst)
+ }
+ }
+
+ ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
+ if request.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber)
+ }
+
+ snmc := header.SolicitedNodeAddr(dst)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()),
+ checker.SrcAddr(src),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(dst),
+ ))
+ }
+
+ icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ipv4.DefaultTTL),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4HostUnreachable),
+ ),
+ )
+ }
+
+ icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ipv6.DefaultTTL),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6AddressUnreachable),
+ ),
+ )
+ }
+
+ tests := []struct {
+ name string
+ networkProtocolFactory []stack.NetworkProtocolFactory
+ networkProtocolNumber tcpip.NetworkProtocolNumber
+ sourceAddr tcpip.Address
+ destAddr tcpip.Address
+ incomingAddr tcpip.AddressWithPrefix
+ outgoingAddr tcpip.AddressWithPrefix
+ transportProtocol func(*stack.Stack) stack.TransportProtocol
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address)
+ icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address)
+ mtu uint32
+ }{
+ {
+ name: "IPv4 Host unreachable",
+ networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ networkProtocolNumber: header.IPv4ProtocolNumber,
+ sourceAddr: tcptestutil.MustParse4("10.0.0.2"),
+ destAddr: tcptestutil.MustParse4("11.0.0.2"),
+ incomingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
+ PrefixLen: 8,
+ },
+ outgoingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
+ PrefixLen: 8,
+ },
+ transportProtocol: icmp.NewProtocol4,
+ linkResolutionRequestChecker: arpChecker,
+ icmpReplyChecker: icmpv4Checker,
+ rx: rxICMPv4EchoRequest,
+ mtu: ipv4.MaxTotalSize,
+ },
+ {
+ name: "IPv6 Host unreachable",
+ networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ networkProtocolNumber: header.IPv6ProtocolNumber,
+ sourceAddr: tcptestutil.MustParse6("10::2"),
+ destAddr: tcptestutil.MustParse6("11::2"),
+ incomingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ },
+ outgoingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11::1").To16()),
+ PrefixLen: 64,
+ },
+ transportProtocol: icmp.NewProtocol6,
+ linkResolutionRequestChecker: ndpChecker,
+ icmpReplyChecker: icmpv6Checker,
+ rx: rxICMPv6EchoRequest,
+ mtu: header.IPv6MinimumMTU,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: test.networkProtocolFactory,
+ TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol},
+ Clock: clock,
+ })
+
+ // Set up endpoint through which we will receive packets.
+ incomingEndpoint := channel.New(1, test.mtu, "")
+ if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
+ }
+ incomingProtoAddr := tcpip.ProtocolAddress{
+ Protocol: test.networkProtocolNumber,
+ AddressWithPrefix: test.incomingAddr,
+ }
+ if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err)
+ }
+
+ // Set up endpoint through which we will attempt to forward packets.
+ outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr)
+ outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
+ }
+ outgoingProtoAddr := tcpip.ProtocolAddress{
+ Protocol: test.networkProtocolNumber,
+ AddressWithPrefix: test.outgoingAddr,
+ }
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: test.incomingAddr.Subnet(),
+ NIC: incomingNICID,
+ },
+ {
+ Destination: test.outgoingAddr.Subnet(),
+ NIC: outgoingNICID,
+ },
+ })
+
+ if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err)
+ }
+
+ test.rx(incomingEndpoint, test.sourceAddr, test.destAddr)
+
+ nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err)
+ }
+ // Trigger the first packet on the endpoint.
+ clock.RunImmediatelyScheduledJobs()
+
+ for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ {
+ request, ok := outgoingEndpoint.Read()
+ if !ok {
+ t.Fatal("expected ARP packet through outgoing NIC")
+ }
+
+ test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr)
+
+ // Advance the clock the span of one request timeout.
+ clock.Advance(nudConfigs.RetransmitTimer)
+ }
+
+ // Next, we make a blocking read to retrieve the error packet. This is
+ // necessary because outgoing packets are dequeued asynchronously when
+ // link resolution fails, and this dequeue is what triggers the ICMP
+ // error.
+ reply, ok := incomingEndpoint.Read()
+ if !ok {
+ t.Fatal("expected ICMP packet through incoming NIC")
+ }
+
+ test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr)
+
+ // Since link resolution failed, we don't expect the packet to be
+ // forwarded.
+ forwardedPacket, ok := outgoingEndpoint.Read()
+ if ok {
+ t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want {
+ t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
func TestGetLinkAddress(t *testing.T) {
const (
host1NICID = 1
@@ -449,8 +701,10 @@ func TestGetLinkAddress(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ Clock: clock,
}
host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
@@ -466,8 +720,20 @@ func TestGetLinkAddress(t *testing.T) {
if test.expectedErr == nil {
wantRes.LinkAddress = utils.LinkAddr2
}
- if diff := cmp.Diff(wantRes, <-ch); diff != "" {
- t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
+ if err != nil {
+ t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
+ }
+
+ clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
+ select {
+ case got := <-ch:
+ if diff := cmp.Diff(wantRes, got); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("event didn't arrive")
}
})
}
@@ -544,8 +810,10 @@ func TestRouteResolvedFields(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ Clock: clock,
}
host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
@@ -575,8 +843,20 @@ func TestRouteResolvedFields(t *testing.T) {
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
}
- if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
- t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
+
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
+ if err != nil {
+ t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
+ }
+ clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
+
+ select {
+ case got := <-ch:
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatalf("event didn't arrive")
}
if test.expectedErr != nil {
@@ -789,11 +1069,16 @@ func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo
d.c <- e
}
-func (d *nudDispatcher) waitForEvent(want eventInfo) error {
- if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
- return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
+func (d *nudDispatcher) expectEvent(want eventInfo) error {
+ select {
+ case got := <-d.c:
+ if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" {
+ return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
+ }
+ return nil
+ default:
+ return fmt.Errorf("event didn't arrive")
}
- return nil
}
// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it
@@ -804,7 +1089,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto tcpip.NetworkProtocolNumber
remoteAddr tcpip.Address
neighborAddr tcpip.Address
- getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{})
+ getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{})
isHost1Listener bool
}{
{
@@ -812,23 +1097,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -836,23 +1123,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -860,23 +1149,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -884,23 +1175,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -908,23 +1201,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -933,23 +1228,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -958,23 +1255,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -983,23 +1282,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -1037,14 +1338,14 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("link resolution mismatch (-want +got):\n%s", diff)
}
}
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryAdded,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr},
}); err != nil {
t.Fatalf("error waiting for initial NUD event: %s", err)
}
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1064,7 +1365,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
// See NUDConfigurations.BaseReachableTime for more information.
maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor)
clock.Advance(maxReachableTime)
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1072,8 +1373,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("error waiting for stale NUD event: %s", err)
}
- listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack)
- defer listenerEP.Close()
+ listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack)
defer clientEP.Close()
listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234}
if err := listenerEP.Bind(listenerAddr); err != nil {
@@ -1094,14 +1394,15 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
// with confirmation that the neighbor is reachable (indicated by a
// successful 3-way handshake).
<-clientCH
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
}); err != nil {
t.Fatalf("error waiting for delay NUD event: %s", err)
}
- if err := nudDisp.waitForEvent(eventInfo{
+ <-listenerCH
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1109,26 +1410,55 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("error waiting for reachable NUD event: %s", err)
}
+ peerEP, peerWQ, err := listenerEP.Accept(nil)
+ if err != nil {
+ t.Fatalf("listenerEP.Accept(): %s", err)
+ }
+ defer peerEP.Close()
+ peerWE, peerCH := waiter.NewChannelEntry(nil)
+ peerWQ.EventRegister(&peerWE, waiter.ReadableEvents)
+
// Wait for the neighbor to be stale again then send data to the remote.
//
// On successful transmission, the neighbor should become reachable
// without probing the neighbor as a TCP ACK would be received which is an
// indication of the neighbor being reachable.
clock.Advance(maxReachableTime)
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
}); err != nil {
t.Fatalf("error waiting for stale NUD event: %s", err)
}
- var r bytes.Reader
- r.Reset([]byte{0})
- var wOpts tcpip.WriteOptions
- if _, err := clientEP.Write(&r, wOpts); err != nil {
- t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err)
+ {
+ var r bytes.Reader
+ r.Reset([]byte{0})
+ var wOpts tcpip.WriteOptions
+ if _, err := clientEP.Write(&r, wOpts); err != nil {
+ t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err)
+ }
+ }
+ // Heads up, there is a race here.
+ //
+ // Incoming TCP segments are handled in
+ // tcp.(*endpoint).handleSegmentLocked:
+ //
+ // - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the
+ // segment queue and notifies waiting readers (such as this channel)
+ //
+ // - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment
+ // and notifies the NUD machinery that the peer is reachable
+ //
+ // Thus we must permit a delay between the readable signal and the
+ // expected NUD event.
+ //
+ // At the time of writing, this race is reliably hit with gotsan.
+ <-peerCH
+ for len(nudDisp.c) == 0 {
+ runtime.Gosched()
}
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1141,7 +1471,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
// TCP should not mark the route reachable and NUD should go through the
// probe state.
clock.Advance(nudConfigs.DelayFirstProbeTime)
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1149,7 +1479,16 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("error waiting for probe NUD event: %s", err)
}
}
- if err := nudDisp.waitForEvent(eventInfo{
+ {
+ var r bytes.Reader
+ r.Reset([]byte{0})
+ var wOpts tcpip.WriteOptions
+ if _, err := peerEP.Write(&r, wOpts); err != nil {
+ t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err)
+ }
+ }
+ <-clientCH
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 87d36e1dd..b4b2ec723 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -44,20 +44,17 @@ type ndpDispatcher struct{}
func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
-func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
- return false
+func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {
}
-func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {}
+func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {}
-func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
- return false
+func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
}
func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {}
-func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
- return true
+func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
}
func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {}