summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/tests
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r--pkg/tcpip/tests/integration/BUILD1
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go7
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go336
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go314
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go4
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go6
-rw-r--r--pkg/tcpip/tests/integration/route_test.go18
7 files changed, 653 insertions, 33 deletions
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 1742a178d..218b218e7 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -7,6 +7,7 @@ go_test(
size = "small",
srcs = [
"forward_test.go",
+ "iptables_test.go",
"link_resolution_test.go",
"loopback_test.go",
"multicast_broadcast_test.go",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index ac9670f9a..aedf1845e 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -436,9 +436,10 @@ func TestForwarding(t *testing.T) {
write := func(ep tcpip.Endpoint, data []byte) {
t.Helper()
- dataPayload := tcpip.SlicePayload(data)
+ var r bytes.Reader
+ r.Reset(data)
var wOpts tcpip.WriteOptions
- n, err := ep.Write(dataPayload, wOpts)
+ n, err := ep.Write(&r, wOpts)
if err != nil {
t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
}
@@ -486,7 +487,7 @@ func TestForwarding(t *testing.T) {
read(serverCH, serverEP, data, clientAddr)
- data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12})
+ data = []byte{5, 6, 7, 8, 9, 10, 11, 12}
write(serverEP, data)
read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
})
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
new file mode 100644
index 000000000..21a8dd291
--- /dev/null
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -0,0 +1,336 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type inputIfNameMatcher struct {
+ name string
+}
+
+var _ stack.Matcher = (*inputIfNameMatcher)(nil)
+
+func (*inputIfNameMatcher) Name() string {
+ return "inputIfNameMatcher"
+}
+
+func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) {
+ return (hook == stack.Input && im.name != "" && im.name == inNicName), false
+}
+
+const (
+ nicID = 1
+ nicName = "nic1"
+ anotherNicName = "nic2"
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ srcAddrV4 = "\x0a\x00\x00\x01"
+ dstAddrV4 = "\x0a\x00\x00\x02"
+ srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ payloadSize = 20
+)
+
+func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
+ t.Helper()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ })
+ e := channel.New(0, header.IPv6MinimumMTU, linkAddr)
+ nicOpts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err)
+ }
+ return s, e
+}
+
+func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
+ t.Helper()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
+ e := channel.New(0, header.IPv4MinimumMTU, linkAddr)
+ nicOpts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err)
+ }
+ return s, e
+}
+
+func genPacketV6() *stack.PacketBuffer {
+ pktSize := header.IPv6MinimumSize + payloadSize
+ hdr := buffer.NewPrependable(pktSize)
+ ip := header.IPv6(hdr.Prepend(pktSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: payloadSize,
+ TransportProtocol: 99,
+ HopLimit: 255,
+ SrcAddr: srcAddrV6,
+ DstAddr: dstAddrV6,
+ })
+ vv := hdr.View().ToVectorisedView()
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
+}
+
+func genPacketV4() *stack.PacketBuffer {
+ pktSize := header.IPv4MinimumSize + payloadSize
+ hdr := buffer.NewPrependable(pktSize)
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&header.IPv4Fields{
+ TOS: 0,
+ TotalLength: uint16(pktSize),
+ ID: 1,
+ Flags: 0,
+ FragmentOffset: 16,
+ TTL: 48,
+ Protocol: 99,
+ SrcAddr: srcAddrV4,
+ DstAddr: dstAddrV4,
+ })
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ vv := hdr.View().ToVectorisedView()
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
+}
+
+func TestIPTablesStatsForInput(t *testing.T) {
+ tests := []struct {
+ name string
+ setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint)
+ setupFilter func(*testing.T, *stack.Stack)
+ genPacket func() *stack.PacketBuffer
+ proto tcpip.NetworkProtocolNumber
+ expectReceived int
+ expectInputDropped int
+ }{
+ {
+ name: "IPv6 Accept",
+ setupStack: genStackV6,
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv4 Accept",
+ setupStack: genStackV4,
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv6 Drop (input interface matches)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv4 Drop (input interface matches)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv6 Accept (input interface does not match)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv4 Accept (input interface does not match)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv6 Drop (input interface does not match but invert is true)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
+ InputInterface: anotherNicName,
+ InputInterfaceInvert: true,
+ }
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv4 Drop (input interface does not match but invert is true)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
+ InputInterface: anotherNicName,
+ InputInterfaceInvert: true,
+ }
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv6 Accept (input interface does not match using a matcher)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv4 Accept (input interface does not match using a matcher)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, e := test.setupStack(t)
+ test.setupFilter(t, s)
+ e.InjectInbound(test.proto, test.genPacket())
+
+ if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived {
+ t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived)
+ }
+ if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped {
+ t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index af32d3009..f85164c5b 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -23,6 +23,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
@@ -32,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -207,8 +209,10 @@ func TestPing(t *testing.T) {
defer ep.Close()
icmpBuf := test.icmpBuf(t)
+ var r bytes.Reader
+ r.Reset(icmpBuf)
wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
- if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil {
+ if n, err := ep.Write(&r, wOpts); err != nil {
t.Fatalf("ep.Write(_, _): %s", err)
} else if want := int64(len(icmpBuf)); n != want {
t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
@@ -358,9 +362,11 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
// Wait for an error due to link resolution failing, or the endpoint to be
// writable.
<-ch
+ var r bytes.Reader
+ r.Reset([]byte{0})
var wOpts tcpip.WriteOptions
- if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr {
- t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr)
+ if n, err := clientEP.Write(&r, wOpts); err != test.expectedWriteErr {
+ t.Errorf("got clientEP.Write(_, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr)
}
if test.expectedWriteErr == nil {
@@ -404,20 +410,34 @@ func TestGetLinkAddress(t *testing.T) {
)
tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedLinkAddr bool
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedOk bool
}{
{
- name: "IPv4",
+ name: "IPv4 resolvable",
netProto: ipv4.ProtocolNumber,
remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ expectedOk: true,
},
{
- name: "IPv6",
+ name: "IPv6 resolvable",
netProto: ipv6.ProtocolNumber,
remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ expectedOk: true,
+ },
+ {
+ name: "IPv4 not resolvable",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
+ expectedOk: false,
+ },
+ {
+ name: "IPv6 not resolvable",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
+ expectedOk: false,
},
}
@@ -432,27 +452,279 @@ func TestGetLinkAddress(t *testing.T) {
host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
- for i := 0; i < 2; i++ {
- addr, ch, err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(tcpip.LinkAddress, bool) {})
- var want *tcpip.Error
- if i == 0 {
- want = tcpip.ErrWouldBlock
+ ch := make(chan stack.LinkResolutionResult, 1)
+ if err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ }); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, tcpip.ErrWouldBlock)
+ }
+ wantRes := stack.LinkResolutionResult{Success: test.expectedOk}
+ if test.expectedOk {
+ wantRes.LinkAddress = linkAddr2
+ }
+ if diff := cmp.Diff(wantRes, <-ch); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestRouteResolvedFields(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ localAddr tcpip.Address
+ remoteAddr tcpip.Address
+ immediatelyResolvable bool
+ expectedSuccess bool
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "IPv4 immediately resolvable",
+ netProto: ipv4.ProtocolNumber,
+ localAddr: ipv4Addr1.AddressWithPrefix.Address,
+ remoteAddr: header.IPv4AllSystems,
+ immediatelyResolvable: true,
+ expectedSuccess: true,
+ expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems),
+ },
+ {
+ name: "IPv6 immediately resolvable",
+ netProto: ipv6.ProtocolNumber,
+ localAddr: ipv6Addr1.AddressWithPrefix.Address,
+ remoteAddr: header.IPv6AllNodesMulticastAddress,
+ immediatelyResolvable: true,
+ expectedSuccess: true,
+ expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
+ },
+ {
+ name: "IPv4 resolvable",
+ netProto: ipv4.ProtocolNumber,
+ localAddr: ipv4Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: true,
+ expectedLinkAddr: linkAddr2,
+ },
+ {
+ name: "IPv6 resolvable",
+ netProto: ipv6.ProtocolNumber,
+ localAddr: ipv6Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: true,
+ expectedLinkAddr: linkAddr2,
+ },
+ {
+ name: "IPv4 not resolvable",
+ netProto: ipv4.ProtocolNumber,
+ localAddr: ipv4Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: false,
+ },
+ {
+ name: "IPv6 not resolvable",
+ netProto: ipv6.ProtocolNumber,
+ localAddr: ipv6Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, useNeighborCache := range []bool{true, false} {
+ t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ UseNeighborCache: useNeighborCache,
+ }
+
+ host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
+ r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
+ }
+ defer r.Release()
+
+ var wantRouteInfo stack.RouteInfo
+ wantRouteInfo.LocalLinkAddress = linkAddr1
+ wantRouteInfo.LocalAddress = test.localAddr
+ wantRouteInfo.RemoteAddress = test.remoteAddr
+ wantRouteInfo.NetProto = test.netProto
+ wantRouteInfo.Loop = stack.PacketOut
+ wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr
+
+ ch := make(chan stack.ResolvedFieldsResult, 1)
+
+ if !test.immediatelyResolvable {
+ wantUnresolvedRouteInfo := wantRouteInfo
+ wantUnresolvedRouteInfo.RemoteLinkAddress = ""
+
+ if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
+ ch <- r
+ }); err != tcpip.ErrWouldBlock {
+ t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
- if err != want {
- t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = (%s, _, %s), want = (_, _, %s)", host1NICID, test.remoteAddr, test.netProto, addr, err, want)
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
}
- if i == 0 {
- <-ch
- continue
+ if !test.expectedSuccess {
+ return
}
- if addr != linkAddr2 {
- t.Fatalf("got addr = %s, want = %s", addr, linkAddr2)
+ // At this point the neighbor table should be populated so the route
+ // should be immediately resolvable.
+ }
+
+ if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
+ ch <- r
+ }); err != nil {
+ t.Errorf("r.ResolvedFields(_): %s", err)
+ }
+ select {
+ case routeResolveRes := <-ch:
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("expected route to be immediately resolvable")
}
})
}
})
}
}
+
+func TestWritePacketsLinkResolution(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedWriteErr *tcpip.Error
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ expectedWriteErr: nil,
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ expectedWriteErr: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ }
+
+ host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
+
+ var serverWQ waiter.Queue
+ serverWE, serverCH := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverWE, waiter.EventIn)
+ serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ)
+ if err != nil {
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err)
+ }
+ defer serverEP.Close()
+
+ serverAddr := tcpip.FullAddress{Port: 1234}
+ if err := serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+
+ r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
+ }
+ defer r.Release()
+
+ data := []byte{1, 2}
+ var pkts stack.PacketBufferList
+ for _, d := range data {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
+ Data: buffer.View([]byte{d}).ToVectorisedView(),
+ })
+ pkt.TransportProtocolNumber = udp.ProtocolNumber
+ length := uint16(pkt.Size())
+ udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ udpHdr.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: serverAddr.Port,
+ Length: length,
+ })
+ xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
+
+ pkts.PushBack(pkt)
+ }
+
+ params := stack.NetworkHeaderParams{
+ Protocol: udp.ProtocolNumber,
+ TTL: 64,
+ TOS: stack.DefaultTOS,
+ }
+
+ if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil {
+ t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err)
+ } else if want := pkts.Len(); want != n {
+ t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want)
+ }
+
+ var writer bytes.Buffer
+ count := 0
+ for {
+ var rOpts tcpip.ReadOptions
+ res, err := serverEP.Read(&writer, rOpts)
+ if err != nil {
+ if err == tcpip.ErrWouldBlock {
+ // Should not have anymore bytes to read after we read the sent
+ // number of bytes.
+ if count == len(data) {
+ break
+ }
+
+ <-serverCH
+ continue
+ }
+
+ t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err)
+ }
+ count += res.Count
+ }
+
+ if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want {
+ t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want)
+ }
+ if diff := cmp.Diff(data, writer.Bytes()); diff != "" {
+ t.Errorf("read bytes mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 3b13ba04d..761283b66 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -232,7 +232,9 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
Port: localPort,
},
}
- n, err := sep.Write(tcpip.SlicePayload(data), wopts)
+ var r bytes.Reader
+ r.Reset(data)
+ n, err := sep.Write(&r, wopts)
if err != nil {
t.Fatalf("sep.Write(_, _): %s", err)
}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index ce7c16bd1..9cc12fa58 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -586,8 +586,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
Port: localPort,
},
}
- data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4})
- if n, err := wep.ep.Write(data, writeOpts); err != nil {
+ data := []byte{byte(i), 2, 3, 4}
+ var r bytes.Reader
+ r.Reset(data)
+ if n, err := wep.ep.Write(&r, writeOpts); err != nil {
t.Fatalf("eps[%d].Write(_, _): %s", i, err)
} else if want := int64(len(data)); n != want {
t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want)
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index b222d2b05..35ee7437a 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -194,9 +194,11 @@ func TestLocalPing(t *testing.T) {
return
}
- payload := tcpip.SlicePayload(test.icmpBuf(t))
+ payload := test.icmpBuf(t)
+ var r bytes.Reader
+ r.Reset(payload)
var wOpts tcpip.WriteOptions
- if n, err := ep.Write(payload, wOpts); err != nil {
+ if n, err := ep.Write(&r, wOpts); err != nil {
t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
} else if n != int64(len(payload)) {
t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload))
@@ -329,12 +331,14 @@ func TestLocalUDP(t *testing.T) {
Port: 80,
}
- clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ clientPayload := []byte{1, 2, 3, 4}
{
+ var r bytes.Reader
+ r.Reset(clientPayload)
wOpts := tcpip.WriteOptions{
To: &serverAddr,
}
- if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr {
+ if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr {
t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
} else if subTest.expectedWriteErr != nil {
// Nothing else to test if we expected not to be able to send the
@@ -376,12 +380,14 @@ func TestLocalUDP(t *testing.T) {
}
}
- serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ serverPayload := []byte{1, 2, 3, 4}
{
+ var r bytes.Reader
+ r.Reset(serverPayload)
wOpts := tcpip.WriteOptions{
To: &clientAddr,
}
- if n, err := server.Write(serverPayload, wOpts); err != nil {
+ if n, err := server.Write(&r, wOpts); err != nil {
t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
} else if n != int64(len(serverPayload)) {
t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload))