diff options
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/forward_test.go | 321 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 336 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 778 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/loopback_test.go | 35 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/multicast_broadcast_test.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/route_test.go | 39 |
7 files changed, 1308 insertions, 220 deletions
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 1742a178d..71695b630 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -7,6 +7,7 @@ go_test( size = "small", srcs = [ "forward_test.go", + "iptables_test.go", "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", @@ -16,6 +17,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index ac9670f9a..38e1881c7 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -38,96 +38,207 @@ import ( var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil) var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil) -// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner -// link endpoint and checks the destination link address before delivering -// network packets to the network dispatcher. -// -// See ethernet.Endpoint for more details. -func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { - var e endpointWithDestinationCheck - e.Endpoint.Init(ethernet.New(ep), &e) - return &e -} - -// endpointWithDestinationCheck is a link endpoint that checks the destination -// link address before delivering network packets to the network dispatcher. -type endpointWithDestinationCheck struct { - nested.Endpoint -} - -// DeliverNetworkPacket implements stack.NetworkDispatcher. -func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { - e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) - } -} - -func TestForwarding(t *testing.T) { - const ( - host1NICID = 1 - routerNICID1 = 2 - routerNICID2 = 3 - host2NICID = 4 - - listenPort = 8080 - ) +const ( + host1NICID = 1 + routerNICID1 = 2 + routerNICID2 = 3 + host2NICID = 4 +) - host1IPv4Addr := tcpip.ProtocolAddress{ +var ( + host1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 24, }, } - routerNIC1IPv4Addr := tcpip.ProtocolAddress{ + routerNIC1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } - routerNIC2IPv4Addr := tcpip.ProtocolAddress{ + routerNIC2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), PrefixLen: 8, }, } - host2IPv4Addr := tcpip.ProtocolAddress{ + host2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()), PrefixLen: 8, }, } - host1IPv6Addr := tcpip.ProtocolAddress{ + host1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::2").To16()), PrefixLen: 64, }, } - routerNIC1IPv6Addr := tcpip.ProtocolAddress{ + routerNIC1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::1").To16()), PrefixLen: 64, }, } - routerNIC2IPv6Addr := tcpip.ProtocolAddress{ + routerNIC2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::1").To16()), PrefixLen: 64, }, } - host2IPv6Addr := tcpip.ProtocolAddress{ + host2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::2").To16()), PrefixLen: 64, }, } +) + +func setupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) { + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) + + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) + } + if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) + } + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + } + if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + }) + routerStack.SetRouteTable([]tcpip.Route{ + { + Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + { + Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + }) +} + +// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner +// link endpoint and checks the destination link address before delivering +// network packets to the network dispatcher. +// +// See ethernet.Endpoint for more details. +func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { + var e endpointWithDestinationCheck + e.Endpoint.Init(ethernet.New(ep), &e) + return &e +} + +// endpointWithDestinationCheck is a link endpoint that checks the destination +// link address before delivering network packets to the network dispatcher. +type endpointWithDestinationCheck struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) + } +} + +func TestForwarding(t *testing.T) { + const listenPort = 8080 type endpointAndAddresses struct { serverEP tcpip.Endpoint @@ -229,7 +340,7 @@ func TestForwarding(t *testing.T) { subTests := []struct { name string proto tcpip.TransportProtocolNumber - expectedConnectErr *tcpip.Error + expectedConnectErr tcpip.Error setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) needRemoteAddr bool }{ @@ -250,7 +361,7 @@ func TestForwarding(t *testing.T) { { name: "TCP", proto: tcp.ProtocolNumber, - expectedConnectErr: tcpip.ErrConnectStarted, + expectedConnectErr: &tcpip.ErrConnectStarted{}, setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { t.Helper() @@ -260,7 +371,7 @@ func TestForwarding(t *testing.T) { var addr tcpip.FullAddress for { newEP, wq, err := ep.Accept(&addr) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-ch continue } @@ -294,113 +405,7 @@ func TestForwarding(t *testing.T) { host1Stack := stack.New(stackOpts) routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - - host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) - routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) - - if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) - } - if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) - } - if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) - } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - }) - routerStack.SetRouteTable([]tcpip.Route{ - { - Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - { - Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - }) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) defer epsAndAddrs.serverEP.Close() @@ -415,8 +420,11 @@ func TestForwarding(t *testing.T) { t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) } - if err := epsAndAddrs.clientEP.Connect(serverAddr); err != subTest.expectedConnectErr { - t.Fatalf("got epsAndAddrs.clientEP.Connect(%#v) = %s, want = %s", serverAddr, err, subTest.expectedConnectErr) + { + err := epsAndAddrs.clientEP.Connect(serverAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) + } } if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) @@ -436,9 +444,10 @@ func TestForwarding(t *testing.T) { write := func(ep tcpip.Endpoint, data []byte) { t.Helper() - dataPayload := tcpip.SlicePayload(data) + var r bytes.Reader + r.Reset(data) var wOpts tcpip.WriteOptions - n, err := ep.Write(dataPayload, wOpts) + n, err := ep.Write(&r, wOpts) if err != nil { t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) } @@ -486,7 +495,7 @@ func TestForwarding(t *testing.T) { read(serverCH, serverEP, data, clientAddr) - data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12}) + data = []byte{5, 6, 7, 8, 9, 10, 11, 12} write(serverEP, data) read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) }) diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go new file mode 100644 index 000000000..21a8dd291 --- /dev/null +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -0,0 +1,336 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type inputIfNameMatcher struct { + name string +} + +var _ stack.Matcher = (*inputIfNameMatcher)(nil) + +func (*inputIfNameMatcher) Name() string { + return "inputIfNameMatcher" +} + +func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) { + return (hook == stack.Input && im.name != "" && im.name == inNicName), false +} + +const ( + nicID = 1 + nicName = "nic1" + anotherNicName = "nic2" + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + srcAddrV4 = "\x0a\x00\x00\x01" + dstAddrV4 = "\x0a\x00\x00\x02" + srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + payloadSize = 20 +) + +func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + }) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) + } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err) + } + return s, e +} + +func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + e := channel.New(0, header.IPv4MinimumMTU, linkAddr) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) + } + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err) + } + return s, e +} + +func genPacketV6() *stack.PacketBuffer { + pktSize := header.IPv6MinimumSize + payloadSize + hdr := buffer.NewPrependable(pktSize) + ip := header.IPv6(hdr.Prepend(pktSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: payloadSize, + TransportProtocol: 99, + HopLimit: 255, + SrcAddr: srcAddrV6, + DstAddr: dstAddrV6, + }) + vv := hdr.View().ToVectorisedView() + return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) +} + +func genPacketV4() *stack.PacketBuffer { + pktSize := header.IPv4MinimumSize + payloadSize + hdr := buffer.NewPrependable(pktSize) + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&header.IPv4Fields{ + TOS: 0, + TotalLength: uint16(pktSize), + ID: 1, + Flags: 0, + FragmentOffset: 16, + TTL: 48, + Protocol: 99, + SrcAddr: srcAddrV4, + DstAddr: dstAddrV4, + }) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + vv := hdr.View().ToVectorisedView() + return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) +} + +func TestIPTablesStatsForInput(t *testing.T) { + tests := []struct { + name string + setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint) + setupFilter func(*testing.T, *stack.Stack) + genPacket func() *stack.PacketBuffer + proto tcpip.NetworkProtocolNumber + expectReceived int + expectInputDropped int + }{ + { + name: "IPv6 Accept", + setupStack: genStackV6, + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept", + setupStack: genStackV4, + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv6 Drop (input interface matches)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv4 Drop (input interface matches)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv6 Accept (input interface does not match)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept (input interface does not match)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv6 Drop (input interface does not match but invert is true)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + InputInterface: anotherNicName, + InputInterfaceInvert: true, + } + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv4 Drop (input interface does not match but invert is true)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + InputInterface: anotherNicName, + InputInterfaceInvert: true, + } + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv6 Accept (input interface does not match using a matcher)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept (input interface does not match using a matcher)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := test.setupStack(t) + test.setupFilter(t, s) + e.InjectInbound(test.proto, test.genPacket()) + + if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived { + t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived) + } + if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped { + t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index af32d3009..7069352f2 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -19,11 +19,14 @@ import ( "fmt" "net" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" @@ -32,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -207,8 +211,10 @@ func TestPing(t *testing.T) { defer ep.Close() icmpBuf := test.icmpBuf(t) + var r bytes.Reader + r.Reset(icmpBuf) wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} - if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil { + if n, err := ep.Write(&r, wOpts); err != nil { t.Fatalf("ep.Write(_, _): %s", err) } else if want := int64(len(icmpBuf)); n != want { t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) @@ -251,7 +257,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name string netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error sockError tcpip.SockError }{ { @@ -270,9 +276,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name: "IPv4 without resolvable remote", netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr3.AddressWithPrefix.Address, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: tcpip.ErrNoRoute, + Err: &tcpip.ErrNoRoute{}, ErrType: byte(header.ICMPv4DstUnreachable), ErrCode: byte(header.ICMPv4HostUnreachable), ErrOrigin: tcpip.SockExtErrorOriginICMP, @@ -292,9 +298,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name: "IPv6 without resolvable remote", netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr3.AddressWithPrefix.Address, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: tcpip.ErrNoRoute, + Err: &tcpip.ErrNoRoute{}, ErrType: byte(header.ICMPv6DstUnreachable), ErrCode: byte(header.ICMPv6AddressUnreachable), ErrOrigin: tcpip.SockExtErrorOriginICMP6, @@ -351,16 +357,24 @@ func TestTCPLinkResolutionFailure(t *testing.T) { remoteAddr := listenerAddr remoteAddr.Addr = test.remoteAddr - if err := clientEP.Connect(remoteAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, tcpip.ErrConnectStarted) + { + err := clientEP.Connect(remoteAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{}) + } } // Wait for an error due to link resolution failing, or the endpoint to be // writable. <-ch - var wOpts tcpip.WriteOptions - if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr { - t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr) + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + _, err := clientEP.Write(&r, wOpts) + if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" { + t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff) + } } if test.expectedWriteErr == nil { @@ -374,7 +388,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { sockErrCmpOpts := []cmp.Option{ cmpopts.IgnoreUnexported(tcpip.SockError{}), - cmp.Comparer(func(a, b *tcpip.Error) bool { + cmp.Comparer(func(a, b tcpip.Error) bool { // tcpip.Error holds an unexported field but the errors netstack uses // are pre defined so we can simply compare pointers. return a == b @@ -404,20 +418,134 @@ func TestGetLinkAddress(t *testing.T) { ) tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedLinkAddr bool + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedOk bool }{ { - name: "IPv4", + name: "IPv4 resolvable", netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedOk: true, }, { - name: "IPv6", + name: "IPv6 resolvable", netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedOk: true, + }, + { + name: "IPv4 not resolvable", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + expectedOk: false, + }, + { + name: "IPv6 not resolvable", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + expectedOk: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, useNeighborCache := range []bool{true, false} { + t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + UseNeighborCache: useNeighborCache, + } + + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + + ch := make(chan stack.LinkResolutionResult, 1) + err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) + } + wantRes := stack.LinkResolutionResult{Success: test.expectedOk} + if test.expectedOk { + wantRes.LinkAddress = linkAddr2 + } + if diff := cmp.Diff(wantRes, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + }) + } + }) + } +} + +func TestRouteResolvedFields(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + localAddr tcpip.Address + remoteAddr tcpip.Address + immediatelyResolvable bool + expectedSuccess bool + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "IPv4 immediately resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: header.IPv4AllSystems, + immediatelyResolvable: true, + expectedSuccess: true, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems), + }, + { + name: "IPv6 immediately resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: header.IPv6AllNodesMulticastAddress, + immediatelyResolvable: true, + expectedSuccess: true, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), + }, + { + name: "IPv4 resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: true, + expectedLinkAddr: linkAddr2, + }, + { + name: "IPv6 resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: true, + expectedLinkAddr: linkAddr2, + }, + { + name: "IPv4 not resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: false, + }, + { + name: "IPv6 not resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: false, }, } @@ -431,28 +559,618 @@ func TestGetLinkAddress(t *testing.T) { } host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) + if err != nil { + t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) + } + defer r.Release() + + var wantRouteInfo stack.RouteInfo + wantRouteInfo.LocalLinkAddress = linkAddr1 + wantRouteInfo.LocalAddress = test.localAddr + wantRouteInfo.RemoteAddress = test.remoteAddr + wantRouteInfo.NetProto = test.netProto + wantRouteInfo.Loop = stack.PacketOut + wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr + + ch := make(chan stack.ResolvedFieldsResult, 1) + + if !test.immediatelyResolvable { + wantUnresolvedRouteInfo := wantRouteInfo + wantUnresolvedRouteInfo.RemoteLinkAddress = "" - for i := 0; i < 2; i++ { - addr, ch, err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(tcpip.LinkAddress, bool) {}) - var want *tcpip.Error - if i == 0 { - want = tcpip.ErrWouldBlock + err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } - if err != want { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = (%s, _, %s), want = (_, _, %s)", host1NICID, test.remoteAddr, test.netProto, addr, err, want) + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) } - if i == 0 { - <-ch - continue + if !test.expectedSuccess { + return } - if addr != linkAddr2 { - t.Fatalf("got addr = %s, want = %s", addr, linkAddr2) + // At this point the neighbor table should be populated so the route + // should be immediately resolvable. + } + + if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r + }); err != nil { + t.Errorf("r.ResolvedFields(_): %s", err) + } + select { + case routeResolveRes := <-ch: + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected route to be immediately resolvable") } }) } }) } } + +func TestWritePacketsLinkResolution(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedWriteErr tcpip.Error + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + } + + host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) + + var serverWQ waiter.Queue + serverWE, serverCH := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverWE, waiter.EventIn) + serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err) + } + defer serverEP.Close() + + serverAddr := tcpip.FullAddress{Port: 1234} + if err := serverEP.Bind(serverAddr); err != nil { + t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err) + } + + r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) + if err != nil { + t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) + } + defer r.Release() + + data := []byte{1, 2} + var pkts stack.PacketBufferList + for _, d := range data { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), + Data: buffer.View([]byte{d}).ToVectorisedView(), + }) + pkt.TransportProtocolNumber = udp.ProtocolNumber + length := uint16(pkt.Size()) + udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: serverAddr.Port, + Length: length, + }) + xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) + + pkts.PushBack(pkt) + } + + params := stack.NetworkHeaderParams{ + Protocol: udp.ProtocolNumber, + TTL: 64, + TOS: stack.DefaultTOS, + } + + if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil { + t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err) + } else if want := pkts.Len(); want != n { + t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want) + } + + var writer bytes.Buffer + count := 0 + for { + var rOpts tcpip.ReadOptions + res, err := serverEP.Read(&writer, rOpts) + if err != nil { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + // Should not have anymore bytes to read after we read the sent + // number of bytes. + if count == len(data) { + break + } + + <-serverCH + continue + } + + t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err) + } + count += res.Count + } + + if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want { + t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want) + } + if diff := cmp.Diff(data, writer.Bytes()); diff != "" { + t.Errorf("read bytes mismatch (-want +got):\n%s", diff) + } + }) + } +} + +type eventType int + +const ( + entryAdded eventType = iota + entryChanged + entryRemoved +) + +func (t eventType) String() string { + switch t { + case entryAdded: + return "add" + case entryChanged: + return "change" + case entryRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +type eventInfo struct { + eventType eventType + nicID tcpip.NICID + entry stack.NeighborEntry +} + +func (e eventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) +} + +var _ stack.NUDDispatcher = (*nudDispatcher)(nil) + +type nudDispatcher struct { + c chan eventInfo +} + +func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryChanged, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryRemoved, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) waitForEvent(want eventInfo) error { + if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) + } + return nil +} + +// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it +// that the neighbor used for a route is reachable. +func TestTCPConfirmNeighborReachability(t *testing.T) { + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + neighborAddr tcpip.Address + getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) + isHost1Listener bool + }{ + { + name: "IPv4 active connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 active connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 passive connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv4 passive connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + nudDisp := nudDispatcher{ + c: make(chan eventInfo, 3), + } + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: clock, + UseNeighborCache: true, + } + host1StackOpts := stackOpts + host1StackOpts.NUDDisp = &nudDisp + + host1Stack := stack.New(host1StackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + // Add a reachable dynamic entry to our neighbor table for the remote. + { + ch := make(chan stack.LinkResolutionResult, 1) + err := host1Stack.GetLinkAddress(host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) + } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: linkAddr2, Success: true}, <-ch); diff != "" { + t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryAdded, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, + }); err != nil { + t.Fatalf("error waiting for initial NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the remote's neighbor entry to be stale before creating a + // TCP connection from host1 to some remote. + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d): %s", host1NICID, err) + } + // The maximum reachable time for a neighbor is some maximum random factor + // applied to the base reachable time. + // + // See NUDConfigurations.BaseReachableTime for more information. + maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + + listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) + defer listenerEP.Close() + defer clientEP.Close() + listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} + if err := listenerEP.Bind(listenerAddr); err != nil { + t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) + } + if err := listenerEP.Listen(1); err != nil { + t.Fatalf("listenerEP.Listen(1): %s", err) + } + { + err := clientEP.Connect(listenerAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{}) + } + } + + // Wait for the TCP handshake to complete then make sure the neighbor is + // reachable without entering the probe state as TCP should provide NUD + // with confirmation that the neighbor is reachable (indicated by a + // successful 3-way handshake). + <-clientCH + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the neighbor to be stale again then send data to the remote. + // + // On successful transmission, the neighbor should become reachable + // without probing the neighbor as a TCP ACK would be received which is an + // indication of the neighbor being reachable. + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := clientEP.Write(&r, wOpts); err != nil { + t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if test.isHost1Listener { + // If host1 is not the client, host1 does not send any data so TCP + // has no way to know it is making forward progress. Because of this, + // TCP should not mark the route reachable and NUD should go through the + // probe state. + clock.Advance(nudConfigs.DelayFirstProbeTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for probe NUD event: %s", err) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 3b13ba04d..ab67762ef 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -37,7 +37,7 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) type ndpDispatcher struct{} -func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { } func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { @@ -232,7 +232,9 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { Port: localPort, }, } - n, err := sep.Write(tcpip.SlicePayload(data), wopts) + var r bytes.Reader + r.Reset(data) + n, err := sep.Write(&r, wopts) if err != nil { t.Fatalf("sep.Write(_, _): %s", err) } @@ -260,8 +262,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { if diff := cmp.Diff(data, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } - } else if err != tcpip.ErrWouldBlock { - t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) + } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) } }) } @@ -320,11 +322,14 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) } - if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: data.ToVectorisedView(), - })); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) + { + err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })) + if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { + t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) + } } } @@ -468,13 +473,17 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { Addr: test.dstAddr, Port: localPort, } - if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + { + err := connectingEndpoint.Connect(connectAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + } } if !test.expectAccept { - if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + _, _, err := listeningEndpoint.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } return } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index ce7c16bd1..d685fdd36 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -479,8 +479,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { if diff := cmp.Diff(data, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } - } else if err != tcpip.ErrWouldBlock { - t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) + } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) } }) } @@ -586,8 +586,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) { Port: localPort, }, } - data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) - if n, err := wep.ep.Write(data, writeOpts); err != nil { + data := []byte{byte(i), 2, 3, 4} + var r bytes.Reader + r.Reset(data) + if n, err := wep.ep.Write(&r, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want) @@ -759,8 +761,11 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { if err := ep.SetSockOpt(&removeOpt); err != nil { t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) } - if _, err := ep.Read(&buf, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock) + { + _, err := ep.Read(&buf, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{}) + } } }) } diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index b222d2b05..9654c9527 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -81,7 +81,7 @@ func TestLocalPing(t *testing.T) { linkEndpoint func() stack.LinkEndpoint localAddr tcpip.Address icmpBuf func(*testing.T) buffer.View - expectedConnectErr *tcpip.Error + expectedConnectErr tcpip.Error checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) }{ { @@ -126,7 +126,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv4.ProtocolNumber, linkEndpoint: loopback.New, icmpBuf: ipv4ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, { @@ -135,7 +135,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv6.ProtocolNumber, linkEndpoint: loopback.New, icmpBuf: ipv6ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, { @@ -144,7 +144,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv4.ProtocolNumber, linkEndpoint: channelEP, icmpBuf: ipv4ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: channelEPCheck, }, { @@ -153,7 +153,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv6.ProtocolNumber, linkEndpoint: channelEP, icmpBuf: ipv6ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: channelEPCheck, }, } @@ -186,17 +186,22 @@ func TestLocalPing(t *testing.T) { defer ep.Close() connAddr := tcpip.FullAddress{Addr: test.localAddr} - if err := ep.Connect(connAddr); err != test.expectedConnectErr { - t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + { + err := ep.Connect(connAddr) + if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff) + } } if test.expectedConnectErr != nil { return } - payload := tcpip.SlicePayload(test.icmpBuf(t)) + payload := test.icmpBuf(t) + var r bytes.Reader + r.Reset(payload) var wOpts tcpip.WriteOptions - if n, err := ep.Write(payload, wOpts); err != nil { + if n, err := ep.Write(&r, wOpts); err != nil { t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) } else if n != int64(len(payload)) { t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) @@ -261,12 +266,12 @@ func TestLocalUDP(t *testing.T) { subTests := []struct { name string addAddress bool - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error }{ { name: "Unassigned local address", addAddress: false, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, }, { name: "Assigned local address", @@ -329,12 +334,14 @@ func TestLocalUDP(t *testing.T) { Port: 80, } - clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + clientPayload := []byte{1, 2, 3, 4} { + var r bytes.Reader + r.Reset(clientPayload) wOpts := tcpip.WriteOptions{ To: &serverAddr, } - if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr { t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) } else if subTest.expectedWriteErr != nil { // Nothing else to test if we expected not to be able to send the @@ -376,12 +383,14 @@ func TestLocalUDP(t *testing.T) { } } - serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + serverPayload := []byte{1, 2, 3, 4} { + var r bytes.Reader + r.Reset(serverPayload) wOpts := tcpip.WriteOptions{ To: &clientAddr, } - if n, err := server.Write(serverPayload, wOpts); err != nil { + if n, err := server.Write(&r, wOpts); err != nil { t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) } else if n != int64(len(serverPayload)) { t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload)) |