summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/BUILD30
-rw-r--r--pkg/tcpip/network/arp/BUILD53
-rw-r--r--pkg/tcpip/network/arp/arp_state_autogen.go3
-rw-r--r--pkg/tcpip/network/arp/arp_test.go713
-rw-r--r--pkg/tcpip/network/arp/stats_test.go51
-rw-r--r--pkg/tcpip/network/hash/BUILD13
-rw-r--r--pkg/tcpip/network/hash/hash_state_autogen.go3
-rw-r--r--pkg/tcpip/network/internal/fragmentation/BUILD54
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go68
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation_test.go636
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler_list.go221
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler_test.go233
-rw-r--r--pkg/tcpip/network/internal/ip/BUILD40
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go381
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go808
-rw-r--r--pkg/tcpip/network/internal/ip/ip_state_autogen.go3
-rw-r--r--pkg/tcpip/network/internal/testutil/BUILD20
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go129
-rw-r--r--pkg/tcpip/network/ip_test.go2020
-rw-r--r--pkg/tcpip/network/ipv4/BUILD68
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go387
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_state_autogen.go113
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go3228
-rw-r--r--pkg/tcpip/network/ipv4/stats_test.go99
-rw-r--r--pkg/tcpip/network/ipv6/BUILD72
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go1716
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_state_autogen.go136
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go3444
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go601
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go1376
-rw-r--r--pkg/tcpip/network/multicast_group_test.go1278
31 files changed, 547 insertions, 17450 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
deleted file mode 100644
index 7b1ff44f4..000000000
--- a/pkg/tcpip/network/BUILD
+++ /dev/null
@@ -1,30 +0,0 @@
-load("//tools:defs.bzl", "go_test")
-
-package(licenses = ["notice"])
-
-go_test(
- name = "ip_test",
- size = "small",
- srcs = [
- "ip_test.go",
- "multicast_group_test.go",
- ],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
deleted file mode 100644
index a72eb1aad..000000000
--- a/pkg/tcpip/network/arp/BUILD
+++ /dev/null
@@ -1,53 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "arp",
- srcs = [
- "arp.go",
- "stats.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/network/internal/ip",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "arp_test",
- size = "small",
- srcs = ["arp_test.go"],
- deps = [
- ":arp",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
- ],
-)
-
-go_test(
- name = "stats_test",
- size = "small",
- srcs = ["stats_test.go"],
- library = ":arp",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- ],
-)
diff --git a/pkg/tcpip/network/arp/arp_state_autogen.go b/pkg/tcpip/network/arp/arp_state_autogen.go
new file mode 100644
index 000000000..5cd8535e3
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package arp
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
deleted file mode 100644
index 94209b026..000000000
--- a/pkg/tcpip/network/arp/arp_test.go
+++ /dev/null
@@ -1,713 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package arp_test
-
-import (
- "context"
- "fmt"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
-)
-
-const (
- nicID = 1
-
- stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
-
- defaultChannelSize = 1
- defaultMTU = 65536
-
- // eventChanSize defines the size of event channels used by the neighbor
- // cache's event dispatcher. The size chosen here needs to be sufficient to
- // queue all the events received during tests before consumption.
- // If eventChanSize is too small, the tests may deadlock.
- eventChanSize = 32
-)
-
-var (
- stackAddr = testutil.MustParse4("10.0.0.1")
- remoteAddr = testutil.MustParse4("10.0.0.2")
- unknownAddr = testutil.MustParse4("10.0.0.3")
-)
-
-type eventType uint8
-
-const (
- entryAdded eventType = iota
- entryChanged
- entryRemoved
-)
-
-func (t eventType) String() string {
- switch t {
- case entryAdded:
- return "add"
- case entryChanged:
- return "change"
- case entryRemoved:
- return "remove"
- default:
- return fmt.Sprintf("unknown (%d)", t)
- }
-}
-
-type eventInfo struct {
- eventType eventType
- nicID tcpip.NICID
- entry stack.NeighborEntry
-}
-
-func (e eventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry)
-}
-
-// arpDispatcher implements NUDDispatcher to validate the dispatching of
-// events upon certain NUD state machine events.
-type arpDispatcher struct {
- // C is where events are queued
- C chan eventInfo
-}
-
-var _ stack.NUDDispatcher = (*arpDispatcher)(nil)
-
-func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryAdded,
- nicID: nicID,
- entry: entry,
- }
- d.C <- e
-}
-
-func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryChanged,
- nicID: nicID,
- entry: entry,
- }
- d.C <- e
-}
-
-func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryRemoved,
- nicID: nicID,
- entry: entry,
- }
- d.C <- e
-}
-
-func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
- select {
- case got := <-d.C:
- if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
- return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
- }
- case <-ctx.Done():
- return fmt.Errorf("%s for %s", ctx.Err(), want)
- }
- return nil
-}
-
-func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error {
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- return d.waitForEvent(ctx, want)
-}
-
-func (d *arpDispatcher) nextEvent() (eventInfo, bool) {
- select {
- case event := <-d.C:
- return event, true
- default:
- return eventInfo{}, false
- }
-}
-
-type testContext struct {
- s *stack.Stack
- linkEP *channel.Endpoint
- nudDisp *arpDispatcher
-}
-
-func newTestContext(t *testing.T) *testContext {
- c := stack.DefaultNUDConfigurations()
- // Transition from Reachable to Stale almost immediately to test if receiving
- // probes refreshes positive reachability.
- c.BaseReachableTime = time.Microsecond
-
- d := arpDispatcher{
- // Create an event channel large enough so the neighbor cache doesn't block
- // while dispatching events. Blocking could interfere with the timing of
- // NUD transitions.
- C: make(chan eventInfo, eventChanSize),
- }
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
- NUDConfigs: c,
- NUDDisp: &d,
- })
-
- ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
- ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- wep := stack.LinkEndpoint(ep)
-
- if testing.Verbose() {
- wep = sniffer.New(ep)
- }
- if err := s.CreateNIC(nicID, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- }})
-
- return &testContext{
- s: s,
- linkEP: ep,
- nudDisp: &d,
- }
-}
-
-func (c *testContext) cleanup() {
- c.linkEP.Close()
-}
-
-func TestMalformedPacket(t *testing.T) {
- c := newTestContext(t)
- defer c.cleanup()
-
- v := make(buffer.View, header.ARPSize)
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v.ToVectorisedView(),
- })
-
- c.linkEP.InjectInbound(arp.ProtocolNumber, pkt)
-
- if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
- }
- if got := c.s.Stats().ARP.MalformedPacketsReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.MalformedPacketsReceived.Value() = %d, want = 1", got)
- }
-}
-
-func TestDisabledEndpoint(t *testing.T) {
- c := newTestContext(t)
- defer c.cleanup()
-
- ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber)
- if err != nil {
- t.Fatalf("GetNetworkEndpoint(%d, header.ARPProtocolNumber) failed: %s", nicID, err)
- }
- ep.Disable()
-
- v := make(buffer.View, header.ARPSize)
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v.ToVectorisedView(),
- })
-
- c.linkEP.InjectInbound(arp.ProtocolNumber, pkt)
-
- if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
- }
- if got := c.s.Stats().ARP.DisabledPacketsReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.DisabledPacketsReceived.Value() = %d, want = 1", got)
- }
-}
-
-func TestDirectReply(t *testing.T) {
- c := newTestContext(t)
- defer c.cleanup()
-
- const senderMAC = "\x01\x02\x03\x04\x05\x06"
- const senderIPv4 = "\x0a\x00\x00\x02"
-
- v := make(buffer.View, header.ARPSize)
- h := header.ARP(v)
- h.SetIPv4OverEthernet()
- h.SetOp(header.ARPReply)
-
- copy(h.HardwareAddressSender(), senderMAC)
- copy(h.ProtocolAddressSender(), senderIPv4)
- copy(h.HardwareAddressTarget(), stackLinkAddr)
- copy(h.ProtocolAddressTarget(), stackAddr)
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v.ToVectorisedView(),
- })
-
- c.linkEP.InjectInbound(arp.ProtocolNumber, pkt)
-
- if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
- }
- if got := c.s.Stats().ARP.RepliesReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
- }
-}
-
-func TestDirectRequest(t *testing.T) {
- c := newTestContext(t)
- defer c.cleanup()
-
- tests := []struct {
- name string
- senderAddr tcpip.Address
- senderLinkAddr tcpip.LinkAddress
- targetAddr tcpip.Address
- isValid bool
- }{
- {
- name: "Loopback",
- senderAddr: stackAddr,
- senderLinkAddr: stackLinkAddr,
- targetAddr: stackAddr,
- isValid: true,
- },
- {
- name: "Remote",
- senderAddr: remoteAddr,
- senderLinkAddr: remoteLinkAddr,
- targetAddr: stackAddr,
- isValid: true,
- },
- {
- name: "RemoteInvalidTarget",
- senderAddr: remoteAddr,
- senderLinkAddr: remoteLinkAddr,
- targetAddr: unknownAddr,
- isValid: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- packetsRecv := c.s.Stats().ARP.PacketsReceived.Value()
- requestsRecv := c.s.Stats().ARP.RequestsReceived.Value()
- requestsRecvUnknownAddr := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value()
- outgoingReplies := c.s.Stats().ARP.OutgoingRepliesSent.Value()
-
- // Inject an incoming ARP request.
- v := make(buffer.View, header.ARPSize)
- h := header.ARP(v)
- h.SetIPv4OverEthernet()
- h.SetOp(header.ARPRequest)
- copy(h.HardwareAddressSender(), test.senderLinkAddr)
- copy(h.ProtocolAddressSender(), test.senderAddr)
- copy(h.ProtocolAddressTarget(), test.targetAddr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v.ToVectorisedView(),
- }))
-
- if got, want := c.s.Stats().ARP.PacketsReceived.Value(), packetsRecv+1; got != want {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want)
- }
- if got, want := c.s.Stats().ARP.RequestsReceived.Value(), requestsRecv+1; got != want {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want)
- }
-
- if !test.isValid {
- // No packets should be sent after receiving an invalid ARP request.
- // There is no need to perform a blocking read here, since packets are
- // sent in the same function that handles ARP requests.
- if pkt, ok := c.linkEP.Read(); ok {
- t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto)
- }
- if got, want := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(), requestsRecvUnknownAddr+1; got != want {
- t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() = %d, want = %d", got, want)
- }
- if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies; got != want {
- t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want)
- }
-
- return
- }
-
- if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies+1; got != want {
- t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want)
- }
-
- // Verify an ARP response was sent.
- pi, ok := c.linkEP.Read()
- if !ok {
- t.Fatal("expected ARP response to be sent, got none")
- }
-
- if pi.Proto != arp.ProtocolNumber {
- t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
- }
- rep := header.ARP(pi.Pkt.NetworkHeader().View())
- if !rep.IsValid() {
- t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
- }
- if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
- t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want)
- }
- if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
- t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want)
- }
- if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
- t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want)
- }
- if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
- t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want)
- }
-
- // Verify the sender was saved in the neighbor cache.
- wantEvent := eventInfo{
- eventType: entryAdded,
- nicID: nicID,
- entry: stack.NeighborEntry{
- Addr: test.senderAddr,
- LinkAddr: tcpip.LinkAddress(test.senderLinkAddr),
- State: stack.Stale,
- },
- }
- if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil {
- t.Fatal(err)
- }
-
- neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatalf("c.s.Neighbors(%d, %d): %s", nicID, ipv4.ProtocolNumber, err)
- }
-
- neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
- for _, n := range neighbors {
- if existing, ok := neighborByAddr[n.Addr]; ok {
- if diff := cmp.Diff(existing, n); diff != "" {
- t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff)
- }
- t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr)
- }
- neighborByAddr[n.Addr] = n
- }
-
- neigh, ok := neighborByAddr[test.senderAddr]
- if !ok {
- t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr)
- }
- if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want {
- t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want)
- }
- if got, want := neigh.State, stack.Stale; got != want {
- t.Errorf("got neighbor State = %s, want = %s", got, want)
- }
-
- // No more events should be dispatched
- for {
- event, ok := c.nudDisp.nextEvent()
- if !ok {
- break
- }
- t.Errorf("unexpected %s", event)
- }
- })
- }
-}
-
-var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil)
-
-type testLinkEndpoint struct {
- stack.LinkEndpoint
-
- writeErr tcpip.Error
-}
-
-func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- if t.writeErr != nil {
- return t.writeErr
- }
-
- return t.LinkEndpoint.WritePacket(r, protocol, pkt)
-}
-
-func TestLinkAddressRequest(t *testing.T) {
- const nicID = 1
-
- testAddr := tcpip.Address([]byte{1, 2, 3, 4})
-
- tests := []struct {
- name string
- nicAddr tcpip.Address
- localAddr tcpip.Address
- remoteLinkAddr tcpip.LinkAddress
- linkErr tcpip.Error
- expectedErr tcpip.Error
- expectedLocalAddr tcpip.Address
- expectedRemoteLinkAddr tcpip.LinkAddress
- expectedRequestsSent uint64
- expectedRequestBadLocalAddressErrors uint64
- expectedRequestInterfaceHasNoLocalAddressErrors uint64
- expectedRequestDroppedErrors uint64
- }{
- {
- name: "Unicast",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: remoteLinkAddr,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Multicast",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: "",
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Unicast with unspecified source",
- nicAddr: stackAddr,
- localAddr: "",
- remoteLinkAddr: remoteLinkAddr,
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: remoteLinkAddr,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Multicast with unspecified source",
- nicAddr: stackAddr,
- localAddr: "",
- remoteLinkAddr: "",
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Unicast with unassigned address",
- nicAddr: stackAddr,
- localAddr: testAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedErr: &tcpip.ErrBadLocalAddress{},
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 1,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Multicast with unassigned address",
- nicAddr: stackAddr,
- localAddr: testAddr,
- remoteLinkAddr: "",
- expectedErr: &tcpip.ErrBadLocalAddress{},
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 1,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Unicast with no local address available",
- nicAddr: "",
- localAddr: "",
- remoteLinkAddr: remoteLinkAddr,
- expectedErr: &tcpip.ErrNetworkUnreachable{},
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 1,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Multicast with no local address available",
- nicAddr: "",
- localAddr: "",
- remoteLinkAddr: "",
- expectedErr: &tcpip.ErrNetworkUnreachable{},
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 1,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Link error",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- linkErr: &tcpip.ErrInvalidEndpointState{},
- expectedErr: &tcpip.ErrInvalidEndpointState{},
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
- })
- linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
- if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err)
- }
- linkRes, ok := ep.(stack.LinkAddressResolver)
- if !ok {
- t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep)
- }
-
- if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
- }
- }
-
- {
- err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff)
- }
- }
-
- if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent {
- t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent)
- }
- if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != test.expectedRequestInterfaceHasNoLocalAddressErrors {
- t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestInterfaceHasNoLocalAddressErrors)
- }
- if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors {
- t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors)
- }
- if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors {
- t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors)
- }
-
- if test.expectedErr != nil {
- return
- }
-
- pkt, ok := linkEP.Read()
- if !ok {
- t.Fatal("expected to send a link address request")
- }
-
- if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
- }
-
- rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
- if got := rep.Op(); got != header.ARPRequest {
- t.Errorf("got Op = %d, want = %d", got, header.ARPRequest)
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
- t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr)
- }
- if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr {
- t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr)
- }
- if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want {
- t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
- }
- if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr {
- t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr)
- }
- })
- }
-}
-
-func TestDADARPRequestPacket(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: 1,
- RetransmitTimer: time.Second,
- },
- }), ipv4.NewProtocol},
- })
- e := channel.New(1, defaultMTU, stackLinkAddr)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- if res, err := s.CheckDuplicateAddress(nicID, header.IPv4ProtocolNumber, remoteAddr, func(stack.DADResult) {}); err != nil {
- t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, header.IPv4ProtocolNumber, remoteAddr, err)
- } else if res != stack.DADStarting {
- t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting)
- }
-
- pkt, ok := e.ReadContext(context.Background())
- if !ok {
- t.Fatal("expected to send an ARP request")
- }
-
- if pkt.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
- }
-
- req := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
- if !req.IsValid() {
- t.Errorf("got req.IsValid() = false, want = true")
- }
- if got := req.Op(); got != header.ARPRequest {
- t.Errorf("got req.Op() = %d, want = %d", got, header.ARPRequest)
- }
- if got := tcpip.LinkAddress(req.HardwareAddressSender()); got != stackLinkAddr {
- t.Errorf("got req.HardwareAddressSender() = %s, want = %s", got, stackLinkAddr)
- }
- if got := tcpip.Address(req.ProtocolAddressSender()); got != header.IPv4Any {
- t.Errorf("got req.ProtocolAddressSender() = %s, want = %s", got, header.IPv4Any)
- }
- if got, want := tcpip.LinkAddress(req.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want {
- t.Errorf("got req.HardwareAddressTarget() = %s, want = %s", got, want)
- }
- if got := tcpip.Address(req.ProtocolAddressTarget()); got != remoteAddr {
- t.Errorf("got req.ProtocolAddressTarget() = %s, want = %s", got, remoteAddr)
- }
-}
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
deleted file mode 100644
index 0df39ae81..000000000
--- a/pkg/tcpip/network/arp/stats_test.go
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package arp
-
-import (
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-var _ stack.NetworkInterface = (*testInterface)(nil)
-
-type testInterface struct {
- stack.NetworkInterface
- nicID tcpip.NICID
-}
-
-func (t *testInterface) ID() tcpip.NICID {
- return t.nicID
-}
-
-func TestMultiCounterStatsInitialization(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- var nic testInterface
- ep := proto.NewEndpoint(&nic, nil).(*endpoint)
- // At this point, the Stack's stats and the NetworkEndpoint's stats are
- // expected to be bound by a MultiCounterStat.
- refStack := s.Stats()
- refEP := ep.stats.localStats
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.arp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ARP).Elem(), reflect.ValueOf(&refStack.ARP).Elem()}); err != nil {
- t.Error(err)
- }
-}
diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD
deleted file mode 100644
index 872165866..000000000
--- a/pkg/tcpip/network/hash/BUILD
+++ /dev/null
@@ -1,13 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "hash",
- srcs = ["hash.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/rand",
- "//pkg/tcpip/header",
- ],
-)
diff --git a/pkg/tcpip/network/hash/hash_state_autogen.go b/pkg/tcpip/network/hash/hash_state_autogen.go
new file mode 100644
index 000000000..9467fe298
--- /dev/null
+++ b/pkg/tcpip/network/hash/hash_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package hash
diff --git a/pkg/tcpip/network/internal/fragmentation/BUILD b/pkg/tcpip/network/internal/fragmentation/BUILD
deleted file mode 100644
index 274f09092..000000000
--- a/pkg/tcpip/network/internal/fragmentation/BUILD
+++ /dev/null
@@ -1,54 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "reassembler_list",
- out = "reassembler_list.go",
- package = "fragmentation",
- prefix = "reassembler",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*reassembler",
- "Linker": "*reassembler",
- },
-)
-
-go_library(
- name = "fragmentation",
- srcs = [
- "fragmentation.go",
- "reassembler.go",
- "reassembler_list.go",
- ],
- visibility = [
- "//pkg/tcpip/network/ipv4:__pkg__",
- "//pkg/tcpip/network/ipv6:__pkg__",
- ],
- deps = [
- "//pkg/log",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "fragmentation_test",
- size = "small",
- srcs = [
- "fragmentation_test.go",
- "reassembler_test.go",
- ],
- library = ":fragmentation",
- deps = [
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/network/internal/testutil",
- "//pkg/tcpip/stack",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go
new file mode 100644
index 000000000..21c5774e9
--- /dev/null
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go
@@ -0,0 +1,68 @@
+// automatically generated by stateify.
+
+package fragmentation
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (l *reassemblerList) StateTypeName() string {
+ return "pkg/tcpip/network/internal/fragmentation.reassemblerList"
+}
+
+func (l *reassemblerList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *reassemblerList) beforeSave() {}
+
+// +checklocksignore
+func (l *reassemblerList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *reassemblerList) afterLoad() {}
+
+// +checklocksignore
+func (l *reassemblerList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *reassemblerEntry) StateTypeName() string {
+ return "pkg/tcpip/network/internal/fragmentation.reassemblerEntry"
+}
+
+func (e *reassemblerEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *reassemblerEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *reassemblerEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *reassemblerEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *reassemblerEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*reassemblerList)(nil))
+ state.Register((*reassemblerEntry)(nil))
+}
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
deleted file mode 100644
index 7daf64b4a..000000000
--- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
+++ /dev/null
@@ -1,636 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "errors"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-// reassembleTimeout is dummy timeout used for testing, where the clock never
-// advances.
-const reassembleTimeout = 1
-
-// vv is a helper to build VectorisedView from different strings.
-func vv(size int, pieces ...string) buffer.VectorisedView {
- views := make([]buffer.View, len(pieces))
- for i, p := range pieces {
- views[i] = []byte(p)
- }
-
- return buffer.NewVectorisedView(size, views)
-}
-
-func pkt(size int, pieces ...string) *stack.PacketBuffer {
- return stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv(size, pieces...),
- })
-}
-
-type processInput struct {
- id FragmentID
- first uint16
- last uint16
- more bool
- proto uint8
- pkt *stack.PacketBuffer
-}
-
-type processOutput struct {
- vv buffer.VectorisedView
- proto uint8
- done bool
-}
-
-var processTestCases = []struct {
- comment string
- in []processInput
- out []processOutput
-}{
- {
- comment: "One ID",
- in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
- },
- out: []processOutput{
- {vv: buffer.VectorisedView{}, done: false},
- {vv: vv(4, "01", "23"), done: true},
- },
- },
- {
- comment: "Next Header protocol mismatch",
- in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")},
- },
- out: []processOutput{
- {vv: buffer.VectorisedView{}, done: false},
- {vv: vv(4, "01", "23"), proto: 6, done: true},
- },
- },
- {
- comment: "Two IDs",
- in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
- {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")},
- {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
- },
- out: []processOutput{
- {vv: buffer.VectorisedView{}, done: false},
- {vv: buffer.VectorisedView{}, done: false},
- {vv: vv(4, "ab", "cd"), done: true},
- {vv: vv(4, "01", "23"), done: true},
- },
- },
-}
-
-func TestFragmentationProcess(t *testing.T) {
- for _, c := range processTestCases {
- t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil)
- firstFragmentProto := c.in[0].proto
- for i, in := range c.in {
- resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
- if err != nil {
- t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
- in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
- }
- if done != c.out[i].done {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
- in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done)
- }
- if c.out[i].done {
- if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data().AsRange().ToOwnedView()); diff != "" {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s",
- in.id, in.first, in.last, in.more, in.proto, in.pkt, diff)
- }
- if firstFragmentProto != proto {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)",
- in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto)
- }
- if _, ok := f.reassemblers[in.id]; ok {
- t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
- }
- for n := f.rList.Front(); n != nil; n = n.Next() {
- if n.id == in.id {
- t.Errorf("Process(%d) did not remove buffer from rList", i)
- }
- }
- }
- }
- })
- }
-}
-
-func TestReassemblingTimeout(t *testing.T) {
- const (
- reassemblyTimeout = time.Millisecond
- protocol = 0xff
- )
-
- type fragment struct {
- first uint16
- last uint16
- more bool
- data string
- }
-
- type event struct {
- // name is a nickname of this event.
- name string
-
- // clockAdvance is a duration to advance the clock. The clock advances
- // before a fragment specified in the fragment field is processed.
- clockAdvance time.Duration
-
- // fragment is a fragment to process. This can be nil if there is no
- // fragment to process.
- fragment *fragment
-
- // expectDone is true if the fragmentation instance should report the
- // reassembly is done after the fragment is processd.
- expectDone bool
-
- // memSizeAfterEvent is the expected memory size of the fragmentation
- // instance after the event.
- memSizeAfterEvent int
- }
-
- memSizeOfFrags := func(frags ...*fragment) int {
- var size int
- for _, frag := range frags {
- size += pkt(len(frag.data), frag.data).MemSize()
- }
- return size
- }
-
- half1 := &fragment{first: 0, last: 0, more: true, data: "0"}
- half2 := &fragment{first: 1, last: 1, more: false, data: "1"}
-
- tests := []struct {
- name string
- events []event
- }{
- {
- name: "half1 and half2 are reassembled successfully",
- events: []event{
- {
- name: "half1",
- fragment: half1,
- expectDone: false,
- memSizeAfterEvent: memSizeOfFrags(half1),
- },
- {
- name: "half2",
- fragment: half2,
- expectDone: true,
- memSizeAfterEvent: 0,
- },
- },
- },
- {
- name: "half1 timeout, half2 timeout",
- events: []event{
- {
- name: "half1",
- fragment: half1,
- expectDone: false,
- memSizeAfterEvent: memSizeOfFrags(half1),
- },
- {
- name: "half1 just before reassembly timeout",
- clockAdvance: reassemblyTimeout - 1,
- memSizeAfterEvent: memSizeOfFrags(half1),
- },
- {
- name: "half1 reassembly timeout",
- clockAdvance: 1,
- memSizeAfterEvent: 0,
- },
- {
- name: "half2",
- fragment: half2,
- expectDone: false,
- memSizeAfterEvent: memSizeOfFrags(half2),
- },
- {
- name: "half2 just before reassembly timeout",
- clockAdvance: reassemblyTimeout - 1,
- memSizeAfterEvent: memSizeOfFrags(half2),
- },
- {
- name: "half2 reassembly timeout",
- clockAdvance: 1,
- memSizeAfterEvent: 0,
- },
- },
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil)
- for _, event := range test.events {
- clock.Advance(event.clockAdvance)
- if frag := event.fragment; frag != nil {
- _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data))
- if err != nil {
- t.Fatalf("%s: f.Process failed: %s", event.name, err)
- }
- if done != event.expectDone {
- t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone)
- }
- }
- if got, want := f.memSize, event.memSizeAfterEvent; got != want {
- t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want)
- }
- }
- })
- }
-}
-
-func TestMemoryLimits(t *testing.T) {
- lowLimit := pkt(1, "0").MemSize()
- highLimit := 3 * lowLimit // Allow at most 3 such packets.
- f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil)
- // Send first fragment with id = 0.
- f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0"))
- // Send first fragment with id = 1.
- f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1"))
- // Send first fragment with id = 2.
- f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2"))
-
- // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
- // evicted.
- f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3"))
-
- if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
- t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
- }
- if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok {
- t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
- }
- if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok {
- t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
- }
-}
-
-func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- memSize := pkt(1, "0").MemSize()
- f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil)
- // Send first fragment with id = 0.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
- // Send the same packet again.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
-
- if got, want := f.memSize, memSize; got != want {
- t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
- }
-}
-
-func TestErrors(t *testing.T) {
- tests := []struct {
- name string
- blockSize uint16
- first uint16
- last uint16
- more bool
- data string
- err error
- }{
- {
- name: "exact block size without more",
- blockSize: 2,
- first: 2,
- last: 3,
- more: false,
- data: "01",
- },
- {
- name: "exact block size with more",
- blockSize: 2,
- first: 2,
- last: 3,
- more: true,
- data: "01",
- },
- {
- name: "exact block size with more and extra data",
- blockSize: 2,
- first: 2,
- last: 3,
- more: true,
- data: "012",
- err: ErrInvalidArgs,
- },
- {
- name: "exact block size with more and too little data",
- blockSize: 2,
- first: 2,
- last: 3,
- more: true,
- data: "0",
- err: ErrInvalidArgs,
- },
- {
- name: "not exact block size with more",
- blockSize: 2,
- first: 2,
- last: 2,
- more: true,
- data: "0",
- err: ErrInvalidArgs,
- },
- {
- name: "not exact block size without more",
- blockSize: 2,
- first: 2,
- last: 2,
- more: false,
- data: "0",
- },
- {
- name: "first not a multiple of block size",
- blockSize: 2,
- first: 3,
- last: 4,
- more: true,
- data: "01",
- err: ErrInvalidArgs,
- },
- {
- name: "first more than last",
- blockSize: 2,
- first: 4,
- last: 3,
- more: true,
- data: "01",
- err: ErrInvalidArgs,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil)
- _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data))
- if !errors.Is(err, test.err) {
- t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
- }
- if done {
- t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data)
- }
- })
- }
-}
-
-type fragmentInfo struct {
- remaining int
- copied int
- offset int
- more bool
-}
-
-func TestPacketFragmenter(t *testing.T) {
- const (
- reserve = 60
- proto = 0
- )
-
- tests := []struct {
- name string
- fragmentPayloadLen uint32
- transportHeaderLen int
- payloadSize int
- wantFragments []fragmentInfo
- }{
- {
- name: "Packet exactly fits in MTU",
- fragmentPayloadLen: 1280,
- transportHeaderLen: 0,
- payloadSize: 1280,
- wantFragments: []fragmentInfo{
- {remaining: 0, copied: 1280, offset: 0, more: false},
- },
- },
- {
- name: "Packet exactly does not fit in MTU",
- fragmentPayloadLen: 1000,
- transportHeaderLen: 0,
- payloadSize: 1001,
- wantFragments: []fragmentInfo{
- {remaining: 1, copied: 1000, offset: 0, more: true},
- {remaining: 0, copied: 1, offset: 1000, more: false},
- },
- },
- {
- name: "Packet has a transport header",
- fragmentPayloadLen: 560,
- transportHeaderLen: 40,
- payloadSize: 560,
- wantFragments: []fragmentInfo{
- {remaining: 1, copied: 560, offset: 0, more: true},
- {remaining: 0, copied: 40, offset: 560, more: false},
- },
- },
- {
- name: "Packet has a huge transport header",
- fragmentPayloadLen: 500,
- transportHeaderLen: 1300,
- payloadSize: 500,
- wantFragments: []fragmentInfo{
- {remaining: 3, copied: 500, offset: 0, more: true},
- {remaining: 2, copied: 500, offset: 500, more: true},
- {remaining: 1, copied: 500, offset: 1000, more: true},
- {remaining: 0, copied: 300, offset: 1500, more: false},
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto)
- originalPayload := stack.PayloadSince(pkt.TransportHeader())
- var reassembledPayload buffer.VectorisedView
- pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve)
- for i := 0; ; i++ {
- fragPkt, offset, copied, more := pf.BuildNextFragment()
- wantFragment := test.wantFragments[i]
- if got := pf.RemainingFragmentCount(); got != wantFragment.remaining {
- t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining)
- }
- if copied != wantFragment.copied {
- t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied)
- }
- if offset != wantFragment.offset {
- t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset)
- }
- if more != wantFragment.more {
- t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
- }
- if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen {
- t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen)
- }
- if got := fragPkt.AvailableHeaderBytes(); got != reserve {
- t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
- }
- if got := fragPkt.TransportHeader().View().Size(); got != 0 {
- t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got)
- }
- reassembledPayload.AppendViews(fragPkt.Data().Views())
- if !more {
- if i != len(test.wantFragments)-1 {
- t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1)
- }
- break
- }
- }
- if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload); diff != "" {
- t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-type testTimeoutHandler struct {
- pkt *stack.PacketBuffer
-}
-
-func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
- h.pkt = pkt
-}
-
-func TestTimeoutHandler(t *testing.T) {
- const (
- proto = 99
- )
-
- pk1 := pkt(1, "1")
- pk2 := pkt(1, "2")
-
- type processParam struct {
- first uint16
- last uint16
- more bool
- pkt *stack.PacketBuffer
- }
-
- tests := []struct {
- name string
- params []processParam
- wantError bool
- wantPkt *stack.PacketBuffer
- }{
- {
- name: "onTimeout runs",
- params: []processParam{
- {
- first: 0,
- last: 0,
- more: true,
- pkt: pk1,
- },
- },
- wantError: false,
- wantPkt: pk1,
- },
- {
- name: "no first fragment",
- params: []processParam{
- {
- first: 1,
- last: 1,
- more: true,
- pkt: pk1,
- },
- },
- wantError: false,
- wantPkt: nil,
- },
- {
- name: "second pkt is ignored",
- params: []processParam{
- {
- first: 0,
- last: 0,
- more: true,
- pkt: pk1,
- },
- {
- first: 0,
- last: 0,
- more: true,
- pkt: pk2,
- },
- },
- wantError: false,
- wantPkt: pk1,
- },
- {
- name: "invalid args - first is greater than last",
- params: []processParam{
- {
- first: 1,
- last: 0,
- more: true,
- pkt: pk1,
- },
- },
- wantError: true,
- wantPkt: nil,
- },
- }
-
- id := FragmentID{ID: 0}
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- handler := &testTimeoutHandler{pkt: nil}
-
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler)
-
- for _, p := range test.params {
- if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError {
- t.Errorf("f.Process error = %s", err)
- }
- }
- if !test.wantError {
- r, ok := f.reassemblers[id]
- if !ok {
- t.Fatal("Reassembler not found")
- }
- f.release(r, true)
- }
- switch {
- case handler.pkt != nil && test.wantPkt == nil:
- t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data().AsRange().ToOwnedView())
- case handler.pkt == nil && test.wantPkt != nil:
- t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data().AsRange().ToOwnedView())
- case handler.pkt != nil && test.wantPkt != nil:
- if diff := cmp.Diff(test.wantPkt.Data().AsRange().ToOwnedView(), handler.pkt.Data().AsRange().ToOwnedView()); diff != "" {
- t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff)
- }
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_list.go b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go
new file mode 100644
index 000000000..673bb11b0
--- /dev/null
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go
@@ -0,0 +1,221 @@
+package fragmentation
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type reassemblerElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type reassemblerList struct {
+ head *reassembler
+ tail *reassembler
+}
+
+// Reset resets list l to the empty state.
+func (l *reassemblerList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *reassemblerList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *reassemblerList) Front() *reassembler {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *reassemblerList) Back() *reassembler {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *reassemblerList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (reassemblerElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *reassemblerList) PushFront(e *reassembler) {
+ linker := reassemblerElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *reassemblerList) PushBack(e *reassembler) {
+ linker := reassemblerElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *reassemblerList) PushBackList(m *reassemblerList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *reassemblerList) InsertAfter(b, e *reassembler) {
+ bLinker := reassemblerElementMapper{}.linkerFor(b)
+ eLinker := reassemblerElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *reassemblerList) InsertBefore(a, e *reassembler) {
+ aLinker := reassemblerElementMapper{}.linkerFor(a)
+ eLinker := reassemblerElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ reassemblerElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *reassemblerList) Remove(e *reassembler) {
+ linker := reassemblerElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ reassemblerElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ reassemblerElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type reassemblerEntry struct {
+ next *reassembler
+ prev *reassembler
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) Next() *reassembler {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) Prev() *reassembler {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) SetNext(elem *reassembler) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) SetPrev(elem *reassembler) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
deleted file mode 100644
index cfd9f00ef..000000000
--- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
+++ /dev/null
@@ -1,233 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "bytes"
- "math"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type processParams struct {
- first uint16
- last uint16
- more bool
- pkt *stack.PacketBuffer
- wantDone bool
- wantError error
-}
-
-func TestReassemblerProcess(t *testing.T) {
- const proto = 99
-
- v := func(size int) buffer.View {
- payload := buffer.NewView(size)
- for i := 1; i < size; i++ {
- payload[i] = uint8(i) * 3
- }
- return payload
- }
-
- pkt := func(sizes ...int) *stack.PacketBuffer {
- var vv buffer.VectorisedView
- for _, size := range sizes {
- vv.AppendView(v(size))
- }
- return stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- }
-
- var tests = []struct {
- name string
- params []processParams
- want []hole
- wantPkt *stack.PacketBuffer
- }{
- {
- name: "No fragments",
- params: nil,
- want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}},
- },
- {
- name: "One fragment at beginning",
- params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
- want: []hole{
- {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)},
- {first: 2, last: math.MaxUint16, filled: false, final: true},
- },
- },
- {
- name: "One fragment in the middle",
- params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
- want: []hole{
- {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)},
- {first: 0, last: 0, filled: false, final: false},
- {first: 3, last: math.MaxUint16, filled: false, final: true},
- },
- },
- {
- name: "One fragment at the end",
- params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}},
- want: []hole{
- {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)},
- {first: 0, last: 0, filled: false},
- },
- },
- {
- name: "One fragment completing a packet",
- params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}},
- want: []hole{
- {first: 0, last: 1, filled: true, final: true},
- },
- wantPkt: pkt(2),
- },
- {
- name: "Two fragments completing a packet",
- params: []processParams{
- {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
- },
- want: []hole{
- {first: 0, last: 1, filled: true, final: false},
- {first: 2, last: 3, filled: true, final: true},
- },
- wantPkt: pkt(2, 2),
- },
- {
- name: "Two fragments completing a packet with a duplicate",
- params: []processParams{
- {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
- },
- want: []hole{
- {first: 0, last: 1, filled: true, final: false},
- {first: 2, last: 3, filled: true, final: true},
- },
- wantPkt: pkt(2, 2),
- },
- {
- name: "Two fragments completing a packet with a partial duplicate",
- params: []processParams{
- {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil},
- {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
- },
- want: []hole{
- {first: 0, last: 3, filled: true, final: false},
- {first: 4, last: 5, filled: true, final: true},
- },
- wantPkt: pkt(4, 2),
- },
- {
- name: "Two overlapping fragments",
- params: []processParams{
- {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil},
- {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap},
- },
- want: []hole{
- {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)},
- {first: 11, last: math.MaxUint16, filled: false, final: true},
- },
- },
- {
- name: "Two final fragments with different ends",
- params: []processParams{
- {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
- {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict},
- },
- want: []hole{
- {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)},
- {first: 0, last: 9, filled: false, final: false},
- },
- },
- {
- name: "Two final fragments - duplicate",
- params: []processParams{
- {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
- {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
- },
- want: []hole{
- {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
- {first: 0, last: 4, filled: false, final: false},
- },
- },
- {
- name: "Two final fragments - duplicate, with different ends",
- params: []processParams{
- {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
- {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict},
- },
- want: []hole{
- {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
- {first: 0, last: 4, filled: false, final: false},
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- r := newReassembler(FragmentID{}, &faketime.NullClock{})
- var resPkt *stack.PacketBuffer
- var isDone bool
- for _, param := range test.params {
- pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
- if done != param.wantDone || err != param.wantError {
- t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
- }
- if done {
- resPkt = pkt
- isDone = true
- }
- }
-
- ignorePkt := func(a, b *stack.PacketBuffer) bool { return true }
- cmpPktData := func(a, b *stack.PacketBuffer) bool {
- if a == nil || b == nil {
- return a == b
- }
- return bytes.Equal(a.Data().AsRange().ToOwnedView(), b.Data().AsRange().ToOwnedView())
- }
-
- if isDone {
- if diff := cmp.Diff(
- test.want, r.holes,
- cmp.AllowUnexported(hole{}),
- // Do not compare pkt in hole. Data will be altered.
- cmp.Comparer(ignorePkt),
- ); diff != "" {
- t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" {
- t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff)
- }
- } else {
- if diff := cmp.Diff(
- test.want, r.holes,
- cmp.AllowUnexported(hole{}),
- cmp.Comparer(cmpPktData),
- ); diff != "" {
- t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
- }
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD
deleted file mode 100644
index fd944ce99..000000000
--- a/pkg/tcpip/network/internal/ip/BUILD
+++ /dev/null
@@ -1,40 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "ip",
- srcs = [
- "duplicate_address_detection.go",
- "errors.go",
- "generic_multicast_protocol.go",
- "stats.go",
- ],
- visibility = [
- "//pkg/tcpip/network/arp:__pkg__",
- "//pkg/tcpip/network/ipv4:__pkg__",
- "//pkg/tcpip/network/ipv6:__pkg__",
- ],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "ip_x_test",
- size = "small",
- srcs = [
- "duplicate_address_detection_test.go",
- "generic_multicast_protocol_test.go",
- ],
- deps = [
- ":ip",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/stack",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
deleted file mode 100644
index a22b712c6..000000000
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
+++ /dev/null
@@ -1,381 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ip_test
-
-import (
- "bytes"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type mockDADProtocol struct {
- t *testing.T
-
- mu struct {
- sync.Mutex
-
- dad ip.DAD
- sentNonces map[tcpip.Address][][]byte
- }
-}
-
-func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip.DADOptions) {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- m.t = t
- opts.Protocol = m
- m.mu.dad.Init(&m.mu, c, opts)
- m.initLocked()
-}
-
-func (m *mockDADProtocol) initLocked() {
- m.mu.sentNonces = make(map[tcpip.Address][][]byte)
-}
-
-func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.sentNonces[addr] = append(m.mu.sentNonces[addr], nonce)
- return nil
-}
-
-func (m *mockDADProtocol) check(addrs []tcpip.Address) string {
- sentNonces := make(map[tcpip.Address][][]byte)
- for _, a := range addrs {
- sentNonces[a] = append(sentNonces[a], nil)
- }
-
- return m.checkWithNonce(sentNonces)
-}
-
-func (m *mockDADProtocol) checkWithNonce(expectedSentNonces map[tcpip.Address][][]byte) string {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- diff := cmp.Diff(expectedSentNonces, m.mu.sentNonces)
- m.initLocked()
- return diff
-}
-
-func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition {
- m.mu.Lock()
- defer m.mu.Unlock()
- return m.mu.dad.CheckDuplicateAddressLocked(addr, h)
-}
-
-func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.dad.StopLocked(addr, reason)
-}
-
-func (m *mockDADProtocol) extendIfNonceEqual(addr tcpip.Address, nonce []byte) ip.ExtendIfNonceEqualLockedDisposition {
- m.mu.Lock()
- defer m.mu.Unlock()
- return m.mu.dad.ExtendIfNonceEqualLocked(addr, nonce)
-}
-
-func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.dad.SetConfigsLocked(c)
-}
-
-const (
- addr1 = tcpip.Address("\x01")
- addr2 = tcpip.Address("\x02")
- addr3 = tcpip.Address("\x03")
- addr4 = tcpip.Address("\x04")
-)
-
-type dadResult struct {
- Addr tcpip.Address
- R stack.DADResult
-}
-
-func handler(ch chan<- dadResult, a tcpip.Address) func(stack.DADResult) {
- return func(r stack.DADResult) {
- ch <- dadResult{Addr: a, R: r}
- }
-}
-
-func TestDADCheckDuplicateAddress(t *testing.T) {
- var dad mockDADProtocol
- clock := faketime.NewManualClock()
- dad.init(t, stack.DADConfigurations{}, ip.DADOptions{
- Clock: clock,
- })
-
- ch := make(chan dadResult, 2)
-
- // DAD should initially be disabled.
- if res := dad.checkDuplicateAddress(addr1, handler(nil, "")); res != stack.DADDisabled {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled)
- }
- // Wait for any initially fired timers to complete.
- clock.Advance(0)
- if diff := dad.check(nil); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
-
- // Enable and request DAD.
- dadConfigs1 := stack.DADConfigurations{
- DupAddrDetectTransmits: 1,
- RetransmitTimer: time.Second,
- }
- dad.setConfigs(dadConfigs1)
- if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
- }
- clock.Advance(0)
- if diff := dad.check([]tcpip.Address{addr1}); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
- // The second request for DAD on the same address should use the original
- // request since it has not completed yet.
- if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning)
- }
- clock.Advance(0)
- if diff := dad.check(nil); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
-
- dadConfigs2 := stack.DADConfigurations{
- DupAddrDetectTransmits: 2,
- RetransmitTimer: time.Second,
- }
- dad.setConfigs(dadConfigs2)
- // A new address should start a new DAD process.
- if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
- }
- clock.Advance(0)
- if diff := dad.check([]tcpip.Address{addr2}); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
-
- // Make sure DAD for addr1 only resolves after the expected timeout.
- const delta = time.Nanosecond
- dadConfig1Duration := time.Duration(dadConfigs1.DupAddrDetectTransmits) * dadConfigs1.RetransmitTimer
- clock.Advance(dadConfig1Duration - delta)
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig1Duration, r)
- default:
- }
- clock.Advance(delta)
- for i := 0; i < 2; i++ {
- if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff)
- }
- }
-
- // Make sure DAD for addr2 only resolves after the expected timeout.
- dadConfig2Duration := time.Duration(dadConfigs2.DupAddrDetectTransmits) * dadConfigs2.RetransmitTimer
- clock.Advance(dadConfig2Duration - dadConfig1Duration - delta)
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig2Duration, r)
- default:
- }
- clock.Advance(delta)
- if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- // Should be able to restart DAD for addr2 after it resolved.
- if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
- }
- clock.Advance(0)
- if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
- clock.Advance(dadConfig2Duration)
- if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anymore results.
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r)
- default:
- }
-}
-
-func TestDADStop(t *testing.T) {
- var dad mockDADProtocol
- clock := faketime.NewManualClock()
- dadConfigs := stack.DADConfigurations{
- DupAddrDetectTransmits: 1,
- RetransmitTimer: time.Second,
- }
- dad.init(t, dadConfigs, ip.DADOptions{
- Clock: clock,
- })
-
- ch := make(chan dadResult, 1)
-
- if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
- }
- if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
- }
- if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
- }
- clock.Advance(0)
- if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
-
- dad.stop(addr1, &stack.DADAborted{})
- if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADAborted{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- dad.stop(addr2, &stack.DADDupAddrDetected{})
- if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADDupAddrDetected{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer
- clock.Advance(dadResolutionDuration)
- if diff := cmp.Diff(dadResult{Addr: addr3, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- // Should be able to restart DAD for an address we stopped DAD on.
- if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
- }
- clock.Advance(0)
- if diff := dad.check([]tcpip.Address{addr1}); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
- clock.Advance(dadResolutionDuration)
- if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anymore updates.
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r)
- default:
- }
-}
-
-func TestNonce(t *testing.T) {
- const (
- nonceSize = 2
-
- extendRequestAttempts = 2
-
- dupAddrDetectTransmits = 2
- extendTransmits = 5
- )
-
- var secureRNGBytes [nonceSize * (dupAddrDetectTransmits + extendTransmits)]byte
- for i := range secureRNGBytes {
- secureRNGBytes[i] = byte(i)
- }
-
- tests := []struct {
- name string
- mockedReceivedNonce []byte
- expectedResults [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition
- expectedTransmits int
- }{
- {
- name: "not matching",
- mockedReceivedNonce: []byte{0, 0},
- expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.NonceNotEqual, ip.NonceNotEqual},
- expectedTransmits: dupAddrDetectTransmits,
- },
- {
- name: "matching nonce",
- mockedReceivedNonce: secureRNGBytes[:nonceSize],
- expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.Extended, ip.AlreadyExtended},
- expectedTransmits: dupAddrDetectTransmits + extendTransmits,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- var dad mockDADProtocol
- clock := faketime.NewManualClock()
- dadConfigs := stack.DADConfigurations{
- DupAddrDetectTransmits: dupAddrDetectTransmits,
- RetransmitTimer: time.Second,
- }
-
- var secureRNG bytes.Reader
- secureRNG.Reset(secureRNGBytes[:])
- dad.init(t, dadConfigs, ip.DADOptions{
- Clock: clock,
- SecureRNG: &secureRNG,
- NonceSize: nonceSize,
- ExtendDADTransmits: extendTransmits,
- })
-
- ch := make(chan dadResult, 1)
- if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
- }
-
- clock.Advance(0)
- for i, want := range test.expectedResults {
- if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want {
- t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want)
- }
- }
-
- for i := 0; i < test.expectedTransmits; i++ {
- if diff := dad.checkWithNonce(map[tcpip.Address][][]byte{
- addr1: {
- secureRNGBytes[nonceSize*i:][:nonceSize],
- },
- }); diff != "" {
- t.Errorf("(i=%d) dad check mismatch (-want +got):\n%s", i, diff)
- }
-
- clock.Advance(dadConfigs.RetransmitTimer)
- }
-
- if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anymore updates.
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r)
- default:
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
deleted file mode 100644
index 0b51563cd..000000000
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
+++ /dev/null
@@ -1,808 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ip_test
-
-import (
- "math/rand"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
-)
-
-const maxUnsolicitedReportDelay = time.Second
-
-var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
-
-type mockMulticastGroupProtocolProtectedFields struct {
- sync.RWMutex
-
- genericMulticastGroup ip.GenericMulticastProtocolState
- sendReportGroupAddrCount map[tcpip.Address]int
- sendLeaveGroupAddrCount map[tcpip.Address]int
- makeQueuePackets bool
- disabled bool
-}
-
-type mockMulticastGroupProtocol struct {
- t *testing.T
-
- skipProtocolAddress tcpip.Address
-
- mu mockMulticastGroupProtocolProtectedFields
-}
-
-func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.initLocked()
- opts.Protocol = m
- m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts)
-}
-
-func (m *mockMulticastGroupProtocol) initLocked() {
- m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int)
- m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
-}
-
-func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.disabled = !v
-}
-
-func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.makeQueuePackets = v
-}
-
-func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.JoinGroupLocked(addr)
-}
-
-func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool {
- m.mu.Lock()
- defer m.mu.Unlock()
- return m.mu.genericMulticastGroup.LeaveGroupLocked(addr)
-}
-
-func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.HandleReportLocked(addr)
-}
-
-func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime)
-}
-
-func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool {
- m.mu.RLock()
- defer m.mu.RUnlock()
- return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr)
-}
-
-func (m *mockMulticastGroupProtocol) makeAllNonMember() {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.MakeAllNonMemberLocked()
-}
-
-func (m *mockMulticastGroupProtocol) initializeGroups() {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.InitializeGroupsLocked()
-}
-
-func (m *mockMulticastGroupProtocol) sendQueuedReports() {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.SendQueuedReportsLocked()
-}
-
-// Enabled implements ip.MulticastGroupProtocol.
-//
-// Precondition: m.mu must be read locked.
-func (m *mockMulticastGroupProtocol) Enabled() bool {
- if m.mu.TryLock() {
- m.mu.Unlock()
- m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
- }
-
- return !m.mu.disabled
-}
-
-// SendReport implements ip.MulticastGroupProtocol.
-//
-// Precondition: m.mu must be locked.
-func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
- if m.mu.TryLock() {
- m.mu.Unlock()
- m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
- }
- if m.mu.TryRLock() {
- m.mu.RUnlock()
- m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
- }
-
- m.mu.sendReportGroupAddrCount[groupAddress]++
- return !m.mu.makeQueuePackets, nil
-}
-
-// SendLeave implements ip.MulticastGroupProtocol.
-//
-// Precondition: m.mu must be locked.
-func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
- if m.mu.TryLock() {
- m.mu.Unlock()
- m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
- }
- if m.mu.TryRLock() {
- m.mu.RUnlock()
- m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
- }
-
- m.mu.sendLeaveGroupAddrCount[groupAddress]++
- return nil
-}
-
-// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
-func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
- return groupAddress != m.skipProtocolAddress
-}
-
-func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- sendReportGroupAddrCount := make(map[tcpip.Address]int)
- for _, a := range sendReportGroupAddresses {
- sendReportGroupAddrCount[a] = 1
- }
-
- sendLeaveGroupAddrCount := make(map[tcpip.Address]int)
- for _, a := range sendLeaveGroupAddresses {
- sendLeaveGroupAddrCount[a] = 1
- }
-
- diff := cmp.Diff(
- &mockMulticastGroupProtocol{
- mu: mockMulticastGroupProtocolProtectedFields{
- sendReportGroupAddrCount: sendReportGroupAddrCount,
- sendLeaveGroupAddrCount: sendLeaveGroupAddrCount,
- },
- },
- m,
- cmp.AllowUnexported(mockMulticastGroupProtocol{}),
- cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}),
- // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
- cmp.FilterPath(
- func(p cmp.Path) bool {
- switch p.Last().String() {
- case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress":
- return true
- default:
- return false
- }
- },
- cmp.Ignore(),
- ),
- )
- m.initLocked()
- return diff
-}
-
-func TestJoinGroup(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- shouldSendReports bool
- }{
- {
- name: "Normal group",
- addr: addr1,
- shouldSendReports: true,
- },
- {
- name: "All-nodes group",
- addr: addr2,
- shouldSendReports: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(0)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- // Joining a group should send a report immediately and another after
- // a random interval between 0 and the maximum unsolicited report delay.
- mgp.joinGroup(test.addr)
- if test.shouldSendReports {
- if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- }
-
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestLeaveGroup(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- shouldSendMessages bool
- }{
- {
- name: "Normal group",
- addr: addr1,
- shouldSendMessages: true,
- },
- {
- name: "All-nodes group",
- addr: addr2,
- shouldSendMessages: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(1)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- mgp.joinGroup(test.addr)
- if test.shouldSendMessages {
- if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- }
-
- // Leaving a group should send a leave report immediately and cancel any
- // delayed reports.
- {
-
- if !mgp.leaveGroup(test.addr) {
- t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr)
- }
- }
- if test.shouldSendMessages {
- if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- }
-
- // Should have no more messages to send.
- //
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestHandleReport(t *testing.T) {
- tests := []struct {
- name string
- reportAddr tcpip.Address
- expectReportsFor []tcpip.Address
- }{
- {
- name: "Unpecified empty",
- reportAddr: "",
- expectReportsFor: []tcpip.Address{addr1, addr2},
- },
- {
- name: "Unpecified any",
- reportAddr: "\x00",
- expectReportsFor: []tcpip.Address{addr1, addr2},
- },
- {
- name: "Specified",
- reportAddr: addr1,
- expectReportsFor: []tcpip.Address{addr2},
- },
- {
- name: "Specified all-nodes",
- reportAddr: addr3,
- expectReportsFor: []tcpip.Address{addr1, addr2},
- },
- {
- name: "Specified other",
- reportAddr: addr4,
- expectReportsFor: []tcpip.Address{addr1, addr2},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(2)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- mgp.joinGroup(addr1)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr2)
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr3)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receiving a report for a group we have a timer scheduled for should
- // cancel our delayed report timer for the group.
- mgp.handleReport(test.reportAddr)
- if len(test.expectReportsFor) != 0 {
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- }
-
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestHandleQuery(t *testing.T) {
- tests := []struct {
- name string
- queryAddr tcpip.Address
- maxDelay time.Duration
- expectQueriedReportsFor []tcpip.Address
- expectDelayedReportsFor []tcpip.Address
- }{
- {
- name: "Unpecified empty",
- queryAddr: "",
- maxDelay: 0,
- expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
- expectDelayedReportsFor: nil,
- },
- {
- name: "Unpecified any",
- queryAddr: "\x00",
- maxDelay: 1,
- expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
- expectDelayedReportsFor: nil,
- },
- {
- name: "Specified",
- queryAddr: addr1,
- maxDelay: 2,
- expectQueriedReportsFor: []tcpip.Address{addr1},
- expectDelayedReportsFor: []tcpip.Address{addr2},
- },
- {
- name: "Specified all-nodes",
- queryAddr: addr3,
- maxDelay: 3,
- expectQueriedReportsFor: nil,
- expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
- },
- {
- name: "Specified other",
- queryAddr: addr4,
- maxDelay: 4,
- expectQueriedReportsFor: nil,
- expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(3)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- mgp.joinGroup(addr1)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr2)
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr3)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receiving a query should make us reschedule our delayed report timer
- // to some time within the new max response delay.
- mgp.handleQuery(test.queryAddr, test.maxDelay)
- clock.Advance(test.maxDelay)
- if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // The groups that were not affected by the query should still send a
- // report after the max unsolicited report delay.
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestJoinCount(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(4)),
- Clock: clock,
- MaxUnsolicitedReportDelay: time.Second,
- })
-
- // Set the join count to 2 for a group.
- mgp.joinGroup(addr1)
- if !mgp.isLocallyJoined(addr1) {
- t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
- }
- // Only the first join should trigger a report to be sent.
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr1)
- if !mgp.isLocallyJoined(addr1) {
- t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Group should still be considered joined after leaving once.
- if !mgp.leaveGroup(addr1) {
- t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
- }
- if !mgp.isLocallyJoined(addr1) {
- t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
- }
- // A leave report should only be sent once the join count reaches 0.
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Leaving once more should actually remove us from the group.
- if !mgp.leaveGroup(addr1) {
- t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
- }
- if mgp.isLocallyJoined(addr1) {
- t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Group should no longer be joined so we should not have anything to
- // leave.
- if mgp.leaveGroup(addr1) {
- t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1)
- }
- if mgp.isLocallyJoined(addr1) {
- t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should have no more messages to send.
- //
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-}
-
-func TestMakeAllNonMemberAndInitialize(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(3)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- mgp.joinGroup(addr1)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr2)
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr3)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should send the leave reports for each but still consider them locally
- // joined.
- mgp.makeAllNonMember()
- if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- for _, group := range []tcpip.Address{addr1, addr2, addr3} {
- if !mgp.isLocallyJoined(group) {
- t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group)
- }
- }
-
- // Should send the initial set of unsolcited reports.
- mgp.initializeGroups()
- if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-}
-
-// TestGroupStateNonMember tests that groups do not send packets when in the
-// non-member state, but are still considered locally joined.
-func TestGroupStateNonMember(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(3)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
- mgp.setEnabled(false)
-
- // Joining groups should not send any reports.
- mgp.joinGroup(addr1)
- if !mgp.isLocallyJoined(addr1) {
- t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr2)
- if !mgp.isLocallyJoined(addr1) {
- t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receiving a query should not send any reports.
- mgp.handleQuery(addr1, time.Nanosecond)
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(time.Nanosecond)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Leaving groups should not send any leave messages.
- if !mgp.leaveGroup(addr1) {
- t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2)
- }
- if mgp.isLocallyJoined(addr1) {
- t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-}
-
-func TestQueuedPackets(t *testing.T) {
- clock := faketime.NewManualClock()
- mgp := mockMulticastGroupProtocol{t: t}
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(4)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- // Joining should trigger a SendReport, but mgp should report that we did not
- // send the packet.
- mgp.setQueuePackets(true)
- mgp.joinGroup(addr1)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // The delayed report timer should have been cancelled since we did not send
- // the initial report earlier.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Mock being able to successfully send the report.
- mgp.setQueuePackets(false)
- mgp.sendQueuedReports()
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // The delayed report (sent after the initial report) should now be sent.
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anything else to send (we should be idle).
- mgp.sendQueuedReports()
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receive a query but mock being unable to send reports again.
- mgp.setQueuePackets(true)
- mgp.handleQuery(addr1, time.Nanosecond)
- clock.Advance(time.Nanosecond)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Mock being able to send reports again - we should have a packet queued to
- // send.
- mgp.setQueuePackets(false)
- mgp.sendQueuedReports()
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anything else to send.
- mgp.sendQueuedReports()
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receive a query again, but mock being unable to send reports.
- mgp.setQueuePackets(true)
- mgp.handleQuery(addr1, time.Nanosecond)
- clock.Advance(time.Nanosecond)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receiving a report should should transition us into the idle member state,
- // even if we had a packet queued. We should no longer have any packets to
- // send.
- mgp.handleReport(addr1)
- mgp.sendQueuedReports()
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // When we fail to send the initial set of reports, incoming reports should
- // not affect a newly joined group's reports from being sent.
- mgp.setQueuePackets(true)
- mgp.joinGroup(addr2)
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.handleReport(addr2)
- // Attempting to send queued reports while still unable to send reports should
- // not change the host state.
- mgp.sendQueuedReports()
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- // Mock being able to successfully send the report.
- mgp.setQueuePackets(false)
- mgp.sendQueuedReports()
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- // The delayed report (sent after the initial report) should now be sent.
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anything else to send.
- mgp.sendQueuedReports()
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-}
diff --git a/pkg/tcpip/network/internal/ip/ip_state_autogen.go b/pkg/tcpip/network/internal/ip/ip_state_autogen.go
new file mode 100644
index 000000000..aee77044e
--- /dev/null
+++ b/pkg/tcpip/network/internal/ip/ip_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package ip
diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD
deleted file mode 100644
index cec3e62c4..000000000
--- a/pkg/tcpip/network/internal/testutil/BUILD
+++ /dev/null
@@ -1,20 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "testutil",
- srcs = ["testutil.go"],
- visibility = [
- "//pkg/tcpip/network/arp:__pkg__",
- "//pkg/tcpip/network/internal/fragmentation:__pkg__",
- "//pkg/tcpip/network/ipv4:__pkg__",
- "//pkg/tcpip/network/ipv6:__pkg__",
- ],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
deleted file mode 100644
index 605e9ef8d..000000000
--- a/pkg/tcpip/network/internal/testutil/testutil.go
+++ /dev/null
@@ -1,129 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package testutil defines types and functions used to test Network Layer
-// functionality such as IP fragmentation.
-package testutil
-
-import (
- "fmt"
- "math/rand"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-// MockLinkEndpoint is an endpoint used for testing, it stores packets written
-// to it and can mock errors.
-type MockLinkEndpoint struct {
- // WrittenPackets is where packets written to the endpoint are stored.
- WrittenPackets []*stack.PacketBuffer
-
- mtu uint32
- err tcpip.Error
- allowPackets int
-}
-
-// NewMockLinkEndpoint creates a new MockLinkEndpoint.
-//
-// err is the error that will be returned once allowPackets packets are written
-// to the endpoint.
-func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint {
- return &MockLinkEndpoint{
- mtu: mtu,
- err: err,
- allowPackets: allowPackets,
- }
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu }
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 }
-
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
-
-// LinkAddress implements LinkEndpoint.LinkAddress.
-func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- if ep.allowPackets == 0 {
- return ep.err
- }
- ep.allowPackets--
- ep.WrittenPackets = append(ep.WrittenPackets, pkt)
- return nil
-}
-
-// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- var n int
-
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if err := ep.WritePacket(r, protocol, pkt); err != nil {
- return n, err
- }
- n++
- }
-
- return n, nil
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (*MockLinkEndpoint) IsAttached() bool { return false }
-
-// Wait implements LinkEndpoint.Wait.
-func (*MockLinkEndpoint) Wait() {}
-
-// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
-func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
-
-// AddHeader implements LinkEndpoint.AddHeader.
-func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
-}
-
-// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
-// how many random bytes will be copied in the Transport Header.
-// extraHeaderReserveLength indicates how much extra space will be reserved for
-// the other headers. The payload is made from Views of the sizes listed in
-// viewSizes.
-func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer {
- var views buffer.VectorisedView
-
- for _, s := range viewSizes {
- newView := buffer.NewView(s)
- if _, err := rand.Read(newView); err != nil {
- panic(fmt.Sprintf("rand.Read: %s", err))
- }
- views.AppendView(newView)
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength,
- Data: views,
- })
- pkt.NetworkProtocolNumber = proto
- if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil {
- panic(fmt.Sprintf("rand.Read: %s", err))
- }
- return pkt
-}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
deleted file mode 100644
index 74aad126c..000000000
--- a/pkg/tcpip/network/ip_test.go
+++ /dev/null
@@ -1,2020 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ip_test
-
-import (
- "fmt"
- "strings"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
-)
-
-const nicID = 1
-
-var (
- localIPv4Addr = testutil.MustParse4("10.0.0.1")
- remoteIPv4Addr = testutil.MustParse4("10.0.0.2")
- ipv4SubnetAddr = testutil.MustParse4("10.0.0.0")
- ipv4SubnetMask = testutil.MustParse4("255.255.255.0")
- ipv4Gateway = testutil.MustParse4("10.0.0.3")
- localIPv6Addr = testutil.MustParse6("a00::1")
- remoteIPv6Addr = testutil.MustParse6("a00::2")
- ipv6SubnetAddr = testutil.MustParse6("a00::")
- ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00")
- ipv6Gateway = testutil.MustParse6("a00::3")
-)
-
-var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{
- Address: localIPv4Addr,
- PrefixLen: 24,
-}
-
-var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{
- Address: localIPv6Addr,
- PrefixLen: 120,
-}
-
-type transportError struct {
- origin tcpip.SockErrOrigin
- typ uint8
- code uint8
- info uint32
- kind stack.TransportErrorKind
-}
-
-// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
-// The former is used to pretend that it's a link endpoint so that we can
-// inspect packets written by the network endpoints. The latter is used to
-// pretend that it's the network stack so that it can inspect incoming packets
-// that have been handled by the network endpoints.
-//
-// Packets are checked by comparing their fields/values against the expected
-// values stored in the test object itself.
-type testObject struct {
- t *testing.T
- protocol tcpip.TransportProtocolNumber
- contents []byte
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- v4 bool
- transErr transportError
-
- dataCalls int
- controlCalls int
-}
-
-// checkValues verifies that the transport protocol, data contents, src & dst
-// addresses of a packet match what's expected. If any field doesn't match, the
-// test fails.
-func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, v buffer.View, srcAddr, dstAddr tcpip.Address) {
- if protocol != t.protocol {
- t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
- }
-
- if srcAddr != t.srcAddr {
- t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr)
- }
-
- if dstAddr != t.dstAddr {
- t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr)
- }
-
- if len(v) != len(t.contents) {
- t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents))
- }
-
- for i := range t.contents {
- if t.contents[i] != v[i] {
- t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i])
- }
- }
-}
-
-// DeliverTransportPacket is called by network endpoints after parsing incoming
-// packets. This is used by the test object to verify that the results of the
-// parsing are expected.
-func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
- netHdr := pkt.Network()
- t.checkValues(protocol, pkt.Data().AsRange().ToOwnedView(), netHdr.SourceAddress(), netHdr.DestinationAddress())
- t.dataCalls++
- return stack.TransportPacketHandled
-}
-
-// DeliverTransportError is called by network endpoints after parsing
-// incoming control (ICMP) packets. This is used by the test object to verify
-// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) {
- t.checkValues(trans, pkt.Data().AsRange().ToOwnedView(), remote, local)
- if diff := cmp.Diff(
- t.transErr,
- transportError{
- origin: transErr.Origin(),
- typ: transErr.Type(),
- code: transErr.Code(),
- info: transErr.Info(),
- kind: transErr.Kind(),
- },
- cmp.AllowUnexported(transportError{}),
- ); diff != "" {
- t.t.Errorf("transport error mismatch (-want +got):\n%s", diff)
- }
- t.controlCalls++
-}
-
-// Attach is only implemented to satisfy the LinkEndpoint interface.
-func (*testObject) Attach(stack.NetworkDispatcher) {}
-
-// IsAttached implements stack.LinkEndpoint.IsAttached.
-func (*testObject) IsAttached() bool {
- return true
-}
-
-// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that
-// matches the linux loopback MTU.
-func (*testObject) MTU() uint32 {
- return 65536
-}
-
-// Capabilities implements stack.LinkEndpoint.Capabilities.
-func (*testObject) Capabilities() stack.LinkEndpointCapabilities {
- return 0
-}
-
-// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface.
-func (*testObject) MaxHeaderLength() uint16 {
- return 0
-}
-
-// LinkAddress returns the link address of this endpoint.
-func (*testObject) LinkAddress() tcpip.LinkAddress {
- return ""
-}
-
-// Wait implements stack.LinkEndpoint.Wait.
-func (*testObject) Wait() {}
-
-// WritePacket is called by network endpoints after producing a packet and
-// writing it to the link endpoint. This is used by the test object to verify
-// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- var prot tcpip.TransportProtocolNumber
- var srcAddr tcpip.Address
- var dstAddr tcpip.Address
-
- if t.v4 {
- h := header.IPv4(pkt.NetworkHeader().View())
- prot = tcpip.TransportProtocolNumber(h.Protocol())
- srcAddr = h.SourceAddress()
- dstAddr = h.DestinationAddress()
-
- } else {
- h := header.IPv6(pkt.NetworkHeader().View())
- prot = tcpip.TransportProtocolNumber(h.NextHeader())
- srcAddr = h.SourceAddress()
- dstAddr = h.DestinationAddress()
- }
- t.checkValues(prot, pkt.Data().AsRange().ToOwnedView(), srcAddr, dstAddr)
- return nil
-}
-
-// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*testObject) WritePackets(_ *stack.Route, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
-}
-
-// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
-func (*testObject) ARPHardwareType() header.ARPHardwareType {
- panic("not implemented")
-}
-
-// AddHeader implements stack.LinkEndpoint.AddHeader.
-func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- panic("not implemented")
-}
-
-func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
- s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv4.ProtocolNumber, local)
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- Gateway: ipv4Gateway,
- NIC: 1,
- }})
-
- return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
-}
-
-func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
- s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv6.ProtocolNumber, local)
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: ipv6Gateway,
- NIC: 1,
- }})
-
- return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
-}
-
-func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) {
- t.Helper()
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
- e := channel.New(1, mtu, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
- }
-
- v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
- }
-
- return s, e
-}
-
-func buildDummyStack(t *testing.T) *stack.Stack {
- t.Helper()
-
- s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU)
- return s
-}
-
-var _ stack.NetworkInterface = (*testInterface)(nil)
-
-type testInterface struct {
- testObject
-
- mu struct {
- sync.RWMutex
- disabled bool
- }
-}
-
-func (*testInterface) ID() tcpip.NICID {
- return nicID
-}
-
-func (*testInterface) IsLoopback() bool {
- return false
-}
-
-func (*testInterface) Name() string {
- return ""
-}
-
-func (t *testInterface) Enabled() bool {
- t.mu.RLock()
- defer t.mu.RUnlock()
- return !t.mu.disabled
-}
-
-func (*testInterface) Promiscuous() bool {
- return false
-}
-
-func (*testInterface) Spoofing() bool {
- return false
-}
-
-func (t *testInterface) setEnabled(v bool) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.mu.disabled = !v
-}
-
-func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
- return &tcpip.ErrNotSupported{}
-}
-
-func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error {
- return nil
-}
-
-func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error {
- return nil
-}
-
-func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
- return tcpip.AddressWithPrefix{}, nil
-}
-
-func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
- return false
-}
-
-func TestSourceAddressValidation(t *testing.T) {
- rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4Echo)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(^header.Checksum(pkt, 0))
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(icmp.ProtocolNumber4),
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) {
- totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: src,
- Dst: localIPv6Addr,
- }))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv6Addr,
- })
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- tests := []struct {
- name string
- srcAddress tcpip.Address
- rxICMP func(*channel.Endpoint, tcpip.Address)
- valid bool
- }{
- {
- name: "IPv4 valid",
- srcAddress: "\x01\x02\x03\x04",
- rxICMP: rxIPv4ICMP,
- valid: true,
- },
- {
- name: "IPv6 valid",
- srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10",
- rxICMP: rxIPv6ICMP,
- valid: true,
- },
- {
- name: "IPv4 unspecified",
- srcAddress: header.IPv4Any,
- rxICMP: rxIPv4ICMP,
- valid: true,
- },
- {
- name: "IPv6 unspecified",
- srcAddress: header.IPv4Any,
- rxICMP: rxIPv6ICMP,
- valid: true,
- },
- {
- name: "IPv4 multicast",
- srcAddress: "\xe0\x00\x00\x01",
- rxICMP: rxIPv4ICMP,
- valid: false,
- },
- {
- name: "IPv6 multicast",
- srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- rxICMP: rxIPv6ICMP,
- valid: false,
- },
- {
- name: "IPv4 broadcast",
- srcAddress: header.IPv4Broadcast,
- rxICMP: rxIPv4ICMP,
- valid: false,
- },
- {
- name: "IPv4 subnet broadcast",
- srcAddress: func() tcpip.Address {
- subnet := localIPv4AddrWithPrefix.Subnet()
- return subnet.Broadcast()
- }(),
- rxICMP: rxIPv4ICMP,
- valid: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU)
- test.rxICMP(e, test.srcAddress)
-
- var wantValid uint64
- if test.valid {
- wantValid = 1
- }
-
- if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want {
- t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
- }
- if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid {
- t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid)
- }
- })
- }
-}
-
-func TestEnableWhenNICDisabled(t *testing.T) {
- tests := []struct {
- name string
- protocolFactory stack.NetworkProtocolFactory
- protoNum tcpip.NetworkProtocolNumber
- }{
- {
- name: "IPv4",
- protocolFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- },
- {
- name: "IPv6",
- protocolFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- var nic testInterface
- nic.setEnabled(false)
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory},
- })
- p := s.NetworkProtocolInstance(test.protoNum)
-
- // We pass nil for all parameters except the NetworkInterface and Stack
- // since Enable only depends on these.
- ep := p.NewEndpoint(&nic, nil)
-
- // The endpoint should initially be disabled, regardless the NIC's enabled
- // status.
- if ep.Enabled() {
- t.Fatal("got ep.Enabled() = true, want = false")
- }
- nic.setEnabled(true)
- if ep.Enabled() {
- t.Fatal("got ep.Enabled() = true, want = false")
- }
-
- // Attempting to enable the endpoint while the NIC is disabled should
- // fail.
- nic.setEnabled(false)
- err := ep.Enable()
- if _, ok := err.(*tcpip.ErrNotPermitted); !ok {
- t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{})
- }
- // ep should consider the NIC's enabled status when determining its own
- // enabled status so we "enable" the NIC to read just the endpoint's
- // enabled status.
- nic.setEnabled(true)
- if ep.Enabled() {
- t.Fatal("got ep.Enabled() = true, want = false")
- }
-
- // Enabling the interface after the NIC has been enabled should succeed.
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
- if !ep.Enabled() {
- t.Fatal("got ep.Enabled() = false, want = true")
- }
-
- // ep should consider the NIC's enabled status when determining its own
- // enabled status.
- nic.setEnabled(false)
- if ep.Enabled() {
- t.Fatal("got ep.Enabled() = true, want = false")
- }
-
- // Disabling the endpoint when the NIC is enabled should make the endpoint
- // disabled.
- nic.setEnabled(true)
- ep.Disable()
- if ep.Enabled() {
- t.Fatal("got ep.Enabled() = true, want = false")
- }
- })
- }
-}
-
-func TestIPv4Send(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- v4: true,
- },
- }
- ep := proto.NewEndpoint(&nic, nil)
- defer ep.Close()
-
- // Allocate and initialize the payload view.
- payload := buffer.NewView(100)
- for i := 0; i < len(payload); i++ {
- payload[i] = uint8(i)
- }
-
- // Setup the packet buffer.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(ep.MaxHeaderLength()),
- Data: payload.ToVectorisedView(),
- })
-
- // Issue the write.
- nic.testObject.protocol = 123
- nic.testObject.srcAddr = localIPv4Addr
- nic.testObject.dstAddr = remoteIPv4Addr
- nic.testObject.contents = payload
-
- r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
- if err := ep.WritePacket(r, stack.NetworkHeaderParams{
- Protocol: 123,
- TTL: 123,
- TOS: stack.DefaultTOS,
- }, pkt); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-}
-
-func TestReceive(t *testing.T) {
- tests := []struct {
- name string
- protoFactory stack.NetworkProtocolFactory
- protoNum tcpip.NetworkProtocolNumber
- v4 bool
- epAddr tcpip.AddressWithPrefix
- handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface)
- }{
- {
- name: "IPv4",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- v4: true,
- epAddr: localIPv4Addr.WithPrefix(),
- handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
- const totalLen = header.IPv4MinimumSize + 30 /* payload length */
-
- view := buffer.NewView(totalLen)
- ip := header.IPv4(view)
- ip.Encode(&header.IPv4Fields{
- TotalLength: totalLen,
- TTL: ipv4.DefaultTTL,
- Protocol: 10,
- SrcAddr: remoteIPv4Addr,
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < len(view); i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv4Addr
- nic.testObject.dstAddr = localIPv4Addr
- nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- ep.HandlePacket(pkt)
- },
- },
- {
- name: "IPv6",
- protoFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- v4: false,
- epAddr: localIPv6Addr.WithPrefix(),
- handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
- const payloadLen = 30
- view := buffer.NewView(header.IPv6MinimumSize + payloadLen)
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: payloadLen,
- TransportProtocol: 10,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: remoteIPv6Addr,
- DstAddr: localIPv6Addr,
- })
-
- // Make payload be non-zero.
- for i := header.IPv6MinimumSize; i < len(view); i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv6Addr
- nic.testObject.dstAddr = localIPv6Addr
- nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen]
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- ep.HandlePacket(pkt)
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
- })
- nic := testInterface{
- testObject: testObject{
- t: t,
- v4: test.v4,
- },
- }
- ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
- }
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
- } else {
- ep.DecRef()
- }
-
- stat := s.Stats().IP.PacketsReceived
- if got := stat.Value(); got != 0 {
- t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got)
- }
- test.handlePacket(t, ep, &nic)
- if nic.testObject.dataCalls != 1 {
- t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
- }
- if got := stat.Value(); got != 1 {
- t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestIPv4ReceiveControl(t *testing.T) {
- const (
- mtu = 0xbeef - header.IPv4MinimumSize
- dataLen = 8
- )
-
- cases := []struct {
- name string
- expectedCount int
- fragmentOffset uint16
- code header.ICMPv4Code
- transErr transportError
- trunc int
- }{
- {
- name: "FragmentationNeeded",
- expectedCount: 1,
- fragmentOffset: 0,
- code: header.ICMPv4FragmentationNeeded,
- transErr: transportError{
- origin: tcpip.SockExtErrorOriginICMP,
- typ: uint8(header.ICMPv4DstUnreachable),
- code: uint8(header.ICMPv4FragmentationNeeded),
- info: mtu,
- kind: stack.PacketTooBigTransportError,
- },
- trunc: 0,
- },
- {
- name: "Truncated (missing IPv4 header)",
- expectedCount: 0,
- fragmentOffset: 0,
- code: header.ICMPv4FragmentationNeeded,
- trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize,
- },
- {
- name: "Truncated (partial offending packet's IP header)",
- expectedCount: 0,
- fragmentOffset: 0,
- code: header.ICMPv4FragmentationNeeded,
- trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1,
- },
- {
- name: "Truncated (partial offending packet's data)",
- expectedCount: 0,
- fragmentOffset: 0,
- code: header.ICMPv4FragmentationNeeded,
- trunc: header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1,
- },
- {
- name: "Port unreachable",
- expectedCount: 1,
- fragmentOffset: 0,
- code: header.ICMPv4PortUnreachable,
- transErr: transportError{
- origin: tcpip.SockExtErrorOriginICMP,
- typ: uint8(header.ICMPv4DstUnreachable),
- code: uint8(header.ICMPv4PortUnreachable),
- kind: stack.DestinationPortUnreachableTransportError,
- },
- trunc: 0,
- },
- {
- name: "Non-zero fragment offset",
- expectedCount: 0,
- fragmentOffset: 100,
- code: header.ICMPv4PortUnreachable,
- trunc: 0,
- },
- {
- name: "Zero-length packet",
- expectedCount: 0,
- fragmentOffset: 100,
- code: header.ICMPv4PortUnreachable,
- trunc: 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen,
- },
- }
- for _, c := range cases {
- t.Run(c.name, func(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- },
- }
- ep := proto.NewEndpoint(&nic, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
- view := buffer.NewView(dataOffset + dataLen)
-
- // Create the outer IPv4 header.
- ip := header.IPv4(view)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(len(view) - c.trunc),
- TTL: 20,
- Protocol: uint8(header.ICMPv4ProtocolNumber),
- SrcAddr: "\x0a\x00\x00\xbb",
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Create the ICMP header.
- icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
- icmp.SetType(header.ICMPv4DstUnreachable)
- icmp.SetCode(c.code)
- icmp.SetIdent(0xdead)
- icmp.SetSequence(0xbeef)
-
- // Create the inner IPv4 header.
- ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:])
- ip.Encode(&header.IPv4Fields{
- TotalLength: 100,
- TTL: 20,
- Protocol: 10,
- FragmentOffset: c.fragmentOffset,
- SrcAddr: localIPv4Addr,
- DstAddr: remoteIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Make payload be non-zero.
- for i := dataOffset; i < len(view); i++ {
- view[i] = uint8(i)
- }
-
- icmp.SetChecksum(0)
- checksum := ^header.Checksum(icmp, 0 /* initial */)
- icmp.SetChecksum(checksum)
-
- // Give packet to IPv4 endpoint, dispatcher will validate that
- // it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv4Addr
- nic.testObject.dstAddr = localIPv4Addr
- nic.testObject.contents = view[dataOffset:]
- nic.testObject.transErr = c.transErr
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
- }
- addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
- } else {
- ep.DecRef()
- }
-
- pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize)
- ep.HandlePacket(pkt)
- if want := c.expectedCount; nic.testObject.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
- }
- })
- }
-}
-
-func TestIPv4FragmentationReceive(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- v4: true,
- },
- }
- ep := proto.NewEndpoint(&nic, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- totalLen := header.IPv4MinimumSize + 24
-
- frag1 := buffer.NewView(totalLen)
- ip1 := header.IPv4(frag1)
- ip1.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- TTL: 20,
- Protocol: 10,
- FragmentOffset: 0,
- Flags: header.IPv4FlagMoreFragments,
- SrcAddr: remoteIPv4Addr,
- DstAddr: localIPv4Addr,
- })
- ip1.SetChecksum(^ip1.CalculateChecksum())
-
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < totalLen; i++ {
- frag1[i] = uint8(i)
- }
-
- frag2 := buffer.NewView(totalLen)
- ip2 := header.IPv4(frag2)
- ip2.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- TTL: 20,
- Protocol: 10,
- FragmentOffset: 24,
- SrcAddr: remoteIPv4Addr,
- DstAddr: localIPv4Addr,
- })
- ip2.SetChecksum(^ip2.CalculateChecksum())
-
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < totalLen; i++ {
- frag2[i] = uint8(i)
- }
-
- // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv4Addr
- nic.testObject.dstAddr = localIPv4Addr
- nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
-
- // Send first segment.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: frag1.ToVectorisedView(),
- })
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
- }
- addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
- } else {
- ep.DecRef()
- }
-
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
- }
-
- // Send second segment.
- pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: frag2.ToVectorisedView(),
- })
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
- }
-}
-
-func TestIPv6Send(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- },
- }
- ep := proto.NewEndpoint(&nic, nil)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- // Allocate and initialize the payload view.
- payload := buffer.NewView(100)
- for i := 0; i < len(payload); i++ {
- payload[i] = uint8(i)
- }
-
- // Setup the packet buffer.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(ep.MaxHeaderLength()),
- Data: payload.ToVectorisedView(),
- })
-
- // Issue the write.
- nic.testObject.protocol = 123
- nic.testObject.srcAddr = localIPv6Addr
- nic.testObject.dstAddr = remoteIPv6Addr
- nic.testObject.contents = payload
-
- r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
- if err := ep.WritePacket(r, stack.NetworkHeaderParams{
- Protocol: 123,
- TTL: 123,
- TOS: stack.DefaultTOS,
- }, pkt); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-}
-
-func TestIPv6ReceiveControl(t *testing.T) {
- const (
- mtu = 0xffff
- outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa"
- dataLen = 8
- )
-
- newUint16 := func(v uint16) *uint16 { return &v }
-
- portUnreachableTransErr := transportError{
- origin: tcpip.SockExtErrorOriginICMP6,
- typ: uint8(header.ICMPv6DstUnreachable),
- code: uint8(header.ICMPv6PortUnreachable),
- kind: stack.DestinationPortUnreachableTransportError,
- }
-
- cases := []struct {
- name string
- expectedCount int
- fragmentOffset *uint16
- typ header.ICMPv6Type
- code header.ICMPv6Code
- transErr transportError
- trunc int
- }{
- {
- name: "PacketTooBig",
- expectedCount: 1,
- fragmentOffset: nil,
- typ: header.ICMPv6PacketTooBig,
- code: header.ICMPv6UnusedCode,
- transErr: transportError{
- origin: tcpip.SockExtErrorOriginICMP6,
- typ: uint8(header.ICMPv6PacketTooBig),
- code: uint8(header.ICMPv6UnusedCode),
- info: mtu,
- kind: stack.PacketTooBigTransportError,
- },
- trunc: 0,
- },
- {
- name: "Truncated (missing offending packet's IPv6 header)",
- expectedCount: 0,
- fragmentOffset: nil,
- typ: header.ICMPv6PacketTooBig,
- code: header.ICMPv6UnusedCode,
- trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize,
- },
- {
- name: "Truncated PacketTooBig (partial offending packet's IPv6 header)",
- expectedCount: 0,
- fragmentOffset: nil,
- typ: header.ICMPv6PacketTooBig,
- code: header.ICMPv6UnusedCode,
- trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1,
- },
- {
- name: "Truncated (partial offending packet's data)",
- expectedCount: 0,
- fragmentOffset: nil,
- typ: header.ICMPv6PacketTooBig,
- code: header.ICMPv6UnusedCode,
- trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1,
- },
- {
- name: "Port unreachable",
- expectedCount: 1,
- fragmentOffset: nil,
- typ: header.ICMPv6DstUnreachable,
- code: header.ICMPv6PortUnreachable,
- transErr: portUnreachableTransErr,
- trunc: 0,
- },
- {
- name: "Truncated DstPortUnreachable (partial offending packet's IP header)",
- expectedCount: 0,
- fragmentOffset: nil,
- typ: header.ICMPv6DstUnreachable,
- code: header.ICMPv6PortUnreachable,
- trunc: header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1,
- },
- {
- name: "DstPortUnreachable for Fragmented, zero offset",
- expectedCount: 1,
- fragmentOffset: newUint16(0),
- typ: header.ICMPv6DstUnreachable,
- code: header.ICMPv6PortUnreachable,
- transErr: portUnreachableTransErr,
- trunc: 0,
- },
- {
- name: "DstPortUnreachable for Non-zero fragment offset",
- expectedCount: 0,
- fragmentOffset: newUint16(100),
- typ: header.ICMPv6DstUnreachable,
- code: header.ICMPv6PortUnreachable,
- transErr: portUnreachableTransErr,
- trunc: 0,
- },
- {
- name: "Zero-length packet",
- expectedCount: 0,
- fragmentOffset: nil,
- typ: header.ICMPv6DstUnreachable,
- code: header.ICMPv6PortUnreachable,
- trunc: 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen,
- },
- }
- for _, c := range cases {
- t.Run(c.name, func(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- },
- }
- ep := proto.NewEndpoint(&nic, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
- if c.fragmentOffset != nil {
- dataOffset += header.IPv6FragmentHeaderSize
- }
- view := buffer.NewView(dataOffset + dataLen)
-
- // Create the outer IPv6 header.
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 20,
- SrcAddr: outerSrcAddr,
- DstAddr: localIPv6Addr,
- })
-
- // Create the ICMP header.
- icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
- icmp.SetType(c.typ)
- icmp.SetCode(c.code)
- icmp.SetIdent(0xdead)
- icmp.SetSequence(0xbeef)
-
- var extHdrs header.IPv6ExtHdrSerializer
- // Build the fragmentation header if needed.
- if c.fragmentOffset != nil {
- extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: *c.fragmentOffset,
- M: true,
- Identification: 0x12345678,
- })
- }
-
- // Create the inner IPv6 header.
- ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
- ip.Encode(&header.IPv6Fields{
- PayloadLength: 100,
- TransportProtocol: 10,
- HopLimit: 20,
- SrcAddr: localIPv6Addr,
- DstAddr: remoteIPv6Addr,
- ExtensionHeaders: extHdrs,
- })
-
- // Make payload be non-zero.
- for i := dataOffset; i < len(view); i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to IPv6 endpoint, dispatcher will validate that
- // it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv6Addr
- nic.testObject.dstAddr = localIPv6Addr
- nic.testObject.contents = view[dataOffset:]
- nic.testObject.transErr = c.transErr
-
- // Set ICMPv6 checksum.
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: outerSrcAddr,
- Dst: localIPv6Addr,
- }))
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
- }
- addr := localIPv6Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
- } else {
- ep.DecRef()
- }
- pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize)
- ep.HandlePacket(pkt)
- if want := c.expectedCount; nic.testObject.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
- }
- })
- }
-}
-
-// truncatedPacket returns a PacketBuffer based on a truncated view. If view,
-// after truncation, is large enough to hold a network header, it makes part of
-// view the packet's NetworkHeader and the rest its Data. Otherwise all of view
-// becomes Data.
-func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer {
- v := view[:len(view)-trunc]
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v.ToVectorisedView(),
- })
- return pkt
-}
-
-func TestWriteHeaderIncludedPacket(t *testing.T) {
- const (
- nicID = 1
- transportProto = 5
-
- dataLen = 4
- )
-
- dataBuf := [dataLen]byte{1, 2, 3, 4}
- data := dataBuf[:]
-
- ipv4Options := header.IPv4OptionsSerializer{
- &header.IPv4SerializableListEndOption{},
- &header.IPv4SerializableNOPOption{},
- &header.IPv4SerializableListEndOption{},
- &header.IPv4SerializableNOPOption{},
- }
-
- expectOptions := header.IPv4Options{
- byte(header.IPv4OptionListEndType),
- byte(header.IPv4OptionNOPType),
- byte(header.IPv4OptionListEndType),
- byte(header.IPv4OptionNOPType),
- }
-
- ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
- ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
-
- var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte
- ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:]
- if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
- }
- if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
-
- tests := []struct {
- name string
- protoFactory stack.NetworkProtocolFactory
- protoNum tcpip.NetworkProtocolNumber
- nicAddr tcpip.Address
- remoteAddr tcpip.Address
- pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
- checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
- expectedErr tcpip.Error
- }{
- {
- name: "IPv4",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- totalLen := header.IPv4MinimumSize + len(data)
- hdr := buffer.NewPrependable(totalLen)
- if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
- })
- return hdr.View().ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv4Any {
- src = localIPv4Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- if len(netHdr.View()) != header.IPv4MinimumSize {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
- }
-
- checker.IPv4(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv4Addr),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))),
- checker.IPPayload(data),
- )
- },
- },
- {
- name: "IPv4 with IHL too small",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- totalLen := header.IPv4MinimumSize + len(data)
- hdr := buffer.NewPrependable(totalLen)
- if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
- })
- ip.SetHeaderLength(header.IPv4MinimumSize - 1)
- return hdr.View().ToVectorisedView()
- },
- expectedErr: &tcpip.ErrMalformedHeader{},
- },
- {
- name: "IPv4 too small",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
- })
- return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
- },
- expectedErr: &tcpip.ErrMalformedHeader{},
- },
- {
- name: "IPv4 minimum size",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
- })
- return buffer.View(ip).ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv4Any {
- src = localIPv4Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- if len(netHdr.View()) != header.IPv4MinimumSize {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
- }
-
- checker.IPv4(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv4Addr),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.IPFullLength(header.IPv4MinimumSize),
- checker.IPPayload(nil),
- )
- },
- },
- {
- name: "IPv4 with options",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
- totalLen := ipHdrLen + len(data)
- hdr := buffer.NewPrependable(totalLen)
- if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
- ip := header.IPv4(hdr.Prepend(ipHdrLen))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
- Options: ipv4Options,
- })
- return hdr.View().ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv4Any {
- src = localIPv4Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
- if len(netHdr.View()) != hdrLen {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
- }
-
- checker.IPv4(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv4Addr),
- checker.IPv4HeaderLength(hdrLen),
- checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(expectOptions),
- checker.IPPayload(data),
- )
- },
- },
- {
- name: "IPv4 with options and data across views",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
- Options: ipv4Options,
- })
- vv := buffer.View(ip).ToVectorisedView()
- vv.AppendView(data)
- return vv
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv4Any {
- src = localIPv4Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
- if len(netHdr.View()) != hdrLen {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
- }
-
- checker.IPv4(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv4Addr),
- checker.IPv4HeaderLength(hdrLen),
- checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(expectOptions),
- checker.IPPayload(data),
- )
- },
- },
- {
- name: "IPv6",
- protoFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
- remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- totalLen := header.IPv6MinimumSize + len(data)
- hdr := buffer.NewPrependable(totalLen)
- if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- TransportProtocol: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
- })
- return hdr.View().ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv6Any {
- src = localIPv6Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- if len(netHdr.View()) != header.IPv6MinimumSize {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
- }
-
- checker.IPv6(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv6Addr),
- checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))),
- checker.IPPayload(data),
- )
- },
- },
- {
- name: "IPv6 with extension header",
- protoFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
- remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
- hdr := buffer.NewPrependable(totalLen)
- if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
- if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
- }
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- // NB: we're lying about transport protocol here to verify the raw
- // fragment header bytes.
- TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
- })
- return hdr.View().ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv6Any {
- src = localIPv6Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want)
- }
-
- checker.IPv6(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv6Addr),
- checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))),
- checker.IPPayload(ipv6PayloadWithExtHdr),
- )
- },
- },
- {
- name: "IPv6 minimum size",
- protoFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
- remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- TransportProtocol: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
- })
- return buffer.View(ip).ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv6Any {
- src = localIPv6Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- if len(netHdr.View()) != header.IPv6MinimumSize {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
- }
-
- checker.IPv6(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv6Addr),
- checker.IPFullLength(header.IPv6MinimumSize),
- checker.IPPayload(nil),
- )
- },
- },
- {
- name: "IPv6 too small",
- protoFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
- remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- TransportProtocol: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
- })
- return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
- },
- expectedErr: &tcpip.ErrMalformedHeader{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- subTests := []struct {
- name string
- srcAddr tcpip.Address
- }{
- {
- name: "unspecified source",
- srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
- },
- {
- name: "random source",
- srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
- },
- }
-
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
- })
- e := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
-
- r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
- }
- defer r.Release()
-
- {
- err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.pktGen(t, subTest.srcAddr),
- }))
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff)
- }
- }
-
- if test.expectedErr != nil {
- return
- }
-
- pkt, ok := e.Read()
- if !ok {
- t.Fatal("expected a packet to be written")
- }
- test.checker(t, pkt.Pkt, subTest.srcAddr)
- })
- }
- })
- }
-}
-
-// Test that the included data in an ICMP error packet conforms to the
-// requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3
-func TestICMPInclusionSize(t *testing.T) {
- const (
- replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize
- replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize
- targetSize4 = header.IPv4MinimumProcessableDatagramSize
- targetSize6 = header.IPv6MinimumMTU
- // A protocol number that will cause an error response.
- reservedProtocol = 254
- )
-
- // IPv4 function to create a IP packet and send it to the stack.
- // The packet should generate an error response. We can do that by using an
- // unknown transport protocol (254).
- rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View {
- totalLen := header.IPv4MinimumSize + len(payload)
- hdr := buffer.NewPrependable(header.IPv4MinimumSize)
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: reservedProtocol,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(buffer.View(payload))
- // Take a copy before InjectInbound takes ownership of vv
- // as vv may be changed during the call.
- v := vv.ToView()
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- return v
- }
-
- // IPv6 function to create a packet and send it to the stack.
- // The packet should be errant in a way that causes the stack to send an
- // ICMP error response and have enough data to allow the testing of the
- // inclusion of the errant packet. Use `unknown next header' to generate
- // the error.
- rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize)
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload)),
- TransportProtocol: reservedProtocol,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv6Addr,
- })
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(buffer.View(payload))
- // Take a copy before InjectInbound takes ownership of vv
- // as vv may be changed during the call.
- v := vv.ToView()
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- return v
- }
-
- v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) {
- // We already know the entire packet is the right size so we can use its
- // length to calculate the right payload size to check.
- expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize
- checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()),
- checker.SrcAddr(localIPv4Addr),
- checker.DstAddr(remoteIPv4Addr),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4ProtoUnreachable),
- checker.ICMPv4Payload(payload[:expectedPayloadLength]),
- ),
- )
- }
-
- v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) {
- // We already know the entire packet is the right size so we can use its
- // length to calculate the right payload size to check.
- expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize
- checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()),
- checker.SrcAddr(localIPv6Addr),
- checker.DstAddr(remoteIPv6Addr),
- checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6ParamProblem),
- checker.ICMPv6Code(header.ICMPv6UnknownHeader),
- checker.ICMPv6Payload(payload[:expectedPayloadLength]),
- ),
- )
- }
- tests := []struct {
- name string
- srcAddress tcpip.Address
- injector func(*channel.Endpoint, tcpip.Address, []byte) buffer.View
- checker func(*testing.T, *stack.PacketBuffer, buffer.View)
- payloadLength int // Not including IP header.
- linkMTU uint32 // Largest IP packet that the link can send as payload.
- replyLength int // Total size of IP/ICMP packet expected back.
- }{
- {
- name: "IPv4 exact match",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: targetSize4 - replyHeaderLength4,
- linkMTU: targetSize4,
- replyLength: targetSize4,
- },
- {
- name: "IPv4 larger MTU",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: targetSize4,
- linkMTU: targetSize4 + 1000,
- replyLength: targetSize4,
- },
- {
- name: "IPv4 smaller MTU",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: targetSize4,
- linkMTU: targetSize4 - 50,
- replyLength: targetSize4 - 50,
- },
- {
- name: "IPv4 payload exceeds",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: targetSize4 + 10,
- linkMTU: targetSize4,
- replyLength: targetSize4,
- },
- {
- name: "IPv4 1 byte less",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: targetSize4 - replyHeaderLength4 - 1,
- linkMTU: targetSize4,
- replyLength: targetSize4 - 1,
- },
- {
- name: "IPv4 No payload",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: 0,
- linkMTU: targetSize4,
- replyLength: replyHeaderLength4,
- },
- {
- name: "IPv6 exact match",
- srcAddress: remoteIPv6Addr,
- injector: rxIPv6Bad,
- checker: v6Checker,
- payloadLength: targetSize6 - replyHeaderLength6,
- linkMTU: targetSize6,
- replyLength: targetSize6,
- },
- {
- name: "IPv6 larger MTU",
- srcAddress: remoteIPv6Addr,
- injector: rxIPv6Bad,
- checker: v6Checker,
- payloadLength: targetSize6,
- linkMTU: targetSize6 + 400,
- replyLength: targetSize6,
- },
- // NB. No "smaller MTU" test here as less than 1280 is not permitted
- // in IPv6.
- {
- name: "IPv6 payload exceeds",
- srcAddress: remoteIPv6Addr,
- injector: rxIPv6Bad,
- checker: v6Checker,
- payloadLength: targetSize6,
- linkMTU: targetSize6,
- replyLength: targetSize6,
- },
- {
- name: "IPv6 1 byte less",
- srcAddress: remoteIPv6Addr,
- injector: rxIPv6Bad,
- checker: v6Checker,
- payloadLength: targetSize6 - replyHeaderLength6 - 1,
- linkMTU: targetSize6,
- replyLength: targetSize6 - 1,
- },
- {
- name: "IPv6 no payload",
- srcAddress: remoteIPv6Addr,
- injector: rxIPv6Bad,
- checker: v6Checker,
- payloadLength: 0,
- linkMTU: targetSize6,
- replyLength: replyHeaderLength6,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU)
- // Allocate and initialize the payload view.
- payload := buffer.NewView(test.payloadLength)
- for i := 0; i < len(payload); i++ {
- payload[i] = uint8(i)
- }
- // Default routes for IPv4&6 so ICMP can find a route to the remote
- // node when attempting to send the ICMP error Reply.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
- v := test.injector(e, test.srcAddress, payload)
- pkt, ok := e.Read()
- if !ok {
- t.Fatal("expected a packet to be written")
- }
- if got, want := pkt.Pkt.Size(), test.replyLength; got != want {
- t.Fatalf("got %d bytes of icmp error packet, want %d", got, want)
- }
- test.checker(t, pkt.Pkt, v)
- })
- }
-}
-
-func TestJoinLeaveAllRoutersGroup(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- protoFactory stack.NetworkProtocolFactory
- allRoutersAddr tcpip.Address
- }{
- {
- name: "IPv4",
- netProto: ipv4.ProtocolNumber,
- protoFactory: ipv4.NewProtocol,
- allRoutersAddr: header.IPv4AllRoutersGroup,
- },
- {
- name: "IPv6 Interface Local",
- netProto: ipv6.ProtocolNumber,
- protoFactory: ipv6.NewProtocol,
- allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress,
- },
- {
- name: "IPv6 Link Local",
- netProto: ipv6.ProtocolNumber,
- protoFactory: ipv6.NewProtocol,
- allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress,
- },
- {
- name: "IPv6 Site Local",
- netProto: ipv6.ProtocolNumber,
- protoFactory: ipv6.NewProtocol,
- allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, nicDisabled := range [...]bool{true, false} {
- t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
- opts := stack.NICOptions{Disabled: nicDisabled}
- if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err)
- }
-
- if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
- t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
- } else if got {
- t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
- }
-
- if err := s.SetForwarding(test.netProto, true); err != nil {
- t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err)
- }
- if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
- t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
- } else if !got {
- t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr)
- }
-
- if err := s.SetForwarding(test.netProto, false); err != nil {
- t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err)
- }
- if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
- t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
- } else if got {
- t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
deleted file mode 100644
index c90974693..000000000
--- a/pkg/tcpip/network/ipv4/BUILD
+++ /dev/null
@@ -1,68 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "ipv4",
- srcs = [
- "icmp.go",
- "igmp.go",
- "ipv4.go",
- "stats.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/network/hash",
- "//pkg/tcpip/network/internal/fragmentation",
- "//pkg/tcpip/network/internal/ip",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "ipv4_test",
- size = "small",
- srcs = [
- "igmp_test.go",
- "ipv4_test.go",
- ],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/arp",
- "//pkg/tcpip/network/internal/testutil",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/raw",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "stats_test",
- size = "small",
- srcs = ["stats_test.go"],
- library = ":ipv4",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- ],
-)
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
deleted file mode 100644
index 4bd6f462e..000000000
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ /dev/null
@@ -1,387 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv4_test
-
-import (
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- nicID = 1
- defaultTTL = 1
- defaultPrefixLength = 24
-)
-
-var (
- stackAddr = testutil.MustParse4("10.0.0.1")
- remoteAddr = testutil.MustParse4("10.0.0.2")
- multicastAddr = testutil.MustParse4("224.0.0.3")
-)
-
-// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet
-// sent to the provided address with the passed fields set. Raises a t.Error if
-// any field does not match.
-func validateIgmpPacket(t *testing.T, p channel.PacketInfo, igmpType header.IGMPType, maxRespTime byte, srcAddr, dstAddr, groupAddress tcpip.Address) {
- t.Helper()
-
- payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
- checker.IPv4(t, payload,
- checker.SrcAddr(srcAddr),
- checker.DstAddr(dstAddr),
- // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
- checker.TTL(1),
- checker.IPv4RouterAlert(),
- checker.IGMP(
- checker.IGMPType(igmpType),
- checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
- checker.IGMPGroupAddress(groupAddress),
- ),
- )
-}
-
-func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
- t.Helper()
-
- // Create an endpoint of queue size 1, since no more than 1 packets are ever
- // queued in the tests in this file.
- e := channel.New(1, 1280, linkAddr)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{
- IGMP: ipv4.IGMPOptions{
- Enabled: igmpEnabled,
- },
- })},
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- return e, s, clock
-}
-
-func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, ttl uint8, srcAddr, dstAddr, groupAddress tcpip.Address, hasRouterAlertOption bool) {
- var options header.IPv4OptionsSerializer
- if hasRouterAlertOption {
- options = header.IPv4OptionsSerializer{
- &header.IPv4SerializableRouterAlertOption{},
- }
- }
- buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize)
-
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(len(buf)),
- TTL: ttl,
- Protocol: uint8(header.IGMPProtocolNumber),
- SrcAddr: srcAddr,
- DstAddr: dstAddr,
- Options: options,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- igmp := header.IGMP(ip.Payload())
- igmp.SetType(igmpType)
- igmp.SetMaxRespTime(maxRespTime)
- igmp.SetGroupAddress(groupAddress)
- igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
-
- e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-}
-
-// TestIGMPV1Present tests the node's ability to fallback to V1 when a V1
-// router is detected. V1 present status is expected to be reset when the NIC
-// cycles.
-func TestIGMPV1Present(t *testing.T) {
- e, s, clock := createStack(t, true)
- addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
- }
-
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- // This NIC will send an IGMPv2 report immediately, before this test can get
- // the IGMPv1 General Membership Query in.
- {
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Inject an IGMPv1 General Membership Query which is identical to a standard
- // membership query except the Max Response Time is set to 0, which will tell
- // the stack that this is a router using IGMPv1. Send it to the all systems
- // group which is the only group this host belongs to.
- createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, defaultTTL, remoteAddr, stackAddr, header.IPv4AllSystems, true /* hasRouterAlertOption */)
- if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 {
- t.Fatalf("got Membership Queries received = %d, want = 1", got)
- }
-
- // Before advancing the clock, verify that this host has not sent a
- // V1MembershipReport yet.
- if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 {
- t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got)
- }
-
- // Verify the solicited Membership Report is sent. Now that this NIC has seen
- // an IGMPv1 query, it should send an IGMPv1 Membership Report.
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt)
- }
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- {
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V1MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 {
- t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, header.IGMPv1MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
- }
-
- // Cycling the interface should reset the V1 present flag.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- {
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got)
- }
- validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
- }
-}
-
-func TestSendQueuedIGMPReports(t *testing.T) {
- e, s, clock := createStack(t, true)
-
- // Joining a group without an assigned address should queue IGMP packets; none
- // should be sent without an assigned address.
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err)
- }
- reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport
- if got := reportStat.Value(); got != 0 {
- t.Errorf("got reportStat.Value() = %d, want = 0", got)
- }
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("got unexpected packet = %#v", p)
- }
-
- // The initial set of IGMP reports that were queued should be sent once an
- // address is assigned.
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
- }
- if got := reportStat.Value(); got != 1 {
- t.Errorf("got reportStat.Value() = %d, want = 1", got)
- }
- if p, ok := e.Read(); !ok {
- t.Error("expected to send an IGMP membership report")
- } else {
- validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
- }
- if t.Failed() {
- t.FailNow()
- }
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- if got := reportStat.Value(); got != 2 {
- t.Errorf("got reportStat.Value() = %d, want = 2", got)
- }
- if p, ok := e.Read(); !ok {
- t.Error("expected to send an IGMP membership report")
- } else {
- validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Should have no more packets to send after the initial set of unsolicited
- // reports.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("got unexpected packet = %#v", p)
- }
-}
-
-func TestIGMPPacketValidation(t *testing.T) {
- tests := []struct {
- name string
- messageType header.IGMPType
- stackAddresses []tcpip.AddressWithPrefix
- srcAddr tcpip.Address
- includeRouterAlertOption bool
- ttl uint8
- expectValidIGMP bool
- getMessageTypeStatValue func(tcpip.Stats) uint64
- }{
- {
- name: "valid",
- messageType: header.IGMPLeaveGroup,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: remoteAddr,
- ttl: 1,
- expectValidIGMP: true,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() },
- },
- {
- name: "bad ttl",
- messageType: header.IGMPv1MembershipReport,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: remoteAddr,
- ttl: 2,
- expectValidIGMP: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() },
- },
- {
- name: "missing router alert ip option",
- messageType: header.IGMPv2MembershipReport,
- includeRouterAlertOption: false,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: remoteAddr,
- ttl: 1,
- expectValidIGMP: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
- },
- {
- name: "igmp leave group and src ip does not belong to nic subnet",
- messageType: header.IGMPLeaveGroup,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: testutil.MustParse4("10.0.1.2"),
- ttl: 1,
- expectValidIGMP: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() },
- },
- {
- name: "igmp query and src ip does not belong to nic subnet",
- messageType: header.IGMPMembershipQuery,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: testutil.MustParse4("10.0.1.2"),
- ttl: 1,
- expectValidIGMP: true,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() },
- },
- {
- name: "igmp report v1 and src ip does not belong to nic subnet",
- messageType: header.IGMPv1MembershipReport,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: testutil.MustParse4("10.0.1.2"),
- ttl: 1,
- expectValidIGMP: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() },
- },
- {
- name: "igmp report v2 and src ip does not belong to nic subnet",
- messageType: header.IGMPv2MembershipReport,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: testutil.MustParse4("10.0.1.2"),
- ttl: 1,
- expectValidIGMP: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
- },
- {
- name: "src ip belongs to the subnet of the nic's second address",
- messageType: header.IGMPv2MembershipReport,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{
- {Address: testutil.MustParse4("10.0.15.1"), PrefixLen: 24},
- {Address: stackAddr, PrefixLen: 24},
- },
- srcAddr: remoteAddr,
- ttl: 1,
- expectValidIGMP: true,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, _ := createStack(t, true)
- for _, address := range test.stackAddresses {
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err)
- }
- }
- stats := s.Stats()
- // Verify that every relevant stats is zero'd before we send a packet.
- if got := test.getMessageTypeStatValue(s.Stats()); got != 0 {
- t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got)
- }
- if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != 0 {
- t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = 0", got)
- }
- if got := stats.IP.PacketsDelivered.Value(); got != 0 {
- t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got)
- }
- createAndInjectIGMPPacket(e, test.messageType, 0, test.ttl, test.srcAddr, header.IPv4AllSystems, header.IPv4AllSystems, test.includeRouterAlertOption)
- // We always expect the packet to pass IP validation.
- if got := stats.IP.PacketsDelivered.Value(); got != 1 {
- t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got)
- }
- // Even when the IGMP-specific validation checks fail, we expect the
- // corresponding IGMP counter to be incremented.
- if got := test.getMessageTypeStatValue(s.Stats()); got != 1 {
- t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got)
- }
- var expectedInvalidCount uint64
- if !test.expectValidIGMP {
- expectedInvalidCount = 1
- }
- if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != expectedInvalidCount {
- t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount)
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go
new file mode 100644
index 000000000..19b672251
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go
@@ -0,0 +1,113 @@
+// automatically generated by stateify.
+
+package ipv4
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (i *icmpv4DestinationUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv4.icmpv4DestinationUnreachableSockError"
+}
+
+func (i *icmpv4DestinationUnreachableSockError) StateFields() []string {
+ return []string{}
+}
+
+func (i *icmpv4DestinationUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+}
+
+func (i *icmpv4DestinationUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+}
+
+func (i *icmpv4DestinationHostUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv4.icmpv4DestinationHostUnreachableSockError"
+}
+
+func (i *icmpv4DestinationHostUnreachableSockError) StateFields() []string {
+ return []string{
+ "icmpv4DestinationUnreachableSockError",
+ }
+}
+
+func (i *icmpv4DestinationHostUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationHostUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError)
+}
+
+func (i *icmpv4DestinationHostUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationHostUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError)
+}
+
+func (i *icmpv4DestinationPortUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv4.icmpv4DestinationPortUnreachableSockError"
+}
+
+func (i *icmpv4DestinationPortUnreachableSockError) StateFields() []string {
+ return []string{
+ "icmpv4DestinationUnreachableSockError",
+ }
+}
+
+func (i *icmpv4DestinationPortUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError)
+}
+
+func (i *icmpv4DestinationPortUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationPortUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError)
+}
+
+func (e *icmpv4FragmentationNeededSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv4.icmpv4FragmentationNeededSockError"
+}
+
+func (e *icmpv4FragmentationNeededSockError) StateFields() []string {
+ return []string{
+ "icmpv4DestinationUnreachableSockError",
+ "mtu",
+ }
+}
+
+func (e *icmpv4FragmentationNeededSockError) beforeSave() {}
+
+// +checklocksignore
+func (e *icmpv4FragmentationNeededSockError) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.icmpv4DestinationUnreachableSockError)
+ stateSinkObject.Save(1, &e.mtu)
+}
+
+func (e *icmpv4FragmentationNeededSockError) afterLoad() {}
+
+// +checklocksignore
+func (e *icmpv4FragmentationNeededSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.icmpv4DestinationUnreachableSockError)
+ stateSourceObject.Load(1, &e.mtu)
+}
+
+func init() {
+ state.Register((*icmpv4DestinationUnreachableSockError)(nil))
+ state.Register((*icmpv4DestinationHostUnreachableSockError)(nil))
+ state.Register((*icmpv4DestinationPortUnreachableSockError)(nil))
+ state.Register((*icmpv4FragmentationNeededSockError)(nil))
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
deleted file mode 100644
index 7a7cad04a..000000000
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ /dev/null
@@ -1,3228 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv4_test
-
-import (
- "bytes"
- "context"
- "encoding/hex"
- "fmt"
- "io/ioutil"
- "math"
- "net"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- extraHeaderReserve = 50
- defaultMTU = 65536
-)
-
-func TestExcludeBroadcast(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
-
- ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
- if testing.Verbose() {
- ep = sniffer.New(ep)
- }
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- }})
-
- randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53}
-
- var wq waiter.Queue
- t.Run("WithoutPrimaryAddress", func(t *testing.T) {
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatal(err)
- }
- defer ep.Close()
-
- // Cannot connect using a broadcast address as the source.
- {
- err := ep.Connect(randomAddr)
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{})
- }
- }
-
- // However, we can bind to a broadcast address to listen.
- if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil {
- t.Errorf("Bind failed: %v", err)
- }
- })
-
- t.Run("WithPrimaryAddress", func(t *testing.T) {
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatal(err)
- }
- defer ep.Close()
-
- // Add a valid primary endpoint address, now we can connect.
- if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
- if err := ep.Connect(randomAddr); err != nil {
- t.Errorf("Connect failed: %v", err)
- }
- })
-}
-
-func TestForwarding(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
- randomSequence = 123
- randomIdent = 42
- randomTimeOffset = 0x10203040
- )
-
- ipv4Addr1 := tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
- PrefixLen: 8,
- }
- ipv4Addr2 := tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
- PrefixLen: 8,
- }
- remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4())
- remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4())
- unreachableIPv4Addr := tcpip.Address(net.ParseIP("12.0.0.2").To4())
- multicastIPv4Addr := tcpip.Address(net.ParseIP("225.0.0.0").To4())
- linkLocalIPv4Addr := tcpip.Address(net.ParseIP("169.254.0.0").To4())
-
- tests := []struct {
- name string
- TTL uint8
- sourceAddr tcpip.Address
- destAddr tcpip.Address
- expectErrorICMP bool
- expectPacketForwarded bool
- options header.IPv4Options
- forwardedOptions header.IPv4Options
- icmpType header.ICMPv4Type
- icmpCode header.ICMPv4Code
- expectPacketUnrouteableError bool
- expectLinkLocalSourceError bool
- expectLinkLocalDestError bool
- }{
- {
- name: "TTL of zero",
- TTL: 0,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- expectErrorICMP: true,
- icmpType: header.ICMPv4TimeExceeded,
- icmpCode: header.ICMPv4TTLExceeded,
- },
- {
- name: "TTL of one",
- TTL: 1,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- expectPacketForwarded: true,
- },
- {
- name: "TTL of two",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- expectPacketForwarded: true,
- },
- {
- name: "Max TTL",
- TTL: math.MaxUint8,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- expectPacketForwarded: true,
- },
- {
- name: "four EOL options",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- expectPacketForwarded: true,
- options: header.IPv4Options{0, 0, 0, 0},
- forwardedOptions: header.IPv4Options{0, 0, 0, 0},
- },
- {
- name: "TS type 1 full",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- options: header.IPv4Options{
- 68, 12, 13, 0xF1,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- },
- expectErrorICMP: true,
- icmpType: header.ICMPv4ParamProblem,
- icmpCode: header.ICMPv4UnusedCode,
- },
- {
- name: "TS type 0",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- options: header.IPv4Options{
- 68, 24, 21, 0x00,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 0, 0, 0, 0,
- },
- forwardedOptions: header.IPv4Options{
- 68, 24, 25, 0x00,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
- },
- expectPacketForwarded: true,
- },
- {
- name: "end of options list",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- options: header.IPv4Options{
- 68, 12, 13, 0x11,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 0, 10, 3, 99, // EOL followed by junk
- 1, 2, 3, 4,
- },
- forwardedOptions: header.IPv4Options{
- 68, 12, 13, 0x21,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 0, // End of Options hides following bytes.
- 0, 0, 0, // 7 bytes unknown option removed.
- 0, 0, 0, 0,
- },
- expectPacketForwarded: true,
- },
- {
- name: "Network unreachable",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: unreachableIPv4Addr,
- expectErrorICMP: true,
- icmpType: header.ICMPv4DstUnreachable,
- icmpCode: header.ICMPv4NetUnreachable,
- expectPacketUnrouteableError: true,
- },
- {
- name: "Multicast destination",
- TTL: 2,
- destAddr: multicastIPv4Addr,
- expectPacketUnrouteableError: true,
- },
- {
- name: "Link local destination",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: linkLocalIPv4Addr,
- expectLinkLocalDestError: true,
- },
- {
- name: "Link local source",
- TTL: 2,
- sourceAddr: linkLocalIPv4Addr,
- destAddr: remoteIPv4Addr2,
- expectLinkLocalSourceError: true,
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
- Clock: clock,
- })
-
- // Advance the clock by some unimportant amount to make
- // it give a more recognisable signature than 00,00,00,00.
- clock.Advance(time.Millisecond * randomTimeOffset)
-
- // We expect at most a single packet in response to our ICMP Echo Request.
- e1 := channel.New(1, ipv4.MaxTotalSize, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1}
- if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err)
- }
-
- e2 := channel.New(1, ipv4.MaxTotalSize, "")
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
- ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2}
- if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: ipv4Addr1.Subnet(),
- NIC: nicID1,
- },
- {
- Destination: ipv4Addr2.Subnet(),
- NIC: nicID2,
- },
- })
-
- if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err)
- }
-
- ipHeaderLength := header.IPv4MinimumSize + len(test.options)
- if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
- }
- totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
- hdr := buffer.NewPrependable(int(totalLen))
- icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- icmp.SetIdent(randomIdent)
- icmp.SetSequence(randomSequence)
- icmp.SetType(header.ICMPv4Echo)
- icmp.SetCode(header.ICMPv4UnusedCode)
- icmp.SetChecksum(0)
- icmp.SetChecksum(^header.Checksum(icmp, 0))
- ip := header.IPv4(hdr.Prepend(ipHeaderLength))
- ip.Encode(&header.IPv4Fields{
- TotalLength: totalLen,
- Protocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: test.TTL,
- SrcAddr: test.sourceAddr,
- DstAddr: test.destAddr,
- })
- if len(test.options) != 0 {
- ip.SetHeaderLength(uint8(ipHeaderLength))
- // Copy options manually. We do not use Encode for options so we can
- // verify malformed options with handcrafted payloads.
- if want, got := copy(ip.Options(), test.options), len(test.options); want != got {
- t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want)
- }
- }
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
- requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- })
- e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
-
- reply, ok := e1.Read()
- if test.expectErrorICMP {
- if !ok {
- t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
- }
-
- checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(ipv4Addr1.Address),
- checker.DstAddr(test.sourceAddr),
- checker.TTL(ipv4.DefaultTTL),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Type(test.icmpType),
- checker.ICMPv4Code(test.icmpCode),
- checker.ICMPv4Payload([]byte(hdr.View())),
- ),
- )
-
- if n := e2.Drain(); n != 0 {
- t.Fatalf("got e2.Drain() = %d, want = 0", n)
- }
- } else if ok {
- t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
- }
-
- reply, ok = e2.Read()
- if test.expectPacketForwarded {
- if !ok {
- t.Fatal("expected ICMP Echo packet through outgoing NIC")
- }
-
- checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(test.sourceAddr),
- checker.DstAddr(test.destAddr),
- checker.TTL(test.TTL-1),
- checker.IPv4Options(test.forwardedOptions),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Type(header.ICMPv4Echo),
- checker.ICMPv4Code(header.ICMPv4UnusedCode),
- checker.ICMPv4Payload(nil),
- ),
- )
-
- if n := e1.Drain(); n != 0 {
- t.Fatalf("got e1.Drain() = %d, want = 0", n)
- }
- } else if ok {
- t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
- }
-
- boolToInt := func(val bool) uint64 {
- if val {
- return 1
- }
- return 0
- }
-
- if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want {
- t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
-
-// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
-// checks the response.
-func TestIPv4Sanity(t *testing.T) {
- const (
- ttl = 255
- nicID = 1
- randomSequence = 123
- randomIdent = 42
- // In some cases Linux sets the error pointer to the start of the option
- // (offset 0) instead of the actual wrong value, which is the length byte
- // (offset 1). For compatibility we must do the same. Use this constant
- // to indicate where this happens.
- pointerOffsetForInvalidLength = 0
- randomTimeOffset = 0x10203040
- )
- var (
- ipv4Addr = tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
- PrefixLen: 24,
- }
- remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
- )
-
- tests := []struct {
- name string
- headerLength uint8 // value of 0 means "use correct size"
- badHeaderChecksum bool
- maxTotalLength uint16
- transportProtocol uint8
- TTL uint8
- options header.IPv4Options
- replyOptions header.IPv4Options // reply should look like this
- shouldFail bool
- expectErrorICMP bool
- ICMPType header.ICMPv4Type
- ICMPCode header.ICMPv4Code
- paramProblemPointer uint8
- }{
- {
- name: "valid no options",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- },
- {
- name: "bad header checksum",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- badHeaderChecksum: true,
- shouldFail: true,
- },
- // The TTL tests check that we are not rejecting an incoming packet
- // with a zero or one TTL, which has been a point of confusion in the
- // past as RFC 791 says: "If this field contains the value zero, then the
- // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies
- // for the case of the destination host, stating as follows.
- //
- // A host MUST NOT send a datagram with a Time-to-Live (TTL)
- // value of zero.
- //
- // A host MUST NOT discard a datagram just because it was
- // received with TTL less than 2.
- {
- name: "zero TTL",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: 0,
- },
- {
- name: "one TTL",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: 1,
- },
- {
- name: "End options",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{0, 0, 0, 0},
- replyOptions: header.IPv4Options{0, 0, 0, 0},
- },
- {
- name: "NOP options",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{1, 1, 1, 1},
- replyOptions: header.IPv4Options{1, 1, 1, 1},
- },
- {
- name: "NOP and End options",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{1, 1, 0, 0},
- replyOptions: header.IPv4Options{1, 1, 0, 0},
- },
- {
- name: "bad header length",
- headerLength: header.IPv4MinimumSize - 1,
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- shouldFail: true,
- },
- {
- name: "bad total length (0)",
- maxTotalLength: 0,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- shouldFail: true,
- },
- {
- name: "bad total length (ip - 1)",
- maxTotalLength: uint16(header.IPv4MinimumSize - 1),
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- shouldFail: true,
- },
- {
- name: "bad total length (ip + icmp - 1)",
- maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1),
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- shouldFail: true,
- },
- {
- name: "bad protocol",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: 99,
- TTL: ttl,
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4DstUnreachable,
- ICMPCode: header.ICMPv4ProtoUnreachable,
- },
- {
- name: "timestamp option overflow",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 12, 13, 0x11,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- },
- replyOptions: header.IPv4Options{
- 68, 12, 13, 0x21,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- },
- },
- {
- name: "timestamp option overflow full",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 12, 13, 0xF1,
- // ^ Counter full (15/0xF)
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 3,
- replyOptions: header.IPv4Options{},
- },
- {
- name: "unknown option",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{10, 4, 9, 0},
- // ^^
- // The unknown option should be stripped out of the reply.
- replyOptions: header.IPv4Options{},
- },
- {
- name: "bad option - no length",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 1, 1, 1, 68,
- // ^-start of timestamp.. but no length..
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 3,
- },
- {
- name: "bad option - length 0",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 0, 9, 0,
- // ^
- 1, 2, 3, 4,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
- },
- {
- name: "bad option - length 1",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 1, 9, 0,
- // ^
- 1, 2, 3, 4,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
- },
- {
- name: "bad option - length big",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 9, 9, 0,
- // ^
- // There are only 8 bytes allocated to options so 9 bytes of timestamp
- // space is not possible. (Second byte)
- 1, 2, 3, 4,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
- },
- {
- // This tests for some linux compatible behaviour.
- // The ICMP pointer returned is 22 for Linux but the
- // error is actually in spot 21.
- name: "bad option - length bad",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- // Timestamps are in multiples of 4 or 8 but never 7.
- // The option space should be padded out.
- options: header.IPv4Options{
- 68, 7, 5, 0,
- // ^ ^ Linux points here which is wrong.
- // | Not a multiple of 4
- 1, 2, 3, 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
- },
- {
- name: "multiple type 0 with room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 24, 21, 0x00,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 0, 0, 0, 0,
- },
- replyOptions: header.IPv4Options{
- 68, 24, 25, 0x00,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
- },
- },
- {
- // The timestamp area is full so add to the overflow count.
- name: "multiple type 1 timestamps",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 20, 21, 0x11,
- // ^
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 192, 168, 1, 13,
- 5, 6, 7, 8,
- },
- // Overflow count is the top nibble of the 4th byte.
- replyOptions: header.IPv4Options{
- 68, 20, 21, 0x21,
- // ^
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 192, 168, 1, 13,
- 5, 6, 7, 8,
- },
- },
- {
- name: "multiple type 1 timestamps with room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 28, 21, 0x01,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 192, 168, 1, 13,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- replyOptions: header.IPv4Options{
- 68, 28, 29, 0x01,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 192, 168, 1, 13,
- 5, 6, 7, 8,
- 192, 168, 1, 58, // New IP Address.
- 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
- },
- },
- {
- // Timestamp pointer uses one based counting so 0 is invalid.
- name: "timestamp pointer invalid",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 8, 0, 0x00,
- // ^ 0 instead of 5 or more.
- 0, 0, 0, 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
- },
- {
- // Timestamp pointer cannot be less than 5. It must point past the header
- // which is 4 bytes. (1 based counting)
- name: "timestamp pointer too small by 1",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 8, header.IPv4OptionTimestampHdrLength, 0x00,
- // ^ header is 4 bytes, so 4 should fail.
- 0, 0, 0, 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
- },
- {
- name: "valid timestamp pointer",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00,
- // ^ header is 4 bytes, so 5 should succeed.
- 0, 0, 0, 0,
- },
- replyOptions: header.IPv4Options{
- 68, 8, 9, 0x00,
- 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
- },
- },
- {
- // Needs 8 bytes for a type 1 timestamp but there are only 4 free.
- name: "bad timer element alignment",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 20, 17, 0x01,
- // ^^ ^^ 20 byte area, next free spot at 17.
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
- },
- // End of option list with illegal option after it, which should be ignored.
- {
- name: "end of options list",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 12, 13, 0x11,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 0, 10, 3, 99, // EOL followed by junk
- },
- replyOptions: header.IPv4Options{
- 68, 12, 13, 0x21,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 0, // End of Options hides following bytes.
- 0, 0, 0, // 3 bytes unknown option removed.
- },
- },
- {
- // Timestamp with a size much too small.
- name: "timestamp truncated",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 1, 0, 0,
- // ^ Smallest possible is 8. Linux points at the 68.
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
- },
- {
- name: "single record route with room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 7, 4, // 3 byte header
- 0, 0, 0, 0,
- 0,
- },
- replyOptions: header.IPv4Options{
- 7, 7, 8, // 3 byte header
- 192, 168, 1, 58, // New IP Address.
- 0, // padding to multiple of 4 bytes.
- },
- },
- {
- name: "multiple record route with room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 23, 20, // 3 byte header
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 0, 0, 0, 0,
- 0,
- },
- replyOptions: header.IPv4Options{
- 7, 23, 24,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 192, 168, 1, 58, // New IP Address.
- 0, // padding to multiple of 4 bytes.
- },
- },
- {
- name: "single record route with no room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 7, 8, // 3 byte header
- 1, 2, 3, 4,
- 0,
- },
- replyOptions: header.IPv4Options{
- 7, 7, 8, // 3 byte header
- 1, 2, 3, 4,
- 0, // padding to multiple of 4 bytes.
- },
- },
- {
- // Unlike timestamp, this should just succeed.
- name: "multiple record route with no room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 23, 24, // 3 byte header
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 0,
- },
- replyOptions: header.IPv4Options{
- 7, 23, 24,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 0, // padding to multiple of 4 bytes.
- },
- },
- {
- // Pointer uses one based counting so 0 is invalid.
- name: "record route pointer zero",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 8, 0, // 3 byte header
- 0, 0, 0, 0,
- 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
- },
- {
- // Pointer must be 4 or more as it must point past the 3 byte header
- // using 1 based counting. 3 should fail.
- name: "record route pointer too small by 1",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header
- 0, 0, 0, 0,
- 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
- },
- {
- // Pointer must be 4 or more as it must point past the 3 byte header
- // using 1 based counting. Check 4 passes. (Duplicates "single
- // record route with room")
- name: "valid record route pointer",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header
- 0, 0, 0, 0,
- 0,
- },
- replyOptions: header.IPv4Options{
- 7, 7, 8, // 3 byte header
- 192, 168, 1, 58, // New IP Address.
- 0, // padding to multiple of 4 bytes.
- },
- },
- {
- // Confirm Linux bug for bug compatibility.
- // Linux returns slot 22 but the error is in slot 21.
- name: "multiple record route with not enough room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 8, 8, // 3 byte header
- // ^ ^ Linux points here. We must too.
- // | Not enough room. 1 byte free, need 4.
- 1, 2, 3, 4,
- 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
- },
- {
- name: "duplicate record route",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 7, 8, // 3 byte header
- 1, 2, 3, 4,
- 7, 7, 8, // 3 byte header
- 1, 2, 3, 4,
- 0, 0, // pad
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 7,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
- Clock: clock,
- })
- // Advance the clock by some unimportant amount to make
- // it give a more recognisable signature than 00,00,00,00.
- clock.Advance(time.Millisecond * randomTimeOffset)
-
- // We expect at most a single packet in response to our ICMP Echo Request.
- e := channel.New(1, ipv4.MaxTotalSize, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
- }
-
- // Default routes for IPv4 so ICMP can find a route to the remote
- // node when attempting to send the ICMP Echo Reply.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- })
-
- if len(test.options)%4 != 0 {
- t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options))
- }
- ipHeaderLength := header.IPv4MinimumSize + len(test.options)
- if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
- }
- totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
- hdr := buffer.NewPrependable(int(totalLen))
- icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
-
- // Specify ident/seq to make sure we get the same in the response.
- icmp.SetIdent(randomIdent)
- icmp.SetSequence(randomSequence)
- icmp.SetType(header.ICMPv4Echo)
- icmp.SetCode(header.ICMPv4UnusedCode)
- icmp.SetChecksum(0)
- icmp.SetChecksum(^header.Checksum(icmp, 0))
- ip := header.IPv4(hdr.Prepend(ipHeaderLength))
- if test.maxTotalLength < totalLen {
- totalLen = test.maxTotalLength
- }
- ip.Encode(&header.IPv4Fields{
- TotalLength: totalLen,
- Protocol: test.transportProtocol,
- TTL: test.TTL,
- SrcAddr: remoteIPv4Addr,
- DstAddr: ipv4Addr.Address,
- })
- if test.headerLength != 0 {
- ip.SetHeaderLength(test.headerLength)
- } else {
- // Set the calculated header length, since we may manually add options.
- ip.SetHeaderLength(uint8(ipHeaderLength))
- }
- if len(test.options) != 0 {
- // Copy options manually. We do not use Encode for options so we can
- // verify malformed options with handcrafted payloads.
- if want, got := copy(ip.Options(), test.options), len(test.options); want != got {
- t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want)
- }
- }
- ip.SetChecksum(0)
- ipHeaderChecksum := ip.CalculateChecksum()
- if test.badHeaderChecksum {
- ipHeaderChecksum += 42
- }
- ip.SetChecksum(^ipHeaderChecksum)
- requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- })
- e.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
- reply, ok := e.Read()
- if !ok {
- if test.shouldFail {
- if test.expectErrorICMP {
- t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode)
- }
- return // Expected silent failure.
- }
- t.Fatal("expected ICMP echo reply missing")
- }
-
- // We didn't expect a packet. Register our surprise but carry on to
- // provide more information about what we got.
- if test.shouldFail && !test.expectErrorICMP {
- t.Error("unexpected packet response")
- }
-
- // Check the route that brought the packet to us.
- if reply.Route.LocalAddress != ipv4Addr.Address {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
- }
- if reply.Route.RemoteAddress != remoteIPv4Addr {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
- }
-
- // Make sure it's all in one buffer for checker.
- replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader()))
-
- // At this stage we only know it's probably an IP+ICMP header so verify
- // that much.
- checker.IPv4(t, replyIPHeader,
- checker.SrcAddr(ipv4Addr.Address),
- checker.DstAddr(remoteIPv4Addr),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- ),
- )
-
- // Don't proceed any further if the checker found problems.
- if t.Failed() {
- t.FailNow()
- }
-
- // OK it's ICMP. We can safely look at the type now.
- replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
- switch replyICMPHeader.Type() {
- case header.ICMPv4ParamProblem:
- if !test.shouldFail {
- t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer())
- }
- if !test.expectErrorICMP {
- t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer())
- }
- checker.IPv4(t, replyIPHeader,
- checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.ICMPv4(
- checker.ICMPv4Type(test.ICMPType),
- checker.ICMPv4Code(test.ICMPCode),
- checker.ICMPv4Pointer(test.paramProblemPointer),
- checker.ICMPv4Payload([]byte(hdr.View())),
- ),
- )
- return
- case header.ICMPv4DstUnreachable:
- if !test.shouldFail {
- t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply",
- header.ICMPv4DstUnreachable, replyICMPHeader.Code())
- }
- if !test.expectErrorICMP {
- t.Fatalf("got ICMP error packet type %d, code %d, wanted no response",
- header.ICMPv4DstUnreachable, replyICMPHeader.Code())
- }
- checker.IPv4(t, replyIPHeader,
- checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.ICMPv4(
- checker.ICMPv4Type(test.ICMPType),
- checker.ICMPv4Code(test.ICMPCode),
- checker.ICMPv4Payload([]byte(hdr.View())),
- ),
- )
- return
- case header.ICMPv4EchoReply:
- if test.shouldFail {
- if !test.expectErrorICMP {
- t.Error("got Echo Reply packet, want no response")
- } else {
- t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode)
- }
- }
- // If the IP options change size then the packet will change size, so
- // some IP header fields will need to be adjusted for the checks.
- sizeChange := len(test.replyOptions) - len(test.options)
-
- checker.IPv4(t, replyIPHeader,
- checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
- checker.IPv4Options(test.replyOptions),
- checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Code(header.ICMPv4UnusedCode),
- checker.ICMPv4Seq(randomSequence),
- checker.ICMPv4Ident(randomIdent),
- ),
- )
- default:
- t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d",
- replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem)
- }
- })
- }
-}
-
-// comparePayloads compared the contents of all the packets against the contents
-// of the source packet.
-func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
- // Make a complete array of the sourcePacket packet.
- source := header.IPv4(packets[0].NetworkHeader().View())
- vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
- source = append(source, vv.ToView()...)
-
- // Make a copy of the IP header, which will be modified in some fields to make
- // an expected header.
- sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...))
- sourceCopy.SetChecksum(0)
- sourceCopy.SetFlagsFragmentOffset(0, 0)
- sourceCopy.SetTotalLength(0)
- // Build up an array of the bytes sent.
- var reassembledPayload buffer.VectorisedView
- for i, packet := range packets {
- // Confirm that the packet is valid.
- allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views())
- fragmentIPHeader := header.IPv4(allBytes.ToView())
- if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) {
- return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader))
- }
- if got := len(fragmentIPHeader); got > int(mtu) {
- return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu)
- }
- if got := fragmentIPHeader.TransportProtocol(); got != proto {
- return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto))
- }
- if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve {
- return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
- }
- if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want {
- return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want)
- }
- if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want {
- return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want)
- }
- if wantFragments[i].more {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset)
- } else {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset)
- }
- reassembledPayload.AppendView(packet.TransportHeader().View())
- reassembledPayload.AppendView(packet.Data().AsRange().ToOwnedView())
- // Clear out the checksum and length from the ip because we can't compare
- // it.
- sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize)
- sourceCopy.SetChecksum(0)
- sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
- if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" {
- return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
- }
- }
-
- expected := buffer.View(source[source.HeaderLength():])
- if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" {
- return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
- }
-
- return nil
-}
-
-type fragmentInfo struct {
- offset uint16
- more bool
- payloadSize uint16
-}
-
-var fragmentationTests = []struct {
- description string
- mtu uint32
- transportHeaderLength int
- payloadSize int
- wantFragments []fragmentInfo
-}{
- {
- description: "No fragmentation",
- mtu: 1280,
- transportHeaderLength: 0,
- payloadSize: 1000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1000, more: false},
- },
- },
- {
- description: "Fragmented",
- mtu: 1280,
- transportHeaderLength: 0,
- payloadSize: 2000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1256, more: true},
- {offset: 1256, payloadSize: 744, more: false},
- },
- },
- {
- description: "Fragmented with the minimum mtu",
- mtu: header.IPv4MinimumMTU,
- transportHeaderLength: 0,
- payloadSize: 100,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 48, more: true},
- {offset: 48, payloadSize: 48, more: true},
- {offset: 96, payloadSize: 4, more: false},
- },
- },
- {
- description: "Fragmented with mtu not a multiple of 8",
- mtu: header.IPv4MinimumMTU + 1,
- transportHeaderLength: 0,
- payloadSize: 100,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 48, more: true},
- {offset: 48, payloadSize: 48, more: true},
- {offset: 96, payloadSize: 4, more: false},
- },
- },
- {
- description: "No fragmentation with big header",
- mtu: 2000,
- transportHeaderLength: 100,
- payloadSize: 1000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1100, more: false},
- },
- },
- {
- description: "Fragmented with big header",
- mtu: 1280,
- transportHeaderLength: 100,
- payloadSize: 1200,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1256, more: true},
- {offset: 1256, payloadSize: 44, more: false},
- },
- },
- {
- description: "Fragmented with MTU smaller than header",
- mtu: 300,
- transportHeaderLength: 1000,
- payloadSize: 500,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 280, more: true},
- {offset: 280, payloadSize: 280, more: true},
- {offset: 560, payloadSize: 280, more: true},
- {offset: 840, payloadSize: 280, more: true},
- {offset: 1120, payloadSize: 280, more: true},
- {offset: 1400, payloadSize: 100, more: false},
- },
- },
-}
-
-func TestFragmentationWritePacket(t *testing.T) {
- const ttl = 42
-
- for _, ft := range fragmentationTests {
- t.Run(ft.description, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
- r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
- source := pkt.Clone()
- err := r.WritePacket(stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- }, pkt)
- if err != nil {
- t.Fatalf("r.WritePacket(...): %s", err)
- }
- if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
- t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
- }
- if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
- }
- if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
- t.Error(err)
- }
- })
- }
-}
-
-func TestFragmentationWritePackets(t *testing.T) {
- const ttl = 42
- writePacketsTests := []struct {
- description string
- insertBefore int
- insertAfter int
- }{
- {
- description: "Single packet",
- insertBefore: 0,
- insertAfter: 0,
- },
- {
- description: "With packet before",
- insertBefore: 1,
- insertAfter: 0,
- },
- {
- description: "With packet after",
- insertBefore: 0,
- insertAfter: 1,
- },
- {
- description: "With packet before and after",
- insertBefore: 1,
- insertAfter: 1,
- },
- }
- tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber)
-
- for _, test := range writePacketsTests {
- t.Run(test.description, func(t *testing.T) {
- for _, ft := range fragmentationTests {
- t.Run(ft.description, func(t *testing.T) {
- var pkts stack.PacketBufferList
- for i := 0; i < test.insertBefore; i++ {
- pkts.PushBack(tinyPacket.Clone())
- }
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
- pkts.PushBack(pkt.Clone())
- for i := 0; i < test.insertAfter; i++ {
- pkts.PushBack(tinyPacket.Clone())
- }
-
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
- r := buildRoute(t, ep)
-
- wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
- n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- })
- if err != nil {
- t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err)
- }
- if n != wantTotalPackets {
- t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets)
- }
- if got := len(ep.WrittenPackets); got != wantTotalPackets {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
- t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
- }
- if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
- }
-
- if wantTotalPackets == 0 {
- return
- }
-
- fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
- if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
- t.Error(err)
- }
- })
- }
- })
- }
-}
-
-// TestFragmentationErrors checks that errors are returned from WritePacket
-// correctly.
-func TestFragmentationErrors(t *testing.T) {
- const ttl = 42
-
- tests := []struct {
- description string
- mtu uint32
- transportHeaderLength int
- payloadSize int
- allowPackets int
- outgoingErrors int
- mockError tcpip.Error
- wantError tcpip.Error
- }{
- {
- description: "No frag",
- mtu: 2000,
- payloadSize: 1000,
- transportHeaderLength: 0,
- allowPackets: 0,
- outgoingErrors: 1,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error on first frag",
- mtu: 500,
- payloadSize: 1000,
- transportHeaderLength: 0,
- allowPackets: 0,
- outgoingErrors: 3,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error on second frag",
- mtu: 500,
- payloadSize: 1000,
- transportHeaderLength: 0,
- allowPackets: 1,
- outgoingErrors: 2,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error on first frag MTU smaller than header",
- mtu: 500,
- transportHeaderLength: 1000,
- payloadSize: 500,
- allowPackets: 0,
- outgoingErrors: 4,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error when MTU is smaller than IPv4 minimum MTU",
- mtu: header.IPv4MinimumMTU - 1,
- transportHeaderLength: 0,
- payloadSize: 500,
- allowPackets: 0,
- outgoingErrors: 1,
- mockError: nil,
- wantError: &tcpip.ErrInvalidEndpointState{},
- },
- }
-
- for _, ft := range tests {
- t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
- ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
- r := buildRoute(t, ep)
- err := r.WritePacket(stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- }, pkt)
- if diff := cmp.Diff(ft.wantError, err); diff != "" {
- t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff)
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
- t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
- }
- if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
- }
- })
- }
-}
-
-func TestInvalidFragments(t *testing.T) {
- const (
- nicID = 1
- linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
- tos = 0
- ident = 1
- ttl = 48
- protocol = 6
- )
-
- payloadGen := func(payloadLen int) []byte {
- payload := make([]byte, payloadLen)
- for i := 0; i < len(payload); i++ {
- payload[i] = 0x30
- }
- return payload
- }
-
- type fragmentData struct {
- ipv4fields header.IPv4Fields
- // 0 means insert the correct IHL. Non 0 means override the correct IHL.
- overrideIHL int // For 0 use 1 as it is an int and will be divided by 4.
- payload []byte
- autoChecksum bool // If true, the Checksum field will be overwritten.
- }
-
- tests := []struct {
- name string
- fragments []fragmentData
- wantMalformedIPPackets uint64
- wantMalformedFragments uint64
- }{
- {
- name: "IHL and TotalLength zero, FragmentOffset non-zero",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: 0,
- ID: ident,
- Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments,
- FragmentOffset: 59776,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- overrideIHL: 1, // See note above.
- payload: payloadGen(12),
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 0,
- },
- {
- name: "IHL and TotalLength zero, FragmentOffset zero",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: 0,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- overrideIHL: 1, // See note above.
- payload: payloadGen(12),
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 0,
- },
- {
- // Payload 17 octets and Fragment offset 65520
- // Leading to the fragment end to be past 65536.
- name: "fragment ends past 65536",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 17,
- ID: ident,
- Flags: 0,
- FragmentOffset: 65520,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(17),
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 1,
- },
- {
- // Payload 16 octets and fragment offset 65520
- // Leading to the fragment end to be exactly 65536.
- name: "fragment ends exactly at 65536",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 16,
- ID: ident,
- Flags: 0,
- FragmentOffset: 65520,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(16),
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 0,
- wantMalformedFragments: 0,
- },
- {
- name: "IHL less than IPv4 minimum size",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 28,
- ID: ident,
- Flags: 0,
- FragmentOffset: 1944,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(28),
- overrideIHL: header.IPv4MinimumSize - 12,
- autoChecksum: true,
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize - 12,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(28),
- overrideIHL: header.IPv4MinimumSize - 12,
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 2,
- wantMalformedFragments: 0,
- },
- {
- name: "fragment with short TotalLength and extra payload",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 28,
- ID: ident,
- Flags: 0,
- FragmentOffset: 28816,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(28),
- overrideIHL: header.IPv4MinimumSize + 4,
- autoChecksum: true,
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 4,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(28),
- overrideIHL: header.IPv4MinimumSize + 4,
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 1,
- },
- {
- name: "multiple fragments with More Fragments flag set to false",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 8,
- ID: ident,
- Flags: 0,
- FragmentOffset: 128,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(8),
- autoChecksum: true,
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 8,
- ID: ident,
- Flags: 0,
- FragmentOffset: 8,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(8),
- autoChecksum: true,
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 8,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(8),
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocol,
- },
- })
- e := channel.New(0, 1500, linkAddr)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
- }
-
- for _, f := range test.fragments {
- pktSize := header.IPv4MinimumSize + len(f.payload)
- hdr := buffer.NewPrependable(pktSize)
-
- ip := header.IPv4(hdr.Prepend(pktSize))
- ip.Encode(&f.ipv4fields)
- if want, got := len(f.payload), copy(ip[header.IPv4MinimumSize:], f.payload); want != got {
- t.Fatalf("copied %d bytes, expected %d bytes.", got, want)
- }
- // Encode sets this up correctly. If we want a different value for
- // testing then we need to overwrite the good value.
- if f.overrideIHL != 0 {
- ip.SetHeaderLength(uint8(f.overrideIHL))
- // If we are asked to add options (type not specified) then pad
- // with 0 (EOL). RFC 791 page 23 says "The padding is zero".
- for i := header.IPv4MinimumSize; i < f.overrideIHL; i++ {
- ip[i] = byte(header.IPv4OptionListEndType)
- }
- }
-
- if f.autoChecksum {
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
- }
-
- vv := hdr.View().ToVectorisedView()
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- }
-
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
- t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
- }
- if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
- t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
- }
- })
- }
-}
-
-func TestFragmentReassemblyTimeout(t *testing.T) {
- const (
- nicID = 1
- linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
- tos = 0
- ident = 1
- ttl = 48
- protocol = 99
- data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
- )
-
- type fragmentData struct {
- ipv4fields header.IPv4Fields
- payload []byte
- }
-
- tests := []struct {
- name string
- fragments []fragmentData
- expectICMP bool
- }{
- {
- name: "first fragment only",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 16,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[:16],
- },
- },
- expectICMP: true,
- },
- {
- name: "two first fragments",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 16,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[:16],
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 16,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[:16],
- },
- },
- expectICMP: true,
- },
- {
- name: "second fragment only",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
- ID: ident,
- Flags: 0,
- FragmentOffset: 8,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[16:],
- },
- },
- expectICMP: false,
- },
- {
- name: "two fragments with a gap",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 8,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[:8],
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
- ID: ident,
- Flags: 0,
- FragmentOffset: 16,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[16:],
- },
- },
- expectICMP: true,
- },
- {
- name: "two fragments with a gap in reverse order",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
- ID: ident,
- Flags: 0,
- FragmentOffset: 16,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[16:],
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 8,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[:8],
- },
- },
- expectICMP: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocol,
- },
- Clock: clock,
- })
- e := channel.New(1, 1500, linkAddr)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- }})
-
- var firstFragmentSent buffer.View
- for _, f := range test.fragments {
- pktSize := header.IPv4MinimumSize
- hdr := buffer.NewPrependable(pktSize)
-
- ip := header.IPv4(hdr.Prepend(pktSize))
- ip.Encode(&f.ipv4fields)
-
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
-
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(f.payload)
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
-
- if firstFragmentSent == nil && ip.FragmentOffset() == 0 {
- firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
- }
-
- e.InjectInbound(header.IPv4ProtocolNumber, pkt)
- }
-
- clock.Advance(ipv4.ReassembleTimeout)
-
- reply, ok := e.Read()
- if !test.expectICMP {
- if ok {
- t.Fatalf("unexpected ICMP error message received: %#v", reply)
- }
- return
- }
- if !ok {
- t.Fatal("expected ICMP error message missing")
- }
- if firstFragmentSent == nil {
- t.Fatalf("unexpected ICMP error message received: %#v", reply)
- }
-
- checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
- checker.SrcAddr(addr2),
- checker.DstAddr(addr1),
- checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4TimeExceeded),
- checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout),
- checker.ICMPv4Checksum(),
- checker.ICMPv4Payload([]byte(firstFragmentSent)),
- ),
- )
- })
- }
-}
-
-// TestReceiveFragments feeds fragments in through the incoming packet path to
-// test reassembly
-func TestReceiveFragments(t *testing.T) {
- const (
- nicID = 1
-
- addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
- addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
- addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3
- )
-
- // Build and return a UDP header containing payload.
- udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View {
- payload := buffer.NewView(payloadLen)
- for i := 0; i < len(payload); i++ {
- payload[i] = uint8(i) * multiplier
- }
-
- udpLength := header.UDPMinimumSize + len(payload)
-
- hdr := buffer.NewPrependable(udpLength)
- u := header.UDP(hdr.Prepend(udpLength))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: uint16(udpLength),
- })
- copy(u.Payload(), payload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
- sum = header.Checksum(payload, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
- return hdr.View()
- }
-
- // UDP header plus a payload of 0..256
- ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2)
- udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:]
- ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2)
- udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:]
- // UDP header plus a payload of 0..256 in increments of 2.
- ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2)
- udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:]
- // UDP header plus a payload of 0..256 in increments of 3.
- // Used to test cases where the fragment blocks are not a multiple of
- // the fragment block size of 8 (RFC 791 section 3.1 page 14).
- ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2)
- udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:]
- // Used to test the max reassembled IPv4 payload length.
- ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2)
- udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:]
-
- type fragmentData struct {
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- id uint16
- flags uint8
- fragmentOffset uint16
- payload buffer.View
- }
-
- tests := []struct {
- name string
- fragments []fragmentData
- expectedPayloads [][]byte
- }{
- {
- name: "No fragmentation",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2,
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "No fragmentation with size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 0,
- payload: ipv4Payload3Addr1ToAddr2,
- },
- },
- expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
- },
- {
- name: "More fragments without payload",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2,
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Non-zero fragment offset without payload",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 8,
- payload: ipv4Payload1Addr1ToAddr2,
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments out of order",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments with last fragment size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload3Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload3Addr1ToAddr2[64:],
- },
- },
- expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
- },
- {
- name: "Two fragments with first fragment size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload3Addr1ToAddr2[:63],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 63,
- payload: ipv4Payload3Addr1ToAddr2[63:],
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Second fragment has MoreFlags set",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with different IDs",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 2,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two interleaved fragmented packets",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 2,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload2Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 2,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload2Addr1ToAddr2[64:],
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
- },
- {
- name: "Two interleaved fragmented packets from different sources but with same ID",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr3,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr3ToAddr2[:32],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- {
- srcAddr: addr3,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 32,
- payload: ipv4Payload1Addr3ToAddr2[32:],
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
- },
- {
- name: "Fragment without followup",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments reassembled into a maximum UDP packet",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload4Addr1ToAddr2[:65512],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 65512,
- payload: ipv4Payload4Addr1ToAddr2[65512:],
- },
- },
- expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
- },
- {
- name: "Two fragments with MF flag reassembled into a maximum UDP packet",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload4Addr1ToAddr2[:65512],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 65512,
- payload: ipv4Payload4Addr1ToAddr2[65512:],
- },
- },
- expectedPayloads: nil,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- // Setup a stack and endpoint.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- RawFactory: raw.EndpointFactory{},
- })
- e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00"))
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
- }
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err)
- }
- defer ep.Close()
-
- bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%+v): %s", bindAddr, err)
- }
-
- // Bring up a raw endpoint so we can examine network headers.
- epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */)
- if err != nil {
- t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err)
- }
- defer epRaw.Close()
-
- // Prepare and send the fragments.
- for _, frag := range test.fragments {
- hdr := buffer.NewPrependable(header.IPv4MinimumSize)
-
- // Serialize IPv4 fixed header.
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)),
- ID: frag.id,
- Flags: frag.flags,
- FragmentOffset: frag.fragmentOffset,
- TTL: 64,
- Protocol: uint8(header.UDPProtocolNumber),
- SrcAddr: frag.srcAddr,
- DstAddr: frag.dstAddr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(frag.payload)
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- }
-
- if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
- t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
- }
-
- for i, expectedPayload := range test.expectedPayloads {
- // Check UDP payload delivered by UDP endpoint.
- var buf bytes.Buffer
- result, err := ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("(i=%d) ep.Read: %s", i, err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: len(expectedPayload),
- Total: len(expectedPayload),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff)
- }
- if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" {
- t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff)
- }
-
- // Check IPv4 header in packet delivered by raw endpoint.
- buf.Reset()
- result, err = epRaw.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("(i=%d) epRaw.Read: %s", i, err)
- }
- // Reassambly does not take care of checksum. Here we write our own
- // check routine instead of using checker.IPv4.
- ip := header.IPv4(buf.Bytes())
- for _, check := range []checker.NetworkChecker{
- checker.FragmentFlags(0),
- checker.FragmentOffset(0),
- checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))),
- } {
- check(t, []header.Network{ip})
- }
- }
-
- res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
- }
- })
- }
-}
-
-func TestWriteStats(t *testing.T) {
- const nPackets = 3
-
- tests := []struct {
- name string
- setup func(*testing.T, *stack.Stack)
- allowPackets int
- expectSent int
- expectOutputDropped int
- expectPostroutingDropped int
- expectWritten int
- }{
- {
- name: "Accept all",
- // No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: math.MaxInt32,
- expectSent: nPackets,
- expectOutputDropped: 0,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Accept all with error",
- // No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: nPackets - 1,
- expectSent: nPackets - 1,
- expectOutputDropped: 0,
- expectPostroutingDropped: 0,
- expectWritten: nPackets - 1,
- }, {
- name: "Drop all with Output chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Output DROP rule.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %s", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectOutputDropped: nPackets,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Drop all with Postrouting chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %s", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectOutputDropped: 0,
- expectPostroutingDropped: nPackets,
- expectWritten: nPackets,
- }, {
- name: "Drop some with Output chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Output DROP rule that matches only 1
- // of the 3 packets.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- // We'll match and DROP the last packet.
- ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
- // Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %s", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectOutputDropped: 1,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Drop some with Postrouting chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Postrouting DROP rule that matches only 1
- // of the 3 packets.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
- // We'll match and DROP the last packet.
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
- // Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %s", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectOutputDropped: 0,
- expectPostroutingDropped: 1,
- expectWritten: nPackets,
- },
- }
-
- // Parameterize the tests to run with both WritePacket and WritePackets.
- writers := []struct {
- name string
- writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error)
- }{
- {
- name: "WritePacket",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
- nWritten := 0
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil {
- return nWritten, err
- }
- nWritten++
- }
- return nWritten, nil
- },
- }, {
- name: "WritePackets",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
- return rt.WritePackets(pkts, stack.NetworkHeaderParams{})
- },
- },
- }
-
- for _, writer := range writers {
- t.Run(writer.name, func(t *testing.T) {
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
- rt := buildRoute(t, ep)
-
- var pkts stack.PacketBufferList
- for i := 0; i < nPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
- Data: buffer.NewView(0).ToVectorisedView(),
- })
- pkt.TransportHeader().Push(header.UDPMinimumSize)
- pkts.PushBack(pkt)
- }
-
- test.setup(t, rt.Stack())
-
- nWritten, _ := writer.writePackets(rt, pkts)
-
- if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
- t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
- }
- if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
- t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
- }
- if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
- t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
- }
- if nWritten != test.expectWritten {
- t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
- }
- })
- }
- })
- }
-}
-
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatalf("CreateNIC(1, _) failed: %s", err)
- }
- const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
- )
- if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
- }
- {
- mask := tcpip.AddressMask(header.IPv4Broadcast)
- subnet, err := tcpip.NewSubnet(dst, mask)
- if err != nil {
- t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }})
- }
- rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err)
- }
- return rt
-}
-
-// limitedMatcher is an iptables matcher that matches after a certain number of
-// packets are checked against it.
-type limitedMatcher struct {
- limit int
-}
-
-// Name implements Matcher.Name.
-func (*limitedMatcher) Name() string {
- return "limitedMatcher"
-}
-
-// Match implements Matcher.Match.
-func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) {
- if lm.limit == 0 {
- return true, false
- }
- lm.limit--
- return false, false
-}
-
-func TestPacketQueing(t *testing.T) {
- const nicID = 1
-
- var (
- host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
-
- host1IPv4Addr = tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
- PrefixLen: 24,
- },
- }
- host2IPv4Addr = tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
- PrefixLen: 8,
- },
- }
- )
-
- tests := []struct {
- name string
- rxPkt func(*channel.Endpoint)
- checkResp func(*testing.T, *channel.Endpoint)
- }{
- {
- name: "ICMP Error",
- rxPkt: func(e *channel.Endpoint) {
- hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize)
- u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: header.UDPMinimumSize,
- })
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
- sum = header.Checksum(header.UDP([]byte{}), sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize,
- TTL: ipv4.DefaultTTL,
- Protocol: uint8(udp.ProtocolNumber),
- SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
- e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- },
- checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != header.IPv4ProtocolNumber {
- t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
- }
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
- }
- checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
- checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
- },
- },
-
- {
- name: "Ping",
- rxPkt: func(e *channel.Endpoint) {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4Echo)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(^header.Checksum(pkt, 0))
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(icmp.ProtocolNumber4),
- TTL: ipv4.DefaultTTL,
- SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- },
- checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != header.IPv4ProtocolNumber {
- t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
- }
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
- }
- checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
- checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4EchoReply),
- checker.ICMPv4Code(header.ICMPv4UnusedCode)))
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(1, defaultMTU, host1NICLinkAddr)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: nicID,
- },
- })
-
- // Receive a packet to trigger link resolution before a response is sent.
- test.rxPkt(e)
-
- // Wait for a ARP request since link address resolution should be
- // performed.
- {
- p, ok := e.ReadContext(context.Background())
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != arp.ProtocolNumber {
- t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
- }
- if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
- }
- rep := header.ARP(p.Pkt.NetworkHeader().View())
- if got := rep.Op(); got != header.ARPRequest {
- t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr {
- t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr)
- }
- if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address {
- t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address)
- }
- if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address {
- t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address)
- }
- }
-
- // Send an ARP reply to complete link address resolution.
- {
- hdr := buffer.View(make([]byte, header.ARPSize))
- packet := header.ARP(hdr)
- packet.SetIPv4OverEthernet()
- packet.SetOp(header.ARPReply)
- copy(packet.HardwareAddressSender(), host2NICLinkAddr)
- copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address)
- copy(packet.HardwareAddressTarget(), host1NICLinkAddr)
- copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address)
- e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.ToVectorisedView(),
- }))
- }
-
- // Expect the response now that the link address has resolved.
- test.checkResp(t, e)
-
- // Since link resolution was already performed, it shouldn't be performed
- // again.
- test.rxPkt(e)
- test.checkResp(t, e)
- })
- }
-}
-
-// TestCloseLocking test that lock ordering is followed when closing an
-// endpoint.
-func TestCloseLocking(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
-
- iterations = 1000
- )
-
- var (
- src = tcptestutil.MustParse4("16.0.0.1")
- dst = tcptestutil.MustParse4("16.0.0.2")
- )
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
-
- // Perform NAT so that the endoint tries to search for a sibling endpoint
- // which ends up taking the protocol and endpoint lock (in that order).
- table := stack.Table{
- Rules: []stack.Rule{
- {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}},
- {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: 0,
- stack.Input: 1,
- stack.Forward: stack.HookUnset,
- stack.Output: 2,
- stack.Postrouting: 3,
- },
- Underflows: [stack.NumHooks]int{
- stack.Prerouting: 0,
- stack.Input: 1,
- stack.Forward: stack.HookUnset,
- stack.Output: 2,
- stack.Postrouting: 3,
- },
- }
- if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil {
- t.Fatalf("s.IPTables().ReplaceTable(...): %s", err)
- }
-
- e := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID1, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
-
- if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err)
- }
-
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- NIC: nicID1,
- }})
-
- var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatal(err)
- }
- defer ep.Close()
-
- addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53}
- if err := ep.Connect(addr); err != nil {
- t.Errorf("ep.Connect(%#v): %s", addr, err)
- }
-
- var wg sync.WaitGroup
- defer wg.Wait()
-
- // Writing packets should trigger NAT which requires the stack to search the
- // protocol for network endpoints with the destination address.
- //
- // Creating and removing interfaces should modify the protocol and endpoint
- // which requires taking the locks of each.
- //
- // We expect the protocol > endpoint lock ordering to be followed here.
- wg.Add(2)
- go func() {
- defer wg.Done()
-
- data := []byte{1, 2, 3, 4}
-
- for i := 0; i < iterations; i++ {
- var r bytes.Reader
- r.Reset(data)
- if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Errorf("ep.Write(_, _): %s", err)
- return
- } else if want := int64(len(data)); n != want {
- t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
- return
- }
- }
- }()
- go func() {
- defer wg.Done()
-
- for i := 0; i < iterations; i++ {
- if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
- t.Errorf("CreateNIC(%d, _): %s", nicID2, err)
- return
- }
- if err := s.RemoveNIC(nicID2); err != nil {
- t.Errorf("RemoveNIC(%d): %s", nicID2, err)
- return
- }
- }
- }()
-}
diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go
deleted file mode 100644
index d1f9e3cf5..000000000
--- a/pkg/tcpip/network/ipv4/stats_test.go
+++ /dev/null
@@ -1,99 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv4
-
-import (
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-var _ stack.NetworkInterface = (*testInterface)(nil)
-
-type testInterface struct {
- stack.NetworkInterface
- nicID tcpip.NICID
-}
-
-func (t *testInterface) ID() tcpip.NICID {
- return t.nicID
-}
-
-func knownNICIDs(proto *protocol) []tcpip.NICID {
- var nicIDs []tcpip.NICID
-
- for k := range proto.mu.eps {
- nicIDs = append(nicIDs, k)
- }
-
- return nicIDs
-}
-
-func TestClearEndpointFromProtocolOnClose(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- nic := testInterface{nicID: 1}
- ep := proto.NewEndpoint(&nic, nil).(*endpoint)
- var nicIDs []tcpip.NICID
-
- proto.mu.Lock()
- foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
-
- if !hasEndpointBeforeClose {
- t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs)
- }
- if foundEP != ep {
- t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
- }
-
- ep.Close()
-
- proto.mu.Lock()
- _, hasEP := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
- if hasEP {
- t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
- }
-}
-
-func TestMultiCounterStatsInitialization(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- var nic testInterface
- ep := proto.NewEndpoint(&nic, nil).(*endpoint)
- // At this point, the Stack's stats and the NetworkEndpoint's stats are
- // expected to be bound by a MultiCounterStat.
- refStack := s.Stats()
- refEP := ep.stats.localStats
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil {
- t.Error(err)
- }
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil {
- t.Error(err)
- }
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil {
- t.Error(err)
- }
-}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
deleted file mode 100644
index f99cbf8f3..000000000
--- a/pkg/tcpip/network/ipv6/BUILD
+++ /dev/null
@@ -1,72 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "ipv6",
- srcs = [
- "dhcpv6configurationfromndpra_string.go",
- "icmp.go",
- "ipv6.go",
- "mld.go",
- "ndp.go",
- "stats.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/network/hash",
- "//pkg/tcpip/network/internal/fragmentation",
- "//pkg/tcpip/network/internal/ip",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "ipv6_test",
- size = "small",
- srcs = [
- "icmp_test.go",
- "ipv6_test.go",
- "ndp_test.go",
- ],
- library = ":ipv6",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/internal/testutil",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "ipv6_x_test",
- size = "small",
- srcs = ["mld_test.go"],
- deps = [
- ":ipv6",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- ],
-)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
deleted file mode 100644
index e457be3cf..000000000
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ /dev/null
@@ -1,1716 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv6
-
-import (
- "bytes"
- "context"
- "net"
- "reflect"
- "strings"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- nicID = 1
-
- linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
-
- defaultChannelSize = 1
- defaultMTU = 65536
-
- // Extra time to use when waiting for an async event to occur.
- defaultAsyncPositiveEventTimeout = 30 * time.Second
-
- arbitraryHopLimit = 42
-)
-
-var (
- lladdr0 = header.LinkLocalAddr(linkAddr0)
- lladdr1 = header.LinkLocalAddr(linkAddr1)
- lladdr2 = header.LinkLocalAddr(linkAddr2)
-)
-
-type stubLinkEndpoint struct {
- stack.LinkEndpoint
-}
-
-func (*stubLinkEndpoint) MTU() uint32 {
- return defaultMTU
-}
-
-func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- // Indicate that resolution for link layer addresses is required to send
- // packets over this link. This is needed so the NIC knows to allocate a
- // neighbor table.
- return stack.CapabilityResolutionRequired
-}
-
-func (*stubLinkEndpoint) MaxHeaderLength() uint16 {
- return 0
-}
-
-func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
- return ""
-}
-
-func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
- return nil
-}
-
-func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {}
-
-type stubDispatcher struct {
- stack.TransportDispatcher
-}
-
-func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition {
- return stack.TransportPacketHandled
-}
-
-var _ stack.NetworkInterface = (*testInterface)(nil)
-
-type testInterface struct {
- stack.LinkEndpoint
-
- probeCount int
- confirmationCount int
-
- nicID tcpip.NICID
-}
-
-func (*testInterface) ID() tcpip.NICID {
- return nicID
-}
-
-func (*testInterface) IsLoopback() bool {
- return false
-}
-
-func (*testInterface) Name() string {
- return ""
-}
-
-func (*testInterface) Enabled() bool {
- return true
-}
-
-func (*testInterface) Promiscuous() bool {
- return false
-}
-
-func (*testInterface) Spoofing() bool {
- return false
-}
-
-func (t *testInterface) WritePacket(r *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- return t.LinkEndpoint.WritePacket(r.Fields(), protocol, pkt)
-}
-
-func (t *testInterface) WritePackets(r *stack.Route, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- return t.LinkEndpoint.WritePackets(r.Fields(), pkts, protocol)
-}
-
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- var r stack.RouteInfo
- r.NetProto = protocol
- r.RemoteLinkAddress = remoteLinkAddr
- return t.LinkEndpoint.WritePacket(r, protocol, pkt)
-}
-
-func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error {
- t.probeCount++
- return nil
-}
-
-func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error {
- t.confirmationCount++
- return nil
-}
-
-func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
- return tcpip.AddressWithPrefix{}, nil
-}
-
-func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
- return false
-}
-
-func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6, hopLimit uint8, includeRouterAlert bool) {
- var extensionHeaders header.IPv6ExtHdrSerializer
- if includeRouterAlert {
- extensionHeaders = header.IPv6ExtHdrSerializer{
- header.IPv6SerializableHopByHopExtHdr{
- &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
- },
- }
- }
- ip := buffer.NewView(header.IPv6MinimumSize + extensionHeaders.Length())
- header.IPv6(ip).Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: hopLimit,
- SrcAddr: src,
- DstAddr: dst,
- ExtensionHeaders: extensionHeaders,
- })
-
- vv := ip.ToVectorisedView()
- vv.AppendView(buffer.View(icmp))
- ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
-}
-
-func TestICMPCounts(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- })
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- netProto := s.NetworkProtocolInstance(ProtocolNumber)
- if netProto == nil {
- t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
- }
- ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{})
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
- }
- addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
- } else {
- ep.DecRef()
- }
-
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
-
- types := []struct {
- typ header.ICMPv6Type
- hopLimit uint8
- includeRouterAlert bool
- size int
- extraData []byte
- }{
- {
- typ: header.ICMPv6DstUnreachable,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6DstUnreachableMinimumSize,
- },
- {
- typ: header.ICMPv6PacketTooBig,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6PacketTooBigMinimumSize,
- },
- {
- typ: header.ICMPv6TimeExceeded,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6ParamProblem,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6EchoRequest,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6EchoMinimumSize,
- },
- {
- typ: header.ICMPv6EchoReply,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6EchoMinimumSize,
- },
- {
- typ: header.ICMPv6RouterSolicit,
- hopLimit: header.NDPHopLimit,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6RouterAdvert,
- hopLimit: header.NDPHopLimit,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- },
- {
- typ: header.ICMPv6NeighborSolicit,
- hopLimit: header.NDPHopLimit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- },
- {
- typ: header.ICMPv6NeighborAdvert,
- hopLimit: header.NDPHopLimit,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- },
- {
- typ: header.ICMPv6RedirectMsg,
- hopLimit: header.NDPHopLimit,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6MulticastListenerQuery,
- hopLimit: header.MLDHopLimit,
- includeRouterAlert: true,
- size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
- },
- {
- typ: header.ICMPv6MulticastListenerReport,
- hopLimit: header.MLDHopLimit,
- includeRouterAlert: true,
- size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
- },
- {
- typ: header.ICMPv6MulticastListenerDone,
- hopLimit: header.MLDHopLimit,
- includeRouterAlert: true,
- size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
- },
- {
- typ: 255, /* Unrecognized */
- size: 50,
- },
- }
-
- for _, typ := range types {
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp[:typ.size],
- Src: lladdr0,
- Dst: lladdr1,
- PayloadCsum: header.Checksum(typ.extraData, 0 /* initial */),
- PayloadLen: len(typ.extraData),
- }))
- handleICMPInIPv6(ep, lladdr1, lladdr0, icmp, typ.hopLimit, typ.includeRouterAlert)
- }
-
- // Construct an empty ICMP packet so that
- // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
- handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)), arbitraryHopLimit, false)
-
- icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
- visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
- if got, want := s.Value(), uint64(1); got != want {
- t.Errorf("got %s = %d, want = %d", name, got, want)
- }
- })
- if t.Failed() {
- t.Logf("stats:\n%+v", s.Stats())
- }
-}
-
-func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) {
- t := v.Type()
- for i := 0; i < v.NumField(); i++ {
- v := v.Field(i)
- if s, ok := v.Interface().(*tcpip.StatCounter); ok {
- f(t.Field(i).Name, s)
- } else {
- visitStats(v, f)
- }
- }
-}
-
-type testContext struct {
- s0 *stack.Stack
- s1 *stack.Stack
-
- linkEP0 *channel.Endpoint
- linkEP1 *channel.Endpoint
-}
-
-type endpointWithResolutionCapability struct {
- stack.LinkEndpoint
-}
-
-func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities {
- return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired
-}
-
-func newTestContext(t *testing.T) *testContext {
- c := &testContext{
- s0: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- }),
- s1: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- }),
- }
-
- c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0)
-
- wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
- if testing.Verbose() {
- wrappedEP0 = sniffer.New(wrappedEP0)
- }
- if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
- t.Fatalf("CreateNIC s0: %v", err)
- }
- if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress lladdr0: %v", err)
- }
-
- c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
- wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
- if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
- if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil {
- t.Fatalf("AddAddress lladdr1: %v", err)
- }
-
- subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- c.s0.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet0,
- NIC: nicID,
- }},
- )
- subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
- if err != nil {
- t.Fatal(err)
- }
- c.s1.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet1,
- NIC: nicID,
- }},
- )
-
- t.Cleanup(func() {
- if err := c.s0.RemoveNIC(nicID); err != nil {
- t.Errorf("c.s0.RemoveNIC(%d): %s", nicID, err)
- }
- if err := c.s1.RemoveNIC(nicID); err != nil {
- t.Errorf("c.s1.RemoveNIC(%d): %s", nicID, err)
- }
-
- c.linkEP0.Close()
- c.linkEP1.Close()
- })
-
- return c
-}
-
-type routeArgs struct {
- src, dst *channel.Endpoint
- typ header.ICMPv6Type
- remoteLinkAddr tcpip.LinkAddress
-}
-
-func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
- t.Helper()
-
- pi, _ := args.src.ReadContext(context.Background())
-
- {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()),
- })
- args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt)
- }
-
- if pi.Proto != ProtocolNumber {
- t.Errorf("unexpected protocol number %d", pi.Proto)
- return
- }
-
- if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr {
- t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
- }
-
- // Pull the full payload since network header. Needed for header.IPv6 to
- // extract its payload.
- ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader()))
- transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
- if transProto != header.ICMPv6ProtocolNumber {
- t.Errorf("unexpected transport protocol number %d", transProto)
- return
- }
- icmpv6 := header.ICMPv6(ipv6.Payload())
- if got, want := icmpv6.Type(), args.typ; got != want {
- t.Errorf("got ICMPv6 type = %d, want = %d", got, want)
- return
- }
- if fn != nil {
- fn(t, icmpv6)
- }
-}
-
-func TestLinkResolution(t *testing.T) {
- c := newTestContext(t)
-
- r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
- }
- defer r.Release()
-
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: r.LocalAddress(),
- Dst: r.RemoteAddress(),
- }))
-
- // We can't send our payload directly over the route because that
- // doesn't provoke NDP discovery.
- var wq waiter.Queue
- ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err)
- }
-
- {
- var r bytes.Reader
- r.Reset(hdr.View())
- if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil {
- t.Fatalf("ep.Write(_): %s", err)
- }
- }
- for _, args := range []routeArgs{
- {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
- {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
- } {
- routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
- if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want {
- t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want)
- }
- })
- }
-
- for _, args := range []routeArgs{
- {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest},
- {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply},
- } {
- routeICMPv6Packet(t, args, nil)
- }
-}
-
-func TestICMPChecksumValidationSimple(t *testing.T) {
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- extraData []byte
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- routerOnly bool
- }{
- {
- name: "DstUnreachable",
- typ: header.ICMPv6DstUnreachable,
- size: header.ICMPv6DstUnreachableMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.DstUnreachable
- },
- },
- {
- name: "PacketTooBig",
- typ: header.ICMPv6PacketTooBig,
- size: header.ICMPv6PacketTooBigMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.PacketTooBig
- },
- },
- {
- name: "TimeExceeded",
- typ: header.ICMPv6TimeExceeded,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.TimeExceeded
- },
- },
- {
- name: "ParamProblem",
- typ: header.ICMPv6ParamProblem,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.ParamProblem
- },
- },
- {
- name: "EchoRequest",
- typ: header.ICMPv6EchoRequest,
- size: header.ICMPv6EchoMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoRequest
- },
- },
- {
- name: "EchoReply",
- typ: header.ICMPv6EchoReply,
- size: header.ICMPv6EchoMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoReply
- },
- },
- {
- name: "RouterSolicit",
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- },
- // Hosts MUST silently discard any received Router Solicitation messages.
- routerOnly: true,
- },
- {
- name: "RouterAdvert",
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- },
- },
- {
- name: "NeighborSolicit",
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- },
- },
- {
- name: "NeighborAdvert",
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- },
- },
- {
- name: "RedirectMsg",
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
- },
- },
- }
-
- for _, typ := range types {
- for _, isRouter := range []bool{false, true} {
- name := typ.name
- if isRouter {
- name += " (Router)"
- }
- t.Run(name, func(t *testing.T) {
- e := channel.New(0, 1280, linkAddr0)
-
- // Indicate that resolution for link layer addresses is required to
- // send packets over this link. This is needed so the NIC knows to
- // allocate a neighbor table.
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- if isRouter {
- // Enabling forwarding makes the stack act as a router.
- s.SetForwarding(ProtocolNumber, true)
- }
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
-
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- handleIPv6Payload := func(checksum bool) {
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- if checksum {
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: lladdr1,
- Dst: lladdr0,
- }))
- }
- ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
- })
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- routerOnly := stats.RouterOnlyPacketsDroppedByHost
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := routerOnly.Value(); got != 0 {
- t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Router only count should not have increased.
- if got := routerOnly.Value(); got != 0 {
- t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- if !isRouter && typ.routerOnly {
- // Router only count should have increased.
- if got := routerOnly.Value(); got != 1 {
- t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got)
- }
- }
- })
- }
- }
-}
-
-func TestICMPChecksumValidationWithPayload(t *testing.T) {
- const simpleBodySize = 64
- simpleBody := func(view buffer.View) {
- for i := 0; i < simpleBodySize; i++ {
- view[i] = uint8(i)
- }
- }
-
- const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
- errorICMPBody := func(view buffer.View) {
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- TransportProtocol: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
- })
- simpleBody(view[header.IPv6MinimumSize:])
- }
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- payloadSize int
- payload func(buffer.View)
- }{
- {
- "DstUnreachable",
- header.ICMPv6DstUnreachable,
- header.ICMPv6DstUnreachableMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.DstUnreachable
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "PacketTooBig",
- header.ICMPv6PacketTooBig,
- header.ICMPv6PacketTooBigMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.PacketTooBig
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "TimeExceeded",
- header.ICMPv6TimeExceeded,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.TimeExceeded
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "ParamProblem",
- header.ICMPv6ParamProblem,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.ParamProblem
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "EchoRequest",
- header.ICMPv6EchoRequest,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoRequest
- },
- simpleBodySize,
- simpleBody,
- },
- {
- "EchoReply",
- header.ICMPv6EchoReply,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoReply
- },
- simpleBodySize,
- simpleBody,
- },
- }
-
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr0)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
-
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
- icmpSize := size + payloadSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize))
- icmpHdr.SetType(typ)
- payloadFn(icmpHdr.Payload())
-
- if checksum {
- icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmpHdr,
- Src: lladdr1,
- Dst: lladdr0,
- }))
- }
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(icmpSize),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- })
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got = %d, want = 0", got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got = %d, want = 0", got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got = %d, want = 0", got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
- const simpleBodySize = 64
- simpleBody := func(view buffer.View) {
- for i := 0; i < simpleBodySize; i++ {
- view[i] = uint8(i)
- }
- }
-
- const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
- errorICMPBody := func(view buffer.View) {
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- TransportProtocol: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
- })
- simpleBody(view[header.IPv6MinimumSize:])
- }
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- payloadSize int
- payload func(buffer.View)
- }{
- {
- "DstUnreachable",
- header.ICMPv6DstUnreachable,
- header.ICMPv6DstUnreachableMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.DstUnreachable
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "PacketTooBig",
- header.ICMPv6PacketTooBig,
- header.ICMPv6PacketTooBigMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.PacketTooBig
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "TimeExceeded",
- header.ICMPv6TimeExceeded,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.TimeExceeded
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "ParamProblem",
- header.ICMPv6ParamProblem,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.ParamProblem
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "EchoRequest",
- header.ICMPv6EchoRequest,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoRequest
- },
- simpleBodySize,
- simpleBody,
- },
- {
- "EchoReply",
- header.ICMPv6EchoReply,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoReply
- },
- simpleBodySize,
- simpleBody,
- },
- }
-
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr0)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
- icmpHdr := header.ICMPv6(hdr.Prepend(size))
- icmpHdr.SetType(typ)
-
- payload := buffer.NewView(payloadSize)
- payloadFn(payload)
-
- if checksum {
- icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmpHdr,
- Src: lladdr1,
- Dst: lladdr0,
- PayloadCsum: header.Checksum(payload, 0 /* initial */),
- PayloadLen: len(payload),
- }))
- }
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size + payloadSize),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
- })
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got = %d, want = 0", got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got = %d, want = 0", got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got = %d, want = 0", got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestLinkAddressRequest(t *testing.T) {
- const nicID = 1
-
- snaddr := header.SolicitedNodeAddr(lladdr0)
- mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
-
- tests := []struct {
- name string
- nicAddr tcpip.Address
- localAddr tcpip.Address
- remoteLinkAddr tcpip.LinkAddress
-
- expectedErr tcpip.Error
- expectedRemoteAddr tcpip.Address
- expectedRemoteLinkAddr tcpip.LinkAddress
- }{
- {
- name: "Unicast",
- nicAddr: lladdr1,
- localAddr: lladdr1,
- remoteLinkAddr: linkAddr1,
- expectedRemoteAddr: lladdr0,
- expectedRemoteLinkAddr: linkAddr1,
- },
- {
- name: "Multicast",
- nicAddr: lladdr1,
- localAddr: lladdr1,
- remoteLinkAddr: "",
- expectedRemoteAddr: snaddr,
- expectedRemoteLinkAddr: mcaddr,
- },
- {
- name: "Unicast with unspecified source",
- nicAddr: lladdr1,
- remoteLinkAddr: linkAddr1,
- expectedRemoteAddr: lladdr0,
- expectedRemoteLinkAddr: linkAddr1,
- },
- {
- name: "Multicast with unspecified source",
- nicAddr: lladdr1,
- remoteLinkAddr: "",
- expectedRemoteAddr: snaddr,
- expectedRemoteLinkAddr: mcaddr,
- },
- {
- name: "Unicast with unassigned address",
- localAddr: lladdr1,
- remoteLinkAddr: linkAddr1,
- expectedErr: &tcpip.ErrBadLocalAddress{},
- },
- {
- name: "Multicast with unassigned address",
- localAddr: lladdr1,
- remoteLinkAddr: "",
- expectedErr: &tcpip.ErrBadLocalAddress{},
- },
- {
- name: "Unicast with no local address available",
- remoteLinkAddr: linkAddr1,
- expectedErr: &tcpip.ErrNetworkUnreachable{},
- },
- {
- name: "Multicast with no local address available",
- remoteLinkAddr: "",
- expectedErr: &tcpip.ErrNetworkUnreachable{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
-
- linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
- if err := s.CreateNIC(nicID, linkEP); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err)
- }
- linkRes, ok := ep.(stack.LinkAddressResolver)
- if !ok {
- t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep)
- }
-
- if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
- }
- }
-
- {
- err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff)
- }
- }
-
- if test.expectedErr != nil {
- return
- }
-
- pkt, ok := linkEP.Read()
- if !ok {
- t.Fatal("expected to send a link address request")
- }
-
- var want stack.RouteInfo
- want.NetProto = ProtocolNumber
- want.RemoteLinkAddress = test.expectedRemoteLinkAddr
- if diff := cmp.Diff(want, pkt.Route, cmp.AllowUnexported(want)); diff != "" {
- t.Errorf("route info mismatch (-want +got):\n%s", diff)
- }
- checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
- checker.SrcAddr(lladdr1),
- checker.DstAddr(test.expectedRemoteAddr),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(lladdr0),
- checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}),
- ))
- })
- }
-}
-
-func TestPacketQueing(t *testing.T) {
- const nicID = 1
-
- var (
- host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
-
- host1IPv6Addr = tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("a::1").To16()),
- PrefixLen: 64,
- },
- }
- host2IPv6Addr = tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("a::2").To16()),
- PrefixLen: 64,
- },
- }
- )
-
- tests := []struct {
- name string
- rxPkt func(*channel.Endpoint)
- checkResp func(*testing.T, *channel.Endpoint)
- }{
- {
- name: "ICMP Error",
- rxPkt: func(e *channel.Endpoint) {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
- u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: header.UDPMinimumSize,
- })
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
- sum = header.Checksum(header.UDP([]byte{}), sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
- })
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- },
- checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != ProtocolNumber {
- t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
- }
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
- }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
- checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
- },
- },
-
- {
- name: "Ping",
- rxPkt: func(e *channel.Endpoint) {
- totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: host2IPv6Addr.AddressWithPrefix.Address,
- Dst: host1IPv6Addr.AddressWithPrefix.Address,
- }))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
- })
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- },
- checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != ProtocolNumber {
- t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
- }
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
- }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
- checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoReply),
- checker.ICMPv6Code(header.ICMPv6UnusedCode)))
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
-
- e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: nicID,
- },
- })
-
- // Receive a packet to trigger link resolution before a response is sent.
- test.rxPkt(e)
-
- // Wait for a neighbor solicitation since link address resolution should
- // be performed.
- {
- p, ok := e.ReadContext(context.Background())
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != ProtocolNumber {
- t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
- }
- snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address),
- checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}),
- ))
- }
-
- // Send a neighbor advertisement to complete link address resolution.
- {
- naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
- pkt := header.ICMPv6(hdr.Prepend(naSize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.MessageBody())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(true)
- na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
- na.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr),
- })
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: host2IPv6Addr.AddressWithPrefix.Address,
- Dst: host1IPv6Addr.AddressWithPrefix.Address,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: header.NDPHopLimit,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
- })
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- // Expect the response now that the link address has resolved.
- test.checkResp(t, e)
-
- // Since link resolution was already performed, it shouldn't be performed
- // again.
- test.rxPkt(e)
- test.checkResp(t, e)
- })
- }
-}
-
-func TestCallsToNeighborCache(t *testing.T) {
- tests := []struct {
- name string
- createPacket func() header.ICMPv6
- multicast bool
- source tcpip.Address
- destination tcpip.Address
- wantProbeCount int
- wantConfirmationCount int
- }{
- {
- name: "Unicast Neighbor Solicitation without source link-layer address option",
- createPacket: func() header.ICMPv6 {
- nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(nsSize))
- icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
- ns.SetTargetAddress(lladdr0)
- return icmp
- },
- source: lladdr1,
- destination: lladdr0,
- // "The source link-layer address option SHOULD be included in unicast
- // solicitations." - RFC 4861 section 4.3
- //
- // A Neighbor Advertisement needs to be sent in response, but the
- // Neighbor Cache shouldn't be updated since we have no useful
- // information about the sender.
- wantProbeCount: 0,
- },
- {
- name: "Unicast Neighbor Solicitation with source link-layer address option",
- createPacket: func() header.ICMPv6 {
- nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(nsSize))
- icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
- ns.SetTargetAddress(lladdr0)
- ns.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- })
- return icmp
- },
- source: lladdr1,
- destination: lladdr0,
- wantProbeCount: 1,
- },
- {
- name: "Multicast Neighbor Solicitation without source link-layer address option",
- createPacket: func() header.ICMPv6 {
- nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(nsSize))
- icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
- ns.SetTargetAddress(lladdr0)
- return icmp
- },
- source: lladdr1,
- destination: header.SolicitedNodeAddr(lladdr0),
- // "The source link-layer address option MUST be included in multicast
- // solicitations." - RFC 4861 section 4.3
- wantProbeCount: 0,
- },
- {
- name: "Multicast Neighbor Solicitation with source link-layer address option",
- createPacket: func() header.ICMPv6 {
- nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(nsSize))
- icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
- ns.SetTargetAddress(lladdr0)
- ns.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- })
- return icmp
- },
- source: lladdr1,
- destination: header.SolicitedNodeAddr(lladdr0),
- wantProbeCount: 1,
- },
- {
- name: "Unicast Neighbor Advertisement without target link-layer address option",
- createPacket: func() header.ICMPv6 {
- naSize := header.ICMPv6NeighborAdvertMinimumSize
- icmp := header.ICMPv6(buffer.NewView(naSize))
- icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(false)
- na.SetTargetAddress(lladdr1)
- return icmp
- },
- source: lladdr1,
- destination: lladdr0,
- // "When responding to unicast solicitations, the target link-layer
- // address option can be omitted since the sender of the solicitation has
- // the correct link-layer address; otherwise, it would not be able to
- // send the unicast solicitation in the first place."
- // - RFC 4861 section 4.4
- wantConfirmationCount: 1,
- },
- {
- name: "Unicast Neighbor Advertisement with target link-layer address option",
- createPacket: func() header.ICMPv6 {
- naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(naSize))
- icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(false)
- na.SetTargetAddress(lladdr1)
- na.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
- return icmp
- },
- source: lladdr1,
- destination: lladdr0,
- wantConfirmationCount: 1,
- },
- {
- name: "Multicast Neighbor Advertisement without target link-layer address option",
- createPacket: func() header.ICMPv6 {
- naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(naSize))
- icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
- na.SetSolicitedFlag(false)
- na.SetOverrideFlag(false)
- na.SetTargetAddress(lladdr1)
- return icmp
- },
- source: lladdr1,
- destination: header.IPv6AllNodesMulticastAddress,
- // "Target link-layer address MUST be included for multicast solicitations
- // in order to avoid infinite Neighbor Solicitation "recursion" when the
- // peer node does not have a cache entry to return a Neighbor
- // Advertisements message." - RFC 4861 section 4.4
- wantConfirmationCount: 0,
- },
- {
- name: "Multicast Neighbor Advertisement with target link-layer address option",
- createPacket: func() header.ICMPv6 {
- naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(naSize))
- icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
- na.SetSolicitedFlag(false)
- na.SetOverrideFlag(false)
- na.SetTargetAddress(lladdr1)
- na.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
- return icmp
- },
- source: lladdr1,
- destination: header.IPv6AllNodesMulticastAddress,
- wantConfirmationCount: 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- })
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- netProto := s.NetworkProtocolInstance(ProtocolNumber)
- if netProto == nil {
- t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
- }
-
- testInterface := testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}
- ep := netProto.NewEndpoint(&testInterface, &stubDispatcher{})
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
- }
- addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
- } else {
- ep.DecRef()
- }
-
- icmp := test.createPacket()
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: test.source,
- Dst: test.destination,
- }))
- handleICMPInIPv6(ep, test.source, test.destination, icmp, header.NDPHopLimit, false)
-
- // Confirm the endpoint calls the correct NUDHandler method.
- if testInterface.probeCount != test.wantProbeCount {
- t.Errorf("got testInterface.probeCount = %d, want = %d", testInterface.probeCount, test.wantProbeCount)
- }
- if testInterface.confirmationCount != test.wantConfirmationCount {
- t.Errorf("got testInterface.confirmationCount = %d, want = %d", testInterface.confirmationCount, test.wantConfirmationCount)
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv6/ipv6_state_autogen.go b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go
new file mode 100644
index 000000000..13d427822
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go
@@ -0,0 +1,136 @@
+// automatically generated by stateify.
+
+package ipv6
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (i *icmpv6DestinationUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv6.icmpv6DestinationUnreachableSockError"
+}
+
+func (i *icmpv6DestinationUnreachableSockError) StateFields() []string {
+ return []string{}
+}
+
+func (i *icmpv6DestinationUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+}
+
+func (i *icmpv6DestinationUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+}
+
+func (i *icmpv6DestinationNetworkUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv6.icmpv6DestinationNetworkUnreachableSockError"
+}
+
+func (i *icmpv6DestinationNetworkUnreachableSockError) StateFields() []string {
+ return []string{
+ "icmpv6DestinationUnreachableSockError",
+ }
+}
+
+func (i *icmpv6DestinationNetworkUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationNetworkUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (i *icmpv6DestinationNetworkUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationNetworkUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (i *icmpv6DestinationPortUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv6.icmpv6DestinationPortUnreachableSockError"
+}
+
+func (i *icmpv6DestinationPortUnreachableSockError) StateFields() []string {
+ return []string{
+ "icmpv6DestinationUnreachableSockError",
+ }
+}
+
+func (i *icmpv6DestinationPortUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (i *icmpv6DestinationPortUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationPortUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (i *icmpv6DestinationAddressUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv6.icmpv6DestinationAddressUnreachableSockError"
+}
+
+func (i *icmpv6DestinationAddressUnreachableSockError) StateFields() []string {
+ return []string{
+ "icmpv6DestinationUnreachableSockError",
+ }
+}
+
+func (i *icmpv6DestinationAddressUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationAddressUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (i *icmpv6DestinationAddressUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationAddressUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (e *icmpv6PacketTooBigSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv6.icmpv6PacketTooBigSockError"
+}
+
+func (e *icmpv6PacketTooBigSockError) StateFields() []string {
+ return []string{
+ "mtu",
+ }
+}
+
+func (e *icmpv6PacketTooBigSockError) beforeSave() {}
+
+// +checklocksignore
+func (e *icmpv6PacketTooBigSockError) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.mtu)
+}
+
+func (e *icmpv6PacketTooBigSockError) afterLoad() {}
+
+// +checklocksignore
+func (e *icmpv6PacketTooBigSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.mtu)
+}
+
+func init() {
+ state.Register((*icmpv6DestinationUnreachableSockError)(nil))
+ state.Register((*icmpv6DestinationNetworkUnreachableSockError)(nil))
+ state.Register((*icmpv6DestinationPortUnreachableSockError)(nil))
+ state.Register((*icmpv6DestinationAddressUnreachableSockError)(nil))
+ state.Register((*icmpv6PacketTooBigSockError)(nil))
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
deleted file mode 100644
index 8ebca735b..000000000
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ /dev/null
@@ -1,3444 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv6
-
-import (
- "bytes"
- "encoding/hex"
- "fmt"
- "io/ioutil"
- "math"
- "net"
- "reflect"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- // The least significant 3 bytes are the same as addr2 so both addr2 and
- // addr3 will have the same solicited-node address.
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
- addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
-
- // Tests use the extension header identifier values as uint8 instead of
- // header.IPv6ExtensionHeaderIdentifier.
- hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier)
- routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier)
- fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
- destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
- noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
- unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier)
-
- extraHeaderReserve = 50
-)
-
-// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
-// expected Neighbor Advertisement received count after receiving the packet.
-func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
- t.Helper()
-
- // Receive ICMP packet.
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: src,
- Dst: dst,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- stats := s.Stats().ICMP.V6.PacketsReceived
-
- if got := stats.NeighborAdvert.Value(); got != want {
- t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
- }
-}
-
-// testReceiveUDP tests receiving a UDP packet from src to dst. want is the
-// expected UDP received count after receiving the packet.
-func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
- t.Helper()
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
-
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
-
- // Receive UDP Packet.
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
- u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: header.UDPMinimumSize,
- })
-
- // UDP pseudo-header checksum.
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
-
- // UDP checksum
- sum = header.Checksum(header.UDP([]byte{}), sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- stat := s.Stats().UDP.PacketsReceived
-
- if got := stat.Value(); got != want {
- t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want)
- }
-}
-
-func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
- // sourcePacket does not have its IP Header populated. Let's copy the one
- // from the first fragment.
- source := header.IPv6(packets[0].NetworkHeader().View())
- sourceIPHeadersLen := len(source)
- vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
- source = append(source, vv.ToView()...)
-
- var reassembledPayload buffer.VectorisedView
- for i, fragment := range packets {
- // Confirm that the packet is valid.
- allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views())
- fragmentIPHeaders := header.IPv6(allBytes.ToView())
- if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) {
- return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders))
- }
-
- fragmentIPHeadersLength := fragment.NetworkHeader().View().Size()
- if fragmentIPHeadersLength != sourceIPHeadersLen {
- return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen)
- }
-
- if got := len(fragmentIPHeaders); got > int(mtu) {
- return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu)
- }
-
- sourceIPHeader := source[:header.IPv6MinimumSize]
- fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize]
-
- if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize {
- return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize)
- }
-
- // We expect the IPv6 Header to be similar across each fragment, besides the
- // payload length.
- sourceIPHeader.SetPayloadLength(0)
- fragmentIPHeader.SetPayloadLength(0)
- if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" {
- return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
- }
-
- if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve {
- return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
- }
- if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber {
- return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber)
- }
-
- if len(packets) > 1 {
- // If the source packet was big enough that it needed fragmentation, let's
- // inspect the fragment header. Because no other extension headers are
- // supported, it will always be the last extension header.
- fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength])
-
- if got := fragmentHeader.More(); got != wantFragments[i].more {
- return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more)
- }
- if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset {
- return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset)
- }
- if got := fragmentHeader.NextHeader(); got != uint8(proto) {
- return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto))
- }
- }
-
- // Store the reassembled payload as we parse each fragment. The payload
- // includes the Transport header and everything after.
- reassembledPayload.AppendView(fragment.TransportHeader().View())
- reassembledPayload.AppendView(fragment.Data().AsRange().ToOwnedView())
- }
-
- if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" {
- return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
- }
-
- return nil
-}
-
-// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
-// UDP packets destined to the IPv6 link-local all-nodes multicast address.
-func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
- tests := []struct {
- name string
- protocolFactory stack.TransportProtocolFactory
- rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
- }{
- {"ICMP", icmp.NewProtocol6, testReceiveICMP},
- {"UDP", udp.NewProtocol, testReceiveUDP},
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
- })
- e := channel.New(10, header.IPv6MinimumMTU, linkAddr1)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- // Should receive a packet destined to the all-nodes
- // multicast address.
- test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1)
- })
- }
-}
-
-// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP
-// packets destined to the IPv6 solicited-node address of an assigned IPv6
-// address.
-func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
- tests := []struct {
- name string
- protocolFactory stack.TransportProtocolFactory
- rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
- }{
- {"ICMP", icmp.NewProtocol6, testReceiveICMP},
- {"UDP", udp.NewProtocol, testReceiveUDP},
- }
-
- snmc := header.SolicitedNodeAddr(addr2)
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
- })
- e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- // Should not receive a packet destined to the solicited node address of
- // addr2/addr3 yet as we haven't added those addresses.
- test.rxf(t, s, e, addr1, snmc, 0)
-
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
- }
-
- // Should receive a packet destined to the solicited node address of
- // addr2/addr3 now that we have added added addr2.
- test.rxf(t, s, e, addr1, snmc, 1)
-
- if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
- }
-
- // Should still receive a packet destined to the solicited node address of
- // addr2/addr3 now that we have added addr3.
- test.rxf(t, s, e, addr1, snmc, 2)
-
- if err := s.RemoveAddress(nicID, addr2); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err)
- }
-
- // Should still receive a packet destined to the solicited node address of
- // addr2/addr3 now that we have removed addr2.
- test.rxf(t, s, e, addr1, snmc, 3)
-
- // Make sure addr3's endpoint does not get removed from the NIC by
- // incrementing its reference count with a route.
- r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err)
- }
- defer r.Release()
-
- if err := s.RemoveAddress(nicID, addr3); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err)
- }
-
- // Should not receive a packet destined to the solicited node address of
- // addr2/addr3 yet as both of them got removed, even though a route using
- // addr3 exists.
- test.rxf(t, s, e, addr1, snmc, 3)
- })
- }
-}
-
-// TestAddIpv6Address tests adding IPv6 addresses.
-func TestAddIpv6Address(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- addr tcpip.Address
- }{
- // This test is in response to b/140943433.
- {
- "Nil",
- tcpip.Address([]byte(nil)),
- },
- {
- "ValidUnicast",
- addr1,
- },
- {
- "ValidLinkLocalUnicast",
- lladdr0,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err)
- }
-
- if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, %d): %s", nicID, ProtocolNumber, err)
- } else if addr.Address != test.addr {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ProtocolNumber, addr.Address, test.addr)
- }
- })
- }
-}
-
-func TestReceiveIPv6ExtHdrs(t *testing.T) {
- tests := []struct {
- name string
- extHdr func(nextHdr uint8) ([]byte, uint8)
- shouldAccept bool
- countersToBeIncremented func(*tcpip.Stats) []*tcpip.StatCounter
- // Should we expect an ICMP response and if so, with what contents?
- expectICMP bool
- ICMPType header.ICMPv6Type
- ICMPCode header.ICMPv6Code
- pointer uint32
- multicast bool
- }{
- {
- name: "None",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
- shouldAccept: true,
- expectICMP: false,
- },
- {
- name: "hopbyhop with router alert option",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0,
- }, hopByHopExtHdrID
- },
- shouldAccept: true,
- countersToBeIncremented: func(stats *tcpip.Stats) []*tcpip.StatCounter {
- return []*tcpip.StatCounter{stats.IP.OptionRouterAlertReceived}
- },
- },
- {
- name: "hopbyhop with two router alert options",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0, 0, 0,
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- countersToBeIncremented: func(stats *tcpip.Stats) []*tcpip.StatCounter {
- return []*tcpip.StatCounter{
- stats.IP.OptionRouterAlertReceived,
- stats.IP.MalformedPacketsReceived,
- }
- },
- },
- {
- name: "hopbyhop with unknown option skippable action",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Skippable unknown.
- 62, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "hopbyhop with unknown option discard action",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard unknown.
- 127, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "hopbyhop with unknown option discard and send icmp action (unicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- //^ Unknown option.
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "hopbyhop with unknown option discard and send icmp action (multicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- //^ Unknown option.
- }, hopByHopExtHdrID
- },
- multicast: true,
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- //^ Unknown option.
- }, hopByHopExtHdrID
- },
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- //^ Unknown option.
- }, hopByHopExtHdrID
- },
- multicast: true,
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "routing with zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
- 1, 0, 2, 3, 4, 5,
- }, routingExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "routing with non-zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
- 1, 1, 2, 3, 4, 5,
- }, routingExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6ErroneousHeader,
- pointer: header.IPv6FixedHeaderSize + 2,
- },
- {
- name: "atomic fragment with zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
- 0, 0, 0, 0, 0, 0,
- }, fragmentExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "atomic fragment with non-zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
- 0, 0, 1, 2, 3, 4,
- }, fragmentExtHdrID
- },
- shouldAccept: true,
- expectICMP: false,
- },
- {
- name: "fragment",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
- 1, 0, 1, 2, 3, 4,
- }, fragmentExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{},
- noNextHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "unknown next header (first)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0, 63, 4, 1, 2, 3, 4,
- }, unknownHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownHeader,
- pointer: header.IPv6NextHeaderOffset,
- },
- {
- name: "unknown next header (not first)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- unknownHdrID, 0,
- 63, 4, 1, 2, 3, 4,
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownHeader,
- pointer: header.IPv6FixedHeaderSize,
- },
- {
- name: "destination with unknown option skippable action",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Skippable unknown.
- 62, 6, 1, 2, 3, 4, 5, 6,
- }, destinationExtHdrID
- },
- shouldAccept: true,
- expectICMP: false,
- },
- {
- name: "destination with unknown option discard action",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard unknown.
- 127, 6, 1, 2, 3, 4, 5, 6,
- }, destinationExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "destination with unknown option discard and send icmp action (unicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- //^ 191 is an unknown option.
- }, destinationExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "destination with unknown option discard and send icmp action (muilticast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- //^ 191 is an unknown option.
- }, destinationExtHdrID
- },
- multicast: true,
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- //^ 255 is unknown.
- }, destinationExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- //^ 255 is unknown.
- }, destinationExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- multicast: true,
- },
- {
- name: "atomic fragment - routing",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Fragment extension header.
- routingExtHdrID, 0, 0, 0, 1, 2, 3, 4,
-
- // Routing extension header.
- nextHdr, 0, 1, 0, 2, 3, 4, 5,
- }, fragmentExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "hop by hop (with skippable unknown) - routing",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Hop By Hop extension header with skippable unknown option.
- routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
-
- // Routing extension header.
- nextHdr, 0, 1, 0, 2, 3, 4, 5,
- }, hopByHopExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "routing - hop by hop (with skippable unknown)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Routing extension header.
- hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
- // ^^^ The HopByHop extension header may not appear after the first
- // extension header.
-
- // Hop By Hop extension header with skippable unknown option.
- nextHdr, 0, 62, 4, 1, 2, 3, 4,
- }, routingExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownHeader,
- pointer: header.IPv6FixedHeaderSize,
- },
- {
- name: "routing - hop by hop (with send icmp unknown)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Routing extension header.
- hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
- // ^^^ The HopByHop extension header may not appear after the first
- // extension header.
-
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Skippable unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- }, routingExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownHeader,
- pointer: header.IPv6FixedHeaderSize,
- },
- {
- name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Hop By Hop extension header with skippable unknown option.
- routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
-
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
-
- // Fragment extension header.
- destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
-
- // Destination extension header with skippable unknown option.
- nextHdr, 0, 63, 4, 1, 2, 3, 4,
- }, hopByHopExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Hop By Hop extension header with discard action for unknown option.
- routingExtHdrID, 0, 65, 4, 1, 2, 3, 4,
-
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
-
- // Fragment extension header.
- destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
-
- // Destination extension header with skippable unknown option.
- nextHdr, 0, 63, 4, 1, 2, 3, 4,
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Hop By Hop extension header with skippable unknown option.
- routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
-
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
-
- // Fragment extension header.
- destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
-
- // Destination extension header with discard action for unknown
- // option.
- nextHdr, 0, 65, 4, 1, 2, 3, 4,
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
- }
-
- // Add a default route so that a return packet knows where to go.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.WritableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
- }
- defer ep.Close()
-
- bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%+v): %s", bindAddr, err)
- }
-
- udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8}
- udpLength := header.UDPMinimumSize + len(udpPayload)
- extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber))
- extHdrLen := len(extHdrBytes)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength)
-
- // Serialize UDP message.
- u := header.UDP(hdr.Prepend(udpLength))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: uint16(udpLength),
- })
- copy(u.Payload(), udpPayload)
-
- dstAddr := tcpip.Address(addr2)
- if test.multicast {
- dstAddr = header.IPv6AllNodesMulticastAddress
- }
-
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength))
- sum = header.Checksum(udpPayload, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- // Copy extension header bytes between the UDP message and the IPv6
- // fixed header.
- copy(hdr.Prepend(extHdrLen), extHdrBytes)
-
- // Serialize IPv6 fixed header.
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- // We're lying about transport protocol here to be able to generate
- // raw extension headers from the test definitions.
- TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr),
- HopLimit: 255,
- SrcAddr: addr1,
- DstAddr: dstAddr,
- })
-
- stats := s.Stats()
- var counters []*tcpip.StatCounter
- // Make sure that the counters we expect to be incremented are initially
- // set to zero.
- if fn := test.countersToBeIncremented; fn != nil {
- counters = fn(&stats)
- }
- for i := range counters {
- if got := counters[i].Value(); got != 0 {
- t.Errorf("before writing packet: got test.countersToBeIncremented(&stats)[%d].Value() = %d, want = 0", i, got)
- }
- }
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- for i := range counters {
- if got := counters[i].Value(); got != 1 {
- t.Errorf("after writing packet: got test.countersToBeIncremented(&stats)[%d].Value() = %d, want = 1", i, got)
- }
- }
-
- udpReceiveStat := stats.UDP.PacketsReceived
- if !test.shouldAccept {
- if got := udpReceiveStat.Value(); got != 0 {
- t.Errorf("got UDP Rx Packets = %d, want = 0", got)
- }
-
- if !test.expectICMP {
- if p, ok := e.Read(); ok {
- t.Fatalf("unexpected packet received: %#v", p)
- }
- return
- }
-
- // ICMP required.
- p, ok := e.Read()
- if !ok {
- t.Fatalf("expected packet wasn't written out")
- }
-
- // Pack the output packet into a single buffer.View as the checkers
- // assume that.
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- pkt := vv.ToView()
- if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want {
- t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want)
- }
-
- ipHdr := header.IPv6(pkt)
- checker.IPv6(t, ipHdr, checker.ICMPv6(
- checker.ICMPv6Type(test.ICMPType),
- checker.ICMPv6Code(test.ICMPCode)))
-
- // We know we are looking at no extension headers in the error ICMP
- // packets.
- icmpPkt := header.ICMPv6(ipHdr.Payload())
- // We know we sent small packets that won't be truncated when reflected
- // back to us.
- originalPacket := icmpPkt.Payload()
- if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want {
- t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want)
- }
- if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" {
- t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff)
- }
- return
- }
-
- // Expect a UDP packet.
- if got := udpReceiveStat.Value(); got != 1 {
- t.Errorf("got UDP Rx Packets = %d, want = 1", got)
- }
- var buf bytes.Buffer
- result, err := ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("Read: %s", err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: len(udpPayload),
- Total: len(udpPayload),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(udpPayload, buf.Bytes()); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have any more UDP packets.
- res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
- }
- })
- }
-}
-
-// fragmentData holds the IPv6 payload for a fragmented IPv6 packet.
-type fragmentData struct {
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- nextHdr uint8
- data buffer.VectorisedView
-}
-
-func TestReceiveIPv6Fragments(t *testing.T) {
- const (
- udpPayload1Length = 256
- udpPayload2Length = 128
- // Used to test cases where the fragment blocks are not a multiple of
- // the fragment block size of 8 (RFC 8200 section 4.5).
- udpPayload3Length = 127
- udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
- udpMaximumSizeMinus15 = header.UDPMaximumSize - 15
- fragmentExtHdrLen = 8
- // Note, not all routing extension headers will be 8 bytes but this test
- // uses 8 byte routing extension headers for most sub tests.
- routingExtHdrLen = 8
- )
-
- udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View {
- payloadLen := len(payload)
- for i := 0; i < payloadLen; i++ {
- payload[i] = uint8(i) * multiplier
- }
-
- udpLength := header.UDPMinimumSize + payloadLen
-
- hdr := buffer.NewPrependable(udpLength)
- u := header.UDP(hdr.Prepend(udpLength))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: uint16(udpLength),
- })
- copy(u.Payload(), payload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
- sum = header.Checksum(payload, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
- return hdr.View()
- }
-
- var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte
- udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:]
- ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2)
-
- var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte
- udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:]
- ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2)
-
- var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte
- udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:]
- ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2)
-
- var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte
- udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:]
- ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2)
-
- var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte
- udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:]
- ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2)
-
- tests := []struct {
- name string
- expectedPayload []byte
- fragments []fragmentData
- expectedPayloads [][]byte
- }{
- {
- name: "No fragmentation",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: uint8(header.UDPProtocolNumber),
- data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Atomic fragment",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2),
- []buffer.View{
- // Fragment extension header.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
-
- ipv6Payload1Addr1ToAddr2,
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Atomic fragment with size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2),
- []buffer.View{
- // Fragment extension header.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
-
- ipv6Payload3Addr1ToAddr2,
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
- },
- {
- name: "Two fragments",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments out of order",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments with different Next Header values",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- // NextHeader value is different than the one in the first fragment, so
- // this NextHeader should be ignored.
- buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments with last fragment size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload3Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
-
- ipv6Payload3Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
- },
- {
- name: "Two fragments with first fragment size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+63,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload3Addr1ToAddr2[:63],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
-
- ipv6Payload3Addr1ToAddr2[63:],
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with different IDs",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 2
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}),
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments reassembled into a maximum UDP packet",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+udpMaximumSizeMinus15,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
- udpMaximumSizeMinus15 >> 8,
- udpMaximumSizeMinus15 & 0xff,
- 0, 0, 0, 1}),
-
- ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
- },
- {
- name: "Two fragments with MF flag reassembled into a maximum UDP packet",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+udpMaximumSizeMinus15,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
- udpMaximumSizeMinus15 >> 8,
- (udpMaximumSizeMinus15 & 0xff) + 1,
- 0, 0, 0, 1}),
-
- ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with per-fragment routing header with zero segments left",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: routingExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+64,
- []buffer.View{
- // Routing extension header.
- //
- // Segments left = 0.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
-
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: routingExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Routing extension header.
- //
- // Segments left = 0.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
-
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments with per-fragment routing header with non-zero segments left",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: routingExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+64,
- []buffer.View{
- // Routing extension header.
- //
- // Segments left = 1.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
-
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: routingExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Routing extension header.
- //
- // Segments left = 1.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
-
- // Fragment extension header.
- //
- // Fragment offset = 9, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with routing header with zero segments left",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
-
- // Routing extension header.
- //
- // Segments left = 0.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}),
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 9, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments with routing header with non-zero segments left",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
-
- // Routing extension header.
- //
- // Segments left = 1.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}),
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 9, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with routing header with zero segments left across fragments",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- // The length of this payload is fragmentExtHdrLen+8 because the
- // first 8 bytes of the 16 byte routing extension header is in
- // this fragment.
- fragmentExtHdrLen+8,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
-
- // Routing extension header (part 1)
- //
- // Segments left = 0.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}),
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- // The length of this payload is
- // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
- // the 16 byte routing extension header is in this fagment.
- fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 1, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
-
- // Routing extension header (part 2)
- buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
-
- ipv6Payload1Addr1ToAddr2,
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with routing header with non-zero segments left across fragments",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- // The length of this payload is fragmentExtHdrLen+8 because the
- // first 8 bytes of the 16 byte routing extension header is in
- // this fragment.
- fragmentExtHdrLen+8,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
-
- // Routing extension header (part 1)
- //
- // Segments left = 1.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}),
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- // The length of this payload is
- // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
- // the 16 byte routing extension header is in this fagment.
- fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 1, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
-
- // Routing extension header (part 2)
- buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
-
- ipv6Payload1Addr1ToAddr2,
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal"
- // fragmented traffic.
- {
- name: "Two fragments with atomic",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- // This fragment has the same ID as the other fragments but is an atomic
- // fragment. It should not interfere with the other fragments.
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2),
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}),
-
- ipv6Payload2Addr1ToAddr2,
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two interleaved fragmented packets",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+32,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 2
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}),
-
- ipv6Payload2Addr1ToAddr2[:32],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 4, More = false, ID = 2
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}),
-
- ipv6Payload2Addr1ToAddr2[32:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
- },
- {
- name: "Two interleaved fragmented packets from different sources but with same ID",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr3,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+32,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
-
- ipv6Payload1Addr3ToAddr2[:32],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- {
- srcAddr: addr3,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 4, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}),
-
- ipv6Payload1Addr3ToAddr2[32:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- e := channel.New(0, header.IPv6MinimumMTU, linkAddr1)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
- }
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
- }
- defer ep.Close()
-
- bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%+v): %s", bindAddr, err)
- }
-
- for _, f := range test.fragments {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize)
-
- // Serialize IPv6 fixed header.
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(f.data.Size()),
- // We're lying about transport protocol here so that we can generate
- // raw extension headers for the tests.
- TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr),
- HopLimit: 255,
- SrcAddr: f.srcAddr,
- DstAddr: f.dstAddr,
- })
-
- vv := hdr.View().ToVectorisedView()
- vv.Append(f.data)
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- }
-
- if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
- t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
- }
-
- for i, p := range test.expectedPayloads {
- var buf bytes.Buffer
- _, err := ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("(i=%d) Read: %s", i, err)
- }
- if diff := cmp.Diff(p, buf.Bytes()); diff != "" {
- t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
- }
- }
-
- res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
- }
- })
- }
-}
-
-func TestInvalidIPv6Fragments(t *testing.T) {
- const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- nicID = 1
- hoplimit = 255
- ident = 1
- data = "TEST_INVALID_IPV6_FRAGMENTS"
- )
-
- type fragmentData struct {
- ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
- payload []byte
- }
-
- tests := []struct {
- name string
- fragments []fragmentData
- wantMalformedIPPackets uint64
- wantMalformedFragments uint64
- expectICMP bool
- expectICMPType header.ICMPv6Type
- expectICMPCode header.ICMPv6Code
- expectICMPTypeSpecific uint32
- }{
- {
- name: "fragment size is not a multiple of 8 and the M flag is true",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 9,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0 >> 3,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:9],
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 1,
- expectICMP: true,
- expectICMPType: header.ICMPv6ParamProblem,
- expectICMPCode: header.ICMPv6ErroneousHeader,
- expectICMPTypeSpecific: header.IPv6PayloadLenOffset,
- },
- {
- name: "fragments reassembled into a payload exceeding the max IPv6 payload size",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3,
- M: false,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 1,
- expectICMP: true,
- expectICMPType: header.ICMPv6ParamProblem,
- expectICMPCode: header.ICMPv6ErroneousHeader,
- expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- NewProtocol,
- },
- })
- e := channel.New(1, 1500, linkAddr1)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- }})
-
- var expectICMPPayload buffer.View
- for _, f := range test.fragments {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- encodeArgs := f.ipv6Fields
- encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
- ip.Encode(&encodeArgs)
-
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(f.payload)
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
-
- if test.expectICMP {
- expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader())
- }
-
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
- t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want)
- }
- if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
- t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want)
- }
-
- reply, ok := e.Read()
- if !test.expectICMP {
- if ok {
- t.Fatalf("unexpected ICMP error message received: %#v", reply)
- }
- return
- }
- if !ok {
- t.Fatal("expected ICMP error message missing")
- }
-
- checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
- checker.SrcAddr(addr2),
- checker.DstAddr(addr1),
- checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())),
- checker.ICMPv6(
- checker.ICMPv6Type(test.expectICMPType),
- checker.ICMPv6Code(test.expectICMPCode),
- checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific),
- checker.ICMPv6Payload([]byte(expectICMPPayload)),
- ),
- )
- })
- }
-}
-
-func TestFragmentReassemblyTimeout(t *testing.T) {
- const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- nicID = 1
- hoplimit = 255
- ident = 1
- data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
- )
-
- type fragmentData struct {
- ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
- payload []byte
- }
-
- tests := []struct {
- name string
- fragments []fragmentData
- expectICMP bool
- }{
- {
- name: "first fragment only",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- },
- expectICMP: true,
- },
- {
- name: "two first fragments",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- },
- expectICMP: true,
- },
- {
- name: "second fragment only",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 8,
- M: false,
- Identification: ident,
- },
- payload: []byte(data)[16:],
- },
- },
- expectICMP: false,
- },
- {
- name: "two fragments with a gap",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 8,
- M: false,
- Identification: ident,
- },
- payload: []byte(data)[16:],
- },
- },
- expectICMP: true,
- },
- {
- name: "two fragments with a gap in reverse order",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 8,
- M: false,
- Identification: ident,
- },
- payload: []byte(data)[16:],
- },
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- },
- expectICMP: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- NewProtocol,
- },
- Clock: clock,
- })
-
- e := channel.New(1, 1500, linkAddr1)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- }})
-
- var firstFragmentSent buffer.View
- for _, f := range test.fragments {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- encodeArgs := f.ipv6Fields
- encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
- ip.Encode(&encodeArgs)
-
- fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
-
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(f.payload)
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
-
- if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 {
- firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
- }
-
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- clock.Advance(ReassembleTimeout)
-
- reply, ok := e.Read()
- if !test.expectICMP {
- if ok {
- t.Fatalf("unexpected ICMP error message received: %#v", reply)
- }
- return
- }
- if !ok {
- t.Fatal("expected ICMP error message missing")
- }
- if firstFragmentSent == nil {
- t.Fatalf("unexpected ICMP error message received: %#v", reply)
- }
-
- checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
- checker.SrcAddr(addr2),
- checker.DstAddr(addr1),
- checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6TimeExceeded),
- checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout),
- checker.ICMPv6Payload([]byte(firstFragmentSent)),
- ),
- )
- })
- }
-}
-
-func TestWriteStats(t *testing.T) {
- const nPackets = 3
- tests := []struct {
- name string
- setup func(*testing.T, *stack.Stack)
- allowPackets int
- expectSent int
- expectOutputDropped int
- expectPostroutingDropped int
- expectWritten int
- }{
- {
- name: "Accept all",
- // No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: math.MaxInt32,
- expectSent: nPackets,
- expectOutputDropped: 0,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Accept all with error",
- // No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: nPackets - 1,
- expectSent: nPackets - 1,
- expectOutputDropped: 0,
- expectPostroutingDropped: 0,
- expectWritten: nPackets - 1,
- }, {
- name: "Drop all with Output chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Output DROP rule.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectOutputDropped: nPackets,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Drop all with Postrouting chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Output DROP rule.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectOutputDropped: 0,
- expectPostroutingDropped: nPackets,
- expectWritten: nPackets,
- }, {
- name: "Drop some with Output chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Output DROP rule that matches only 1
- // of the 3 packets.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- // We'll match and DROP the last packet.
- ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
- // Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectOutputDropped: 1,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Drop some with Postrouting chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Postrouting DROP rule that matches only 1
- // of the 3 packets.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
- // We'll match and DROP the last packet.
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
- // Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectOutputDropped: 0,
- expectPostroutingDropped: 1,
- expectWritten: nPackets,
- },
- }
-
- writers := []struct {
- name string
- writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error)
- }{
- {
- name: "WritePacket",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
- nWritten := 0
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil {
- return nWritten, err
- }
- nWritten++
- }
- return nWritten, nil
- },
- }, {
- name: "WritePackets",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
- return rt.WritePackets(pkts, stack.NetworkHeaderParams{})
- },
- },
- }
-
- for _, writer := range writers {
- t.Run(writer.name, func(t *testing.T) {
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ep := iptestutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
- rt := buildRoute(t, ep)
- var pkts stack.PacketBufferList
- for i := 0; i < nPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
- Data: buffer.NewView(0).ToVectorisedView(),
- })
- pkt.TransportHeader().Push(header.UDPMinimumSize)
- pkts.PushBack(pkt)
- }
-
- test.setup(t, rt.Stack())
-
- nWritten, _ := writer.writePackets(rt, pkts)
-
- if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
- t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
- }
- if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
- t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
- }
- if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
- t.Errorf("got r.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
- }
- if nWritten != test.expectWritten {
- t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
- }
- })
- }
- })
- }
-}
-
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatalf("CreateNIC(1, _) failed: %s", err)
- }
- const (
- src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- )
- if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
- }
- {
- mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
- subnet, err := tcpip.NewSubnet(dst, mask)
- if err != nil {
- t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }})
- }
- rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err)
- }
- return rt
-}
-
-// limitedMatcher is an iptables matcher that matches after a certain number of
-// packets are checked against it.
-type limitedMatcher struct {
- limit int
-}
-
-// Name implements Matcher.Name.
-func (*limitedMatcher) Name() string {
- return "limitedMatcher"
-}
-
-// Match implements Matcher.Match.
-func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) {
- if lm.limit == 0 {
- return true, false
- }
- lm.limit--
- return false, false
-}
-
-func knownNICIDs(proto *protocol) []tcpip.NICID {
- var nicIDs []tcpip.NICID
-
- for k := range proto.mu.eps {
- nicIDs = append(nicIDs, k)
- }
-
- return nicIDs
-}
-
-func TestClearEndpointFromProtocolOnClose(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- var nic testInterface
- ep := proto.NewEndpoint(&nic, nil).(*endpoint)
- var nicIDs []tcpip.NICID
-
- proto.mu.Lock()
- foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
- if !hasEndpointBeforeClose {
- t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
- }
- if foundEP != ep {
- t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
- }
-
- ep.Close()
-
- proto.mu.Lock()
- _, hasEndpointAfterClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
- if hasEndpointAfterClose {
- t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
- }
-}
-
-type fragmentInfo struct {
- offset uint16
- more bool
- payloadSize uint16
-}
-
-var fragmentationTests = []struct {
- description string
- mtu uint32
- transHdrLen int
- payloadSize int
- wantFragments []fragmentInfo
-}{
- {
- description: "No fragmentation",
- mtu: header.IPv6MinimumMTU,
- transHdrLen: 0,
- payloadSize: 1000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1000, more: false},
- },
- },
- {
- description: "Fragmented",
- mtu: header.IPv6MinimumMTU,
- transHdrLen: 0,
- payloadSize: 2000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1240, more: true},
- {offset: 154, payloadSize: 776, more: false},
- },
- },
- {
- description: "Fragmented with mtu not a multiple of 8",
- mtu: header.IPv6MinimumMTU + 1,
- transHdrLen: 0,
- payloadSize: 2000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1240, more: true},
- {offset: 154, payloadSize: 776, more: false},
- },
- },
- {
- description: "No fragmentation with big header",
- mtu: 2000,
- transHdrLen: 100,
- payloadSize: 1000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1100, more: false},
- },
- },
- {
- description: "Fragmented with big header",
- mtu: header.IPv6MinimumMTU,
- transHdrLen: 100,
- payloadSize: 1200,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1240, more: true},
- {offset: 154, payloadSize: 76, more: false},
- },
- },
-}
-
-func TestFragmentationWritePacket(t *testing.T) {
- const (
- ttl = 42
- tos = stack.DefaultTOS
- transportProto = tcp.ProtocolNumber
- )
-
- for _, ft := range fragmentationTests {
- t.Run(ft.description, func(t *testing.T) {
- pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
- source := pkt.Clone()
- ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
- r := buildRoute(t, ep)
- err := r.WritePacket(stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- }, pkt)
- if err != nil {
- t.Fatalf("WritePacket(_, _, _): = %s", err)
- }
- if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
- t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
- }
- if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
- }
- if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
- t.Error(err)
- }
- })
- }
-}
-
-func TestFragmentationWritePackets(t *testing.T) {
- const ttl = 42
- tests := []struct {
- description string
- insertBefore int
- insertAfter int
- }{
- {
- description: "Single packet",
- insertBefore: 0,
- insertAfter: 0,
- },
- {
- description: "With packet before",
- insertBefore: 1,
- insertAfter: 0,
- },
- {
- description: "With packet after",
- insertBefore: 0,
- insertAfter: 1,
- },
- {
- description: "With packet before and after",
- insertBefore: 1,
- insertAfter: 1,
- },
- }
- tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
-
- for _, test := range tests {
- t.Run(test.description, func(t *testing.T) {
- for _, ft := range fragmentationTests {
- t.Run(ft.description, func(t *testing.T) {
- var pkts stack.PacketBufferList
- for i := 0; i < test.insertBefore; i++ {
- pkts.PushBack(tinyPacket.Clone())
- }
- pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
- source := pkt
- pkts.PushBack(pkt.Clone())
- for i := 0; i < test.insertAfter; i++ {
- pkts.PushBack(tinyPacket.Clone())
- }
-
- ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
- r := buildRoute(t, ep)
-
- wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
- n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- })
- if n != wantTotalPackets || err != nil {
- t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets)
- }
- if got := len(ep.WrittenPackets); got != wantTotalPackets {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
- t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
- }
- if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
- }
-
- if wantTotalPackets == 0 {
- return
- }
-
- fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
- if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
- t.Error(err)
- }
- })
- }
- })
- }
-}
-
-// TestFragmentationErrors checks that errors are returned from WritePacket
-// correctly.
-func TestFragmentationErrors(t *testing.T) {
- const ttl = 42
-
- tests := []struct {
- description string
- mtu uint32
- transHdrLen int
- payloadSize int
- allowPackets int
- outgoingErrors int
- mockError tcpip.Error
- wantError tcpip.Error
- }{
- {
- description: "No frag",
- mtu: 2000,
- payloadSize: 1000,
- transHdrLen: 0,
- allowPackets: 0,
- outgoingErrors: 1,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error on first frag",
- mtu: 1300,
- payloadSize: 3000,
- transHdrLen: 0,
- allowPackets: 0,
- outgoingErrors: 3,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error on second frag",
- mtu: 1500,
- payloadSize: 4000,
- transHdrLen: 0,
- allowPackets: 1,
- outgoingErrors: 2,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error when MTU is smaller than transport header",
- mtu: header.IPv6MinimumMTU,
- transHdrLen: 1500,
- payloadSize: 500,
- allowPackets: 0,
- outgoingErrors: 1,
- mockError: nil,
- wantError: &tcpip.ErrMessageTooLong{},
- },
- {
- description: "Error when MTU is smaller than IPv6 minimum MTU",
- mtu: header.IPv6MinimumMTU - 1,
- transHdrLen: 0,
- payloadSize: 500,
- allowPackets: 0,
- outgoingErrors: 1,
- mockError: nil,
- wantError: &tcpip.ErrInvalidEndpointState{},
- },
- }
-
- for _, ft := range tests {
- t.Run(ft.description, func(t *testing.T) {
- pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
- ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
- r := buildRoute(t, ep)
- err := r.WritePacket(stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- }, pkt)
- if diff := cmp.Diff(ft.wantError, err); diff != "" {
- t.Errorf("unexpected error from WritePacket(_, _, _), (-want, +got):\n%s", diff)
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
- t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
- }
- if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
- }
- })
- }
-}
-
-func TestForwarding(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
- randomSequence = 123
- randomIdent = 42
- )
-
- ipv6Addr1 := tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("10::1").To16()),
- PrefixLen: 64,
- }
- ipv6Addr2 := tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("11::1").To16()),
- PrefixLen: 64,
- }
- remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16())
- remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16())
- unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16())
- multicastIPv6Addr := tcpip.Address(net.ParseIP("ff00::").To16())
- linkLocalIPv6Addr := tcpip.Address(net.ParseIP("fe80::").To16())
-
- tests := []struct {
- name string
- extHdr func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker)
- TTL uint8
- expectErrorICMP bool
- expectPacketForwarded bool
- countUnrouteablePackets uint64
- sourceAddr tcpip.Address
- destAddr tcpip.Address
- icmpType header.ICMPv6Type
- icmpCode header.ICMPv6Code
- expectPacketUnrouteableError bool
- expectLinkLocalSourceError bool
- expectLinkLocalDestError bool
- expectExtensionHeaderError bool
- }{
- {
- name: "TTL of zero",
- TTL: 0,
- expectErrorICMP: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- icmpType: header.ICMPv6TimeExceeded,
- icmpCode: header.ICMPv6HopLimitExceeded,
- },
- {
- name: "TTL of one",
- TTL: 1,
- expectErrorICMP: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- icmpType: header.ICMPv6TimeExceeded,
- icmpCode: header.ICMPv6HopLimitExceeded,
- },
- {
- name: "TTL of two",
- TTL: 2,
- expectPacketForwarded: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- },
- {
- name: "TTL of three",
- TTL: 3,
- expectPacketForwarded: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- },
- {
- name: "Max TTL",
- TTL: math.MaxUint8,
- expectPacketForwarded: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- },
- {
- name: "Network unreachable",
- TTL: 2,
- expectErrorICMP: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: unreachableIPv6Addr,
- icmpType: header.ICMPv6DstUnreachable,
- icmpCode: header.ICMPv6NetworkUnreachable,
- expectPacketUnrouteableError: true,
- },
- {
- name: "Multicast destination",
- TTL: 2,
- countUnrouteablePackets: 1,
- sourceAddr: remoteIPv6Addr1,
- destAddr: multicastIPv6Addr,
- expectPacketUnrouteableError: true,
- },
- {
- name: "Link local destination",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: linkLocalIPv6Addr,
- expectLinkLocalDestError: true,
- },
- {
- name: "Link local source",
- TTL: 2,
- sourceAddr: linkLocalIPv6Addr,
- destAddr: remoteIPv6Addr2,
- expectLinkLocalSourceError: true,
- },
- {
- name: "Hopbyhop with unknown option skippable action",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Skippable unknown.
- 62, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6UnknownOption(), checker.IPv6UnknownOption()))
- },
- expectPacketForwarded: true,
- },
- {
- name: "Hopbyhop with unknown option discard action",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard unknown.
- 127, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, nil
- },
- expectExtensionHeaderError: true,
- },
- {
- name: "Hopbyhop with unknown option discard and send icmp action (unicast)",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, nil
- },
- expectErrorICMP: true,
- icmpType: header.ICMPv6ParamProblem,
- icmpCode: header.ICMPv6UnknownOption,
- expectExtensionHeaderError: true,
- },
- {
- name: "Hopbyhop with unknown option discard and send icmp action (multicast)",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: multicastIPv6Addr,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, nil
- },
- expectErrorICMP: true,
- icmpType: header.ICMPv6ParamProblem,
- icmpCode: header.ICMPv6UnknownOption,
- expectExtensionHeaderError: true,
- },
- {
- name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, nil
- },
- expectErrorICMP: true,
- icmpType: header.ICMPv6ParamProblem,
- icmpCode: header.ICMPv6UnknownOption,
- expectExtensionHeaderError: true,
- },
- {
- name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: multicastIPv6Addr,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, nil
- },
- expectExtensionHeaderError: true,
- },
- {
- name: "Hopbyhop with router alert option",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 0,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0,
- }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)))
- },
- expectPacketForwarded: true,
- },
- {
- name: "Hopbyhop with two router alert options",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0,
- }, hopByHopExtHdrID, nil
- },
- expectExtensionHeaderError: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- })
- // We expect at most a single packet in response to our ICMP Echo Request.
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1}
- if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err)
- }
-
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
- ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2}
- if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: ipv6Addr1.Subnet(),
- NIC: nicID1,
- },
- {
- Destination: ipv6Addr2.Subnet(),
- NIC: nicID2,
- },
- })
-
- if err := s.SetForwarding(ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err)
- }
-
- transportProtocol := header.ICMPv6ProtocolNumber
- extHdrBytes := []byte{}
- extHdrChecker := checker.IPv6ExtHdr()
- if test.extHdr != nil {
- nextHdrID := hopByHopExtHdrID
- extHdrBytes, nextHdrID, extHdrChecker = test.extHdr(uint8(header.ICMPv6ProtocolNumber))
- transportProtocol = tcpip.TransportProtocolNumber(nextHdrID)
- }
- extHdrLen := len(extHdrBytes)
-
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize + extHdrLen)
- icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- icmp.SetIdent(randomIdent)
- icmp.SetSequence(randomSequence)
- icmp.SetType(header.ICMPv6EchoRequest)
- icmp.SetCode(header.ICMPv6UnusedCode)
- icmp.SetChecksum(0)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: test.sourceAddr,
- Dst: test.destAddr,
- }))
- copy(hdr.Prepend(extHdrLen), extHdrBytes)
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: transportProtocol,
- HopLimit: test.TTL,
- SrcAddr: test.sourceAddr,
- DstAddr: test.destAddr,
- })
- requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- })
- e1.InjectInbound(ProtocolNumber, requestPkt)
-
- reply, ok := e1.Read()
- if test.expectErrorICMP {
- if !ok {
- t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
- }
-
- checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(ipv6Addr1.Address),
- checker.DstAddr(test.sourceAddr),
- checker.TTL(DefaultTTL),
- checker.ICMPv6(
- checker.ICMPv6Type(test.icmpType),
- checker.ICMPv6Code(test.icmpCode),
- checker.ICMPv6Payload([]byte(hdr.View())),
- ),
- )
-
- if n := e2.Drain(); n != 0 {
- t.Fatalf("got e2.Drain() = %d, want = 0", n)
- }
- } else if ok {
- t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
- }
-
- reply, ok = e2.Read()
- if test.expectPacketForwarded {
- if !ok {
- t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
- }
-
- checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(test.sourceAddr),
- checker.DstAddr(test.destAddr),
- checker.TTL(test.TTL-1),
- extHdrChecker,
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoRequest),
- checker.ICMPv6Code(header.ICMPv6UnusedCode),
- checker.ICMPv6Payload(nil),
- ),
- )
-
- if n := e1.Drain(); n != 0 {
- t.Fatalf("got e1.Drain() = %d, want = 0", n)
- }
- } else if ok {
- t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
- }
-
- boolToInt := func(val bool) uint64 {
- if val {
- return 1
- }
- return 0
- }
-
- if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 1); got != want {
- t.Errorf("got rt.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value(), boolToInt(test.expectExtensionHeaderError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
-
-func TestMultiCounterStatsInitialization(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- var nic testInterface
- ep := proto.NewEndpoint(&nic, nil).(*endpoint)
- // At this point, the Stack's stats and the NetworkEndpoint's stats are
- // supposed to be bound.
- refStack := s.Stats()
- refEP := ep.stats.localStats
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refStack.IP).Elem(), reflect.ValueOf(&refEP.IP).Elem()}); err != nil {
- t.Error(err)
- }
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refStack.ICMP.V6).Elem(), reflect.ValueOf(&refEP.ICMP).Elem()}); err != nil {
- t.Error(err)
- }
-}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
deleted file mode 100644
index 71d1c3e28..000000000
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ /dev/null
@@ -1,601 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv6_test
-
-import (
- "bytes"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-var (
- linkLocalAddr = testutil.MustParse6("fe80::1")
- globalAddr = testutil.MustParse6("a80::1")
- globalMulticastAddr = testutil.MustParse6("ff05:100::2")
-
- linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr)
- globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
-)
-
-func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) {
- t.Helper()
-
- checker.IPv6WithExtHdr(t, p,
- checker.IPv6ExtHdr(
- checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
- ),
- checker.SrcAddr(localAddress),
- checker.DstAddr(remoteAddress),
- checker.TTL(header.MLDHopLimit),
- checker.MLD(mldType, header.MLDMinimumSize,
- checker.MLDMaxRespDelay(0),
- checker.MLDMulticastAddress(groupAddress),
- ),
- )
-}
-
-func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
- const nicID = 1
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- MLD: ipv6.MLDOptions{
- Enabled: true,
- },
- })},
- })
- e := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- // The stack will join an address's solicited node multicast address when
- // an address is added. An MLD report message should be sent for the
- // solicited-node group.
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
- }
-
- // The stack will leave an address's solicited node multicast address when
- // an address is removed. An MLD done message should be sent for the
- // solicited-node group.
- if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a done message to be sent")
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersLinkLocalMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
- }
-}
-
-func TestSendQueuedMLDReports(t *testing.T) {
- const (
- nicID = 1
- maxReports = 2
- )
-
- tests := []struct {
- name string
- dadTransmits uint8
- retransmitTimer time.Duration
- }{
- {
- name: "DAD Disabled",
- dadTransmits: 0,
- retransmitTimer: 0,
- },
- {
- name: "DAD Enabled",
- dadTransmits: 1,
- retransmitTimer: time.Second,
- },
- }
-
- nonce := [...]byte{
- 1, 2, 3, 4, 5, 6,
- }
-
- const maxNSMessages = 2
- secureRNGBytes := make([]byte, len(nonce)*maxNSMessages)
- for b := secureRNGBytes[:]; len(b) > 0; b = b[len(nonce):] {
- if n := copy(b, nonce[:]); n != len(nonce) {
- t.Fatalf("got copy(...) = %d, want = %d", n, len(nonce))
- }
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits)
- clock := faketime.NewManualClock()
- var secureRNG bytes.Reader
- secureRNG.Reset(secureRNGBytes[:])
- s := stack.New(stack.Options{
- SecureRNG: &secureRNG,
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: test.dadTransmits,
- RetransmitTimer: test.retransmitTimer,
- },
- MLD: ipv6.MLDOptions{
- Enabled: true,
- },
- })},
- Clock: clock,
- })
-
- // Allow space for an extra packet so we can observe packets that were
- // unexpectedly sent.
- e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- resolveDAD := func(addr, snmc tcpip.Address) {
- clock.Advance(dadResolutionTime)
- if p, ok := e.Read(); !ok {
- t.Fatal("expected DAD packet")
- } else {
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(addr),
- checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonce[:])}),
- ))
- }
- }
-
- var reportCounter uint64
- reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
- var doneCounter uint64
- doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
- if got := doneStat.Value(); got != doneCounter {
- t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
- }
-
- // Joining a group without an assigned address should send an MLD report
- // with the unspecified address.
- if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
- }
- reportCounter++
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Errorf("expected MLD report for %s", globalMulticastAddr)
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr)
- }
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Errorf("got unexpected packet = %#v", p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Adding a global address should not send reports for the already joined
- // group since we should only send queued reports when a link-local
- // address is assigned.
- //
- // Note, we will still expect to send a report for the global address's
- // solicited node address from the unspecified address as per RFC 3590
- // section 4.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
- }
- reportCounter++
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Errorf("expected MLD report for %s", globalAddrSNMC)
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC)
- }
- if dadResolutionTime != 0 {
- // Reports should not be sent when the address resolves.
- resolveDAD(globalAddr, globalAddrSNMC)
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
- }
- // Leave the group since we don't care about the global address's
- // solicited node multicast group membership.
- if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil {
- t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err)
- }
- if got := doneStat.Value(); got != doneCounter {
- t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
- }
- if p, ok := e.Read(); ok {
- t.Errorf("got unexpected packet = %#v", p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Adding a link-local address should send a report for its solicited node
- // address and globalMulticastAddr.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
- }
- if dadResolutionTime != 0 {
- reportCounter++
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Errorf("expected MLD report for %s", linkLocalAddrSNMC)
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
- }
- resolveDAD(linkLocalAddr, linkLocalAddrSNMC)
- }
-
- // We expect two batches of reports to be sent (1 batch when the
- // link-local address is assigned, and another after the maximum
- // unsolicited report interval.
- for i := 0; i < 2; i++ {
- // We expect reports to be sent (one for globalMulticastAddr and another
- // for linkLocalAddrSNMC).
- reportCounter += maxReports
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
-
- addrs := map[tcpip.Address]bool{
- globalMulticastAddr: false,
- linkLocalAddrSNMC: false,
- }
- for range addrs {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs)
- }
-
- addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress()
- if seen, ok := addrs[addr]; !ok {
- t.Fatalf("got unexpected packet destined to %s", addr)
- } else if seen {
- t.Fatalf("got another packet destined to %s", addr)
- }
-
- addrs[addr] = true
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr)
-
- clock.Advance(ipv6.UnsolicitedReportIntervalMax)
- }
- }
-
- // Should not send any more reports.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Errorf("got unexpected packet = %#v", p)
- }
- })
- }
-}
-
-// createAndInjectMLDPacket creates and injects an MLD packet with the
-// specified fields.
-func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, hopLimit uint8, srcAddress tcpip.Address, withRouterAlertOption bool, routerAlertValue header.IPv6RouterAlertValue) {
- var extensionHeaders header.IPv6ExtHdrSerializer
- if withRouterAlertOption {
- extensionHeaders = header.IPv6ExtHdrSerializer{
- header.IPv6SerializableHopByHopExtHdr{
- &header.IPv6RouterAlertOption{Value: routerAlertValue},
- },
- }
- }
-
- extensionHeadersLength := extensionHeaders.Length()
- payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize
- buf := buffer.NewView(header.IPv6MinimumSize + payloadLength)
-
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- HopLimit: hopLimit,
- TransportProtocol: header.ICMPv6ProtocolNumber,
- SrcAddr: srcAddress,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- ExtensionHeaders: extensionHeaders,
- })
-
- icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:])
- icmp.SetType(mldType)
- mld := header.MLD(icmp.MessageBody())
- mld.SetMaximumResponseDelay(0)
- mld.SetMulticastAddress(header.IPv6Any)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: srcAddress,
- Dst: header.IPv6AllNodesMulticastAddress,
- }))
-
- e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-}
-
-func TestMLDPacketValidation(t *testing.T) {
- const nicID = 1
- linkLocalAddr2 := testutil.MustParse6("fe80::2")
-
- tests := []struct {
- name string
- messageType header.ICMPv6Type
- srcAddr tcpip.Address
- includeRouterAlertOption bool
- routerAlertValue header.IPv6RouterAlertValue
- hopLimit uint8
- expectValidMLD bool
- getMessageTypeStatValue func(tcpip.Stats) uint64
- }{
- {
- name: "valid",
- messageType: header.ICMPv6MulticastListenerQuery,
- includeRouterAlertOption: true,
- routerAlertValue: header.IPv6RouterAlertMLD,
- srcAddr: linkLocalAddr2,
- hopLimit: header.MLDHopLimit,
- expectValidMLD: true,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerQuery.Value() },
- },
- {
- name: "bad hop limit",
- messageType: header.ICMPv6MulticastListenerReport,
- includeRouterAlertOption: true,
- routerAlertValue: header.IPv6RouterAlertMLD,
- srcAddr: linkLocalAddr2,
- hopLimit: header.MLDHopLimit + 1,
- expectValidMLD: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerReport.Value() },
- },
- {
- name: "src ip not link local",
- messageType: header.ICMPv6MulticastListenerReport,
- includeRouterAlertOption: true,
- routerAlertValue: header.IPv6RouterAlertMLD,
- srcAddr: globalAddr,
- hopLimit: header.MLDHopLimit,
- expectValidMLD: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerReport.Value() },
- },
- {
- name: "missing router alert ip option",
- messageType: header.ICMPv6MulticastListenerDone,
- includeRouterAlertOption: false,
- srcAddr: linkLocalAddr2,
- hopLimit: header.MLDHopLimit,
- expectValidMLD: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerDone.Value() },
- },
- {
- name: "incorrect router alert value",
- messageType: header.ICMPv6MulticastListenerDone,
- includeRouterAlertOption: true,
- routerAlertValue: header.IPv6RouterAlertRSVP,
- srcAddr: linkLocalAddr2,
- hopLimit: header.MLDHopLimit,
- expectValidMLD: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerDone.Value() },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- MLD: ipv6.MLDOptions{
- Enabled: true,
- },
- })},
- })
- e := channel.New(nicID, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- stats := s.Stats()
- // Verify that every relevant stats is zero'd before we send a packet.
- if got := test.getMessageTypeStatValue(s.Stats()); got != 0 {
- t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got)
- }
- if got := stats.ICMP.V6.PacketsReceived.Invalid.Value(); got != 0 {
- t.Errorf("got stats.ICMP.V6.PacketsReceived.Invalid.Value() = %d, want = 0", got)
- }
- if got := stats.IP.PacketsDelivered.Value(); got != 0 {
- t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got)
- }
- createAndInjectMLDPacket(e, test.messageType, test.hopLimit, test.srcAddr, test.includeRouterAlertOption, test.routerAlertValue)
- // We always expect the packet to pass IP validation.
- if got := stats.IP.PacketsDelivered.Value(); got != 1 {
- t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got)
- }
- // Even when the MLD-specific validation checks fail, we expect the
- // corresponding MLD counter to be incremented.
- if got := test.getMessageTypeStatValue(s.Stats()); got != 1 {
- t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got)
- }
- var expectedInvalidCount uint64
- if !test.expectValidMLD {
- expectedInvalidCount = 1
- }
- if got := stats.ICMP.V6.PacketsReceived.Invalid.Value(); got != expectedInvalidCount {
- t.Errorf("got stats.ICMP.V6.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount)
- }
- })
- }
-}
-
-func TestMLDSkipProtocol(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- group tcpip.Address
- expectReport bool
- }{
- {
- name: "Reserverd0",
- group: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: false,
- },
- {
- name: "Interface Local",
- group: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: false,
- },
- {
- name: "Link Local",
- group: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Realm Local",
- group: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Admin Local",
- group: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Site Local",
- group: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(6)",
- group: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(7)",
- group: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Organization Local",
- group: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(9)",
- group: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(A)",
- group: "\xff\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(B)",
- group: "\xff\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(C)",
- group: "\xff\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(D)",
- group: "\xff\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Global",
- group: "\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "ReservedF",
- group: "\xff\x0f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- MLD: ipv6.MLDOptions{
- Enabled: true,
- },
- })},
- })
- e := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
- }
-
- if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, test.group); err != nil {
- t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, test.group, err)
- }
- if isInGroup, err := s.IsInGroup(nicID, test.group); err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.group, err)
- } else if !isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.group)
- }
-
- if !test.expectReport {
- if p, ok := e.Read(); ok {
- t.Fatalf("got e.Read() = (%#v, true), want = (_, false)", p)
- }
-
- return
- }
-
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, test.group, header.ICMPv6MulticastListenerReport, test.group)
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
deleted file mode 100644
index 52b9a200c..000000000
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ /dev/null
@@ -1,1376 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv6
-
-import (
- "bytes"
- "context"
- "strings"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
-)
-
-// setupStackAndEndpoint creates a stack with a single NIC with a link-local
-// address llladdr and an IPv6 endpoint to a remote with link-local address
-// rlladdr
-func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
- t.Helper()
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- })
-
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
- {
- subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
-
- netProto := s.NetworkProtocolInstance(ProtocolNumber)
- if netProto == nil {
- t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
- }
-
- ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{})
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
- t.Cleanup(ep.Close)
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
- }
- addr := llladdr.WithPrefix()
- if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
- } else {
- addressEP.DecRef()
- }
-
- return s, ep
-}
-
-var _ NDPDispatcher = (*testNDPDispatcher)(nil)
-
-// testNDPDispatcher is an NDPDispatcher only allows default router discovery.
-type testNDPDispatcher struct {
- addr tcpip.Address
-}
-
-func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
-}
-
-func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
- t.addr = addr
- return true
-}
-
-func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) {
- t.addr = addr
-}
-
-func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
- return false
-}
-
-func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {
-}
-
-func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
- return false
-}
-
-func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {
-}
-
-func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {
-}
-
-func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {
-}
-
-func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {
-}
-
-func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) {
-}
-
-func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
- var ndpDisp testNDPDispatcher
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err)
- }
-
- ipv6EP := ep.(*endpoint)
- ipv6EP.mu.Lock()
- ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour)
- ipv6EP.mu.Unlock()
-
- if ndpDisp.addr != lladdr1 {
- t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
- }
-
- ndpDisp.addr = ""
- ndpEP := ep.(stack.NDPEndpoint)
- ndpEP.InvalidateDefaultRouter(lladdr1)
- if ndpDisp.addr != lladdr1 {
- t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
- }
-}
-
-type linkResolutionResult struct {
- linkAddr tcpip.LinkAddress
- ok bool
-}
-
-// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a
-// valid NDP NS message with the Source Link Layer Address option results in a
-// new entry in the link address cache for the sender of the message.
-func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- optsBuf []byte
- expectedLinkAddr tcpip.LinkAddress
- }{
- {
- name: "Valid",
- optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
- expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
- },
- {
- name: "Too Small",
- optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
- },
- {
- name: "Invalid Length",
- optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- e := channel.New(0, 1280, linkAddr0)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
- }
-
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.MessageBody())
- ns.SetTargetAddress(lladdr0)
- opts := ns.Options()
- copy(opts, test.optsBuf)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: lladdr1,
- Dst: lladdr0,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
-
- invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- neighbors, err := s.Neighbors(nicID, ProtocolNumber)
- if err != nil {
- t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
- }
-
- neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
- for _, n := range neighbors {
- if existing, ok := neighborByAddr[n.Addr]; ok {
- if diff := cmp.Diff(existing, n); diff != "" {
- t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff)
- }
- t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing)
- }
- neighborByAddr[n.Addr] = n
- }
-
- if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 {
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
-
- if !ok {
- t.Fatalf("expected a neighbor entry for %q", lladdr1)
- }
- if neigh.LinkAddr != test.expectedLinkAddr {
- t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr)
- }
- if neigh.State != stack.Stale {
- t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale)
- }
- } else {
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
-
- if ok {
- t.Fatalf("unexpectedly got neighbor entry: %#v", neigh)
- }
- }
- })
- }
-}
-
-func TestNeighborSolicitationResponse(t *testing.T) {
- const nicID = 1
- nicAddr := lladdr0
- remoteAddr := lladdr1
- nicAddrSNMC := header.SolicitedNodeAddr(nicAddr)
- nicLinkAddr := linkAddr0
- remoteLinkAddr0 := linkAddr1
- remoteLinkAddr1 := linkAddr2
-
- tests := []struct {
- name string
- nsOpts header.NDPOptionsSerializer
- nsSrcLinkAddr tcpip.LinkAddress
- nsSrc tcpip.Address
- nsDst tcpip.Address
- nsInvalid bool
- naDstLinkAddr tcpip.LinkAddress
- naSolicited bool
- naSrc tcpip.Address
- naDst tcpip.Address
- performsLinkResolution bool
- }{
- {
- name: "Unspecified source to solicited-node multicast destination",
- nsOpts: nil,
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: header.IPv6Any,
- nsDst: nicAddrSNMC,
- nsInvalid: false,
- naDstLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
- naSolicited: false,
- naSrc: nicAddr,
- naDst: header.IPv6AllNodesMulticastAddress,
- },
- {
- name: "Unspecified source with source ll option to multicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: header.IPv6Any,
- nsDst: nicAddrSNMC,
- nsInvalid: true,
- },
- {
- name: "Unspecified source to unicast destination",
- nsOpts: nil,
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: header.IPv6Any,
- nsDst: nicAddr,
- nsInvalid: true,
- },
- {
- name: "Unspecified source with source ll option to unicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: header.IPv6Any,
- nsDst: nicAddr,
- nsInvalid: true,
- },
- {
- name: "Specified source with 1 source ll to multicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddrSNMC,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
- naSolicited: true,
- naSrc: nicAddr,
- naDst: remoteAddr,
- },
- {
- name: "Specified source with 1 source ll different from route to multicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddrSNMC,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr1,
- naSolicited: true,
- naSrc: nicAddr,
- naDst: remoteAddr,
- },
- {
- name: "Specified source to multicast destination",
- nsOpts: nil,
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddrSNMC,
- nsInvalid: true,
- },
- {
- name: "Specified source with 2 source ll to multicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddrSNMC,
- nsInvalid: true,
- },
-
- {
- name: "Specified source to unicast destination",
- nsOpts: nil,
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddr,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
- naSolicited: true,
- naSrc: nicAddr,
- naDst: remoteAddr,
- // Since we send a unicast solicitations to a node without an entry for
- // the remote, the node needs to perform neighbor discovery to get the
- // remote's link address to send the advertisement response.
- performsLinkResolution: true,
- },
- {
- name: "Specified source with 1 source ll to unicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddr,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
- naSolicited: true,
- naSrc: nicAddr,
- naDst: remoteAddr,
- },
- {
- name: "Specified source with 1 source ll different from route to unicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddr,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr1,
- naSolicited: true,
- naSrc: nicAddr,
- naDst: remoteAddr,
- },
- {
- name: "Specified source with 2 source ll to unicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddr,
- nsInvalid: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- e := channel.New(1, 1280, nicLinkAddr)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
- })
-
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.MessageBody())
- ns.SetTargetAddress(nicAddr)
- opts := ns.Options()
- opts.Serialize(test.nsOpts)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: test.nsSrc,
- Dst: test.nsDst,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
- })
-
- invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- if test.nsInvalid {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
-
- if p, got := e.Read(); got {
- t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
- }
-
- // If we expected the NS to be invalid, we have nothing else to check.
- return
- }
-
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- if test.performsLinkResolution {
- p, got := e.ReadContext(context.Background())
- if !got {
- t.Fatal("expected an NDP NS response")
- }
-
- respNSDst := header.SolicitedNodeAddr(test.nsSrc)
- var want stack.RouteInfo
- want.NetProto = ProtocolNumber
- want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst)
- if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" {
- t.Errorf("route info mismatch (-want +got):\n%s", diff)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(nicAddr),
- checker.DstAddr(respNSDst),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(test.nsSrc),
- checker.NDPNSOptions([]header.NDPOption{
- header.NDPSourceLinkLayerAddressOption(nicLinkAddr),
- }),
- ))
-
- ser := header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- }
- ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length()
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.MessageBody())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(true)
- na.SetTargetAddress(test.nsSrc)
- na.Options().Serialize(ser)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: test.nsSrc,
- Dst: nicAddr,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.nsSrc,
- DstAddr: nicAddr,
- })
- e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- p, got := e.ReadContext(context.Background())
- if !got {
- t.Fatal("expected an NDP NA response")
- }
-
- if p.Route.LocalAddress != test.naSrc {
- t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc)
- }
- if p.Route.LocalLinkAddress != nicLinkAddr {
- t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr)
- }
- if p.Route.RemoteAddress != test.naDst {
- t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst)
- }
- if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.naSrc),
- checker.DstAddr(test.naDst),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNA(
- checker.NDPNASolicitedFlag(test.naSolicited),
- checker.NDPNATargetAddress(nicAddr),
- checker.NDPNAOptions([]header.NDPOption{
- header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
- }),
- ))
- })
- }
-}
-
-// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a
-// valid NDP NA message with the Target Link Layer Address option does not
-// result in a new entry in the neighbor cache for the target of the message.
-func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- optsBuf []byte
- isValid bool
- }{
- {
- name: "Valid",
- optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
- isValid: true,
- },
- {
- name: "Too Small",
- optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
- },
- {
- name: "Invalid Length",
- optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
- },
- {
- name: "Multiple",
- optsBuf: []byte{
- 2, 1, 2, 3, 4, 5, 6, 7,
- 2, 1, 2, 3, 4, 5, 6, 8,
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- e := channel.New(0, 1280, linkAddr0)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
- }
-
- ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.MessageBody())
- ns.SetTargetAddress(lladdr1)
- opts := ns.Options()
- copy(opts, test.optsBuf)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: lladdr1,
- Dst: lladdr0,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
-
- invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- neighbors, err := s.Neighbors(nicID, ProtocolNumber)
- if err != nil {
- t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
- }
-
- neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
- for _, n := range neighbors {
- if existing, ok := neighborByAddr[n.Addr]; ok {
- if diff := cmp.Diff(existing, n); diff != "" {
- t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff)
- }
- t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing)
- }
- neighborByAddr[n.Addr] = n
- }
-
- if neigh, ok := neighborByAddr[lladdr1]; ok {
- t.Fatalf("unexpectedly got neighbor entry: %#v", neigh)
- }
-
- if test.isValid {
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
- } else {
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
- }
- })
- }
-}
-
-func TestNDPValidation(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
-
- return s, ep
- }
-
- handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
- var extHdrs header.IPv6ExtHdrSerializer
- if atomicFragment {
- extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{})
- }
- extHdrsLen := extHdrs.Length()
-
- ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen)
- header.IPv6(ip).Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload) + extHdrsLen),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: hopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- ExtensionHeaders: extHdrs,
- })
- vv := ip.ToVectorisedView()
- vv.AppendView(payload)
- ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- }
-
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
-
- var sllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- })
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- extraData []byte
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- routerOnly bool
- }{
- {
- name: "RouterSolicit",
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- },
- routerOnly: true,
- },
- {
- name: "RouterAdvert",
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- },
- },
- {
- name: "NeighborSolicit",
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- extraData: sllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- },
- },
- {
- name: "NeighborAdvert",
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- },
- },
- {
- name: "RedirectMsg",
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
- },
- },
- }
-
- subTests := []struct {
- name string
- atomicFragment bool
- hopLimit uint8
- code header.ICMPv6Code
- valid bool
- }{
- {
- name: "Valid",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: true,
- },
- {
- name: "Fragmented",
- atomicFragment: true,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: false,
- },
- {
- name: "Invalid hop limit",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit - 1,
- code: 0,
- valid: false,
- },
- {
- name: "Invalid ICMPv6 code",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 1,
- valid: false,
- },
- }
-
- for _, typ := range types {
- for _, isRouter := range []bool{false, true} {
- name := typ.name
- if isRouter {
- name += " (Router)"
- }
-
- t.Run(name, func(t *testing.T) {
- for _, test := range subTests {
- t.Run(test.name, func(t *testing.T) {
- s, ep := setup(t)
-
- if isRouter {
- // Enabling forwarding makes the stack act as a router.
- s.SetForwarding(ProtocolNumber, true)
- }
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- routerOnly := stats.RouterOnlyPacketsDroppedByHost
- typStat := typ.statCounter(stats)
-
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp[:typ.size],
- Src: lladdr0,
- Dst: lladdr1,
- PayloadCsum: header.Checksum(typ.extraData /* initial */, 0),
- PayloadLen: len(typ.extraData),
- }))
-
- // Rx count of the NDP message should initially be 0.
- if got := typStat.Value(); got != 0 {
- t.Errorf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
-
- // RouterOnlyPacketsReceivedByHost count should initially be 0.
- if got := routerOnly.Value(); got != 0 {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
- }
-
- if t.Failed() {
- t.FailNow()
- }
-
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep)
-
- // Rx count of the NDP packet should have increased.
- if got := typStat.Value(); got != 1 {
- t.Errorf("got %s = %d, want = 1", typ.name, got)
- }
-
- want := uint64(0)
- if !test.valid {
- // Invalid count should have increased.
- want = 1
- }
- if got := invalid.Value(); got != want {
- t.Errorf("got invalid = %d, want = %d", got, want)
- }
-
- want = 0
- if test.valid && !isRouter && typ.routerOnly {
- // RouterOnlyPacketsReceivedByHost count should have increased.
- want = 1
- }
- if got := routerOnly.Value(); got != want {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
- }
-
- })
- }
- })
- }
- }
-}
-
-// TestNeighborAdvertisementValidation tests that the NIC validates received
-// Neighbor Advertisements.
-//
-// In particular, if the IP Destination Address is a multicast address, and the
-// Solicited flag is not zero, the Neighbor Advertisement is invalid and should
-// be discarded.
-func TestNeighborAdvertisementValidation(t *testing.T) {
- tests := []struct {
- name string
- ipDstAddr tcpip.Address
- solicitedFlag bool
- valid bool
- }{
- {
- name: "Multicast IP destination address with Solicited flag set",
- ipDstAddr: header.IPv6AllNodesMulticastAddress,
- solicitedFlag: true,
- valid: false,
- },
- {
- name: "Multicast IP destination address with Solicited flag unset",
- ipDstAddr: header.IPv6AllNodesMulticastAddress,
- solicitedFlag: false,
- valid: true,
- },
- {
- name: "Unicast IP destination address with Solicited flag set",
- ipDstAddr: lladdr0,
- solicitedFlag: true,
- valid: true,
- },
- {
- name: "Unicast IP destination address with Solicited flag unset",
- ipDstAddr: lladdr0,
- solicitedFlag: false,
- valid: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- e := channel.New(0, header.IPv6MinimumMTU, linkAddr0)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
- }
-
- ndpNASize := header.ICMPv6NeighborAdvertMinimumSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.MessageBody())
- na.SetTargetAddress(lladdr1)
- na.SetSolicitedFlag(test.solicitedFlag)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: lladdr1,
- Dst: test.ipDstAddr,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: test.ipDstAddr,
- })
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- rxNA := stats.NeighborAdvert
-
- if got := rxNA.Value(); got != 0 {
- t.Fatalf("got rxNA = %d, want = 0", got)
- }
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- if got := rxNA.Value(); got != 1 {
- t.Fatalf("got rxNA = %d, want = 1", got)
- }
- var wantInvalid uint64 = 1
- if test.valid {
- wantInvalid = 0
- }
- if got := invalid.Value(); got != wantInvalid {
- t.Fatalf("got invalid = %d, want = %d", got, wantInvalid)
- }
- // As per RFC 4861 section 7.2.5:
- // When a valid Neighbor Advertisement is received ...
- // If no entry exists, the advertisement SHOULD be silently discarded.
- // There is no need to create an entry if none exists, since the
- // recipient has apparently not initiated any communication with the
- // target.
- if neighbors, err := s.Neighbors(nicID, ProtocolNumber); err != nil {
- t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
- } else if len(neighbors) != 0 {
- t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
- }
- })
- }
-}
-
-// TestRouterAdvertValidation tests that when the NIC is configured to handle
-// NDP Router Advertisement packets, it validates the Router Advertisement
-// properly before handling them.
-func TestRouterAdvertValidation(t *testing.T) {
- tests := []struct {
- name string
- src tcpip.Address
- hopLimit uint8
- code header.ICMPv6Code
- ndpPayload []byte
- expectedSuccess bool
- }{
- {
- "OK",
- lladdr0,
- 255,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- true,
- },
- {
- "NonLinkLocalSourceAddr",
- addr1,
- 255,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- false,
- },
- {
- "HopLimitNot255",
- lladdr0,
- 254,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- false,
- },
- {
- "NonZeroCode",
- lladdr0,
- 255,
- 1,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- false,
- },
- {
- "NDPPayloadTooSmall",
- lladdr0,
- 255,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0,
- },
- false,
- },
- {
- "OKWithOptions",
- lladdr0,
- 255,
- 0,
- []byte{
- // RA payload
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
-
- // Option #1 (TargetLinkLayerAddress)
- 2, 1, 0, 0, 0, 0, 0, 0,
-
- // Option #2 (unrecognized)
- 255, 1, 0, 0, 0, 0, 0, 0,
-
- // Option #3 (PrefixInformation)
- 3, 4, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- },
- true,
- },
- {
- "OptionWithZeroLength",
- lladdr0,
- 255,
- 0,
- []byte{
- // RA payload
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
-
- // Option #1 (TargetLinkLayerAddress)
- // Invalid as it has 0 length.
- 2, 0, 0, 0, 0, 0, 0, 0,
-
- // Option #2 (unrecognized)
- 255, 1, 0, 0, 0, 0, 0, 0,
-
- // Option #3 (PrefixInformation)
- 3, 4, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- },
- false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(header.ICMPv6RouterAdvert)
- pkt.SetCode(test.code)
- copy(pkt.MessageBody(), test.ndpPayload)
- payloadLength := hdr.UsedLength()
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: test.src,
- Dst: header.IPv6AllNodesMulticastAddress,
- }))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- rxRA := stats.RouterAdvert
-
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := rxRA.Value(); got != 0 {
- t.Fatalf("got rxRA = %d, want = 0", got)
- }
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- if got := rxRA.Value(); got != 1 {
- t.Fatalf("got rxRA = %d, want = 1", got)
- }
-
- if test.expectedSuccess {
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- } else {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- }
- })
- }
-}
-
-// TestCheckDuplicateAddress checks that calls to CheckDuplicateAddress and DAD
-// performed when adding new addresses do not interfere with each other.
-func TestCheckDuplicateAddress(t *testing.T) {
- const nicID = 1
-
- clock := faketime.NewManualClock()
- dadConfigs := stack.DADConfigurations{
- DupAddrDetectTransmits: 1,
- RetransmitTimer: time.Second,
- }
-
- nonces := [...][]byte{
- {1, 2, 3, 4, 5, 6},
- {7, 8, 9, 10, 11, 12},
- }
-
- var secureRNGBytes []byte
- for _, n := range nonces {
- secureRNGBytes = append(secureRNGBytes, n...)
- }
- var secureRNG bytes.Reader
- secureRNG.Reset(secureRNGBytes[:])
- s := stack.New(stack.Options{
- SecureRNG: &secureRNG,
- Clock: clock,
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
- DADConfigs: dadConfigs,
- })},
- })
- // This test is expected to send at max 2 DAD messages. We allow an extra
- // packet to be stored to catch unexpected packets.
- e := channel.New(3, header.IPv6MinimumMTU, linkAddr0)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- dadPacketsSent := 0
- snmc := header.SolicitedNodeAddr(lladdr0)
- remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc)
- checkDADMsg := func() {
- p, ok := e.ReadContext(context.Background())
- if !ok {
- t.Fatalf("expected %d-th DAD message", dadPacketsSent)
- }
-
- if p.Proto != header.IPv6ProtocolNumber {
- t.Errorf("(i=%d) got p.Proto = %d, want = %d", dadPacketsSent, p.Proto, header.IPv6ProtocolNumber)
- }
-
- if p.Route.RemoteLinkAddress != remoteLinkAddr {
- t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", dadPacketsSent, p.Route.RemoteLinkAddress, remoteLinkAddr)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(lladdr0),
- checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}),
- ))
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
- }
- checkDADMsg()
-
- // Start DAD for the address we just added.
- //
- // Even though the stack will perform DAD before the added address transitions
- // from tentative to assigned, this DAD request should be independent of that.
- ch := make(chan stack.DADResult, 3)
- dadRequestsMade := 1
- dadPacketsSent++
- if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) {
- ch <- r
- }); err != nil {
- t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err)
- } else if res != stack.DADStarting {
- t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting)
- }
- checkDADMsg()
-
- // Remove the address and make sure our DAD request was not stopped.
- if err := s.RemoveAddress(nicID, lladdr0); err != nil {
- t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err)
- }
- // Should not restart DAD since we already requested DAD above - the handler
- // should be called when the original request compeletes so we should not send
- // an extra DAD message here.
- dadRequestsMade++
- if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) {
- ch <- r
- }); err != nil {
- t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err)
- } else if res != stack.DADAlreadyRunning {
- t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADAlreadyRunning)
- }
-
- // Wait for DAD to complete.
- clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
- for i := 0; i < dadRequestsMade; i++ {
- if diff := cmp.Diff(&stack.DADSucceeded{}, <-ch); diff != "" {
- t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
- }
- }
- // Should have no more results.
- select {
- case r := <-ch:
- t.Errorf("unexpectedly got an extra DAD result; r = %#v", r)
- default:
- }
-
- // Should have no more packets.
- if p, ok := e.Read(); ok {
- t.Errorf("got unexpected packet = %#v", p)
- }
-}
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
deleted file mode 100644
index 1b96b1fb8..000000000
--- a/pkg/tcpip/network/multicast_group_test.go
+++ /dev/null
@@ -1,1278 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ip_test
-
-import (
- "fmt"
- "strings"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
-
- defaultIPv4PrefixLength = 24
-
- igmpMembershipQuery = uint8(header.IGMPMembershipQuery)
- igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
- igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport)
- igmpLeaveGroup = uint8(header.IGMPLeaveGroup)
- mldQuery = uint8(header.ICMPv6MulticastListenerQuery)
- mldReport = uint8(header.ICMPv6MulticastListenerReport)
- mldDone = uint8(header.ICMPv6MulticastListenerDone)
-
- maxUnsolicitedReports = 2
-)
-
-var (
- stackIPv4Addr = testutil.MustParse4("10.0.0.1")
- linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1")
- linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2")
-
- ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3")
- ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4")
- ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5")
- ipv6MulticastAddr1 = testutil.MustParse6("ff02::3")
- ipv6MulticastAddr2 = testutil.MustParse6("ff02::4")
- ipv6MulticastAddr3 = testutil.MustParse6("ff02::5")
-)
-
-var (
- // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
- // NIC will wait before sending an unsolicited report after joining a
- // multicast group, in deciseconds.
- unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 {
- const decisecond = time.Second / 10
- if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
- panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
- }
- return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
- }()
-
- ipv6AddrSNMC = header.SolicitedNodeAddr(linkLocalIPv6Addr1)
-)
-
-// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
-// sent to the provided address with the passed fields set.
-func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) {
- t.Helper()
-
- payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
- checker.IPv6WithExtHdr(t, payload,
- checker.IPv6ExtHdr(
- checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
- ),
- checker.SrcAddr(linkLocalIPv6Addr1),
- checker.DstAddr(remoteAddress),
- // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
- checker.TTL(1),
- checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize,
- checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond),
- checker.MLDMulticastAddress(groupAddress),
- ),
- )
-}
-
-// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet
-// sent to the provided address with the passed fields set.
-func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) {
- t.Helper()
-
- payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
- checker.IPv4(t, payload,
- checker.SrcAddr(stackIPv4Addr),
- checker.DstAddr(remoteAddress),
- // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
- checker.TTL(1),
- checker.IPv4RouterAlert(),
- checker.IGMP(
- checker.IGMPType(header.IGMPType(igmpType)),
- checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
- checker.IGMPGroupAddress(groupAddress),
- ),
- )
-}
-
-func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
- t.Helper()
-
- e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr)
- s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e)
- return e, s, clock
-}
-
-func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
- t.Helper()
-
- igmpEnabled := v4 && mgpEnabled
- mldEnabled := !v4 && mgpEnabled
-
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocolWithOptions(ipv4.Options{
- IGMP: ipv4.IGMPOptions{
- Enabled: igmpEnabled,
- },
- }),
- ipv6.NewProtocolWithOptions(ipv6.Options{
- MLD: ipv6.MLDOptions{
- Enabled: mldEnabled,
- },
- }),
- },
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- addr := tcpip.AddressWithPrefix{
- Address: stackIPv4Addr,
- PrefixLen: defaultIPv4PrefixLength,
- }
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
- }
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err)
- }
-
- return s, clock
-}
-
-// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join
-// when it is created with an IPv6 address.
-//
-// To not interfere with tests, checkInitialIPv6Groups will leave the added
-// address's solicited node multicast group so that the tests can all assume
-// the NIC has not joined any IPv6 groups.
-func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) {
- t.Helper()
-
- stats := s.Stats().ICMP.V6.PacketsSent
-
- reportCounter++
- if got := stats.MulticastListenerReport.Value(); got != reportCounter {
- t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC)
- }
-
- // Leave the group to not affect the tests. This is fine since we are not
- // testing DAD or the solicited node address specifically.
- if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil {
- t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err)
- }
- leaveCounter++
- if got := stats.MulticastListenerDone.Value(); got != leaveCounter {
- t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC)
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
-
- return reportCounter, leaveCounter
-}
-
-// createAndInjectIGMPPacket creates and injects an IGMP packet with the
-// specified fields.
-func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) {
- options := header.IPv4OptionsSerializer{
- &header.IPv4SerializableRouterAlertOption{},
- }
- buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize)
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(len(buf)),
- TTL: header.IGMPTTL,
- Protocol: uint8(header.IGMPProtocolNumber),
- SrcAddr: remoteIPv4Addr,
- DstAddr: header.IPv4AllSystems,
- Options: options,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- igmp := header.IGMP(ip.Payload())
- igmp.SetType(header.IGMPType(igmpType))
- igmp.SetMaxRespTime(maxRespTime)
- igmp.SetGroupAddress(groupAddress)
- igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
-
- e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-}
-
-// createAndInjectMLDPacket creates and injects an MLD packet with the
-// specified fields.
-func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) {
- extensionHeaders := header.IPv6ExtHdrSerializer{
- header.IPv6SerializableHopByHopExtHdr{
- &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
- },
- }
-
- extensionHeadersLength := extensionHeaders.Length()
- payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize
- buf := buffer.NewView(header.IPv6MinimumSize + payloadLength)
-
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- HopLimit: header.MLDHopLimit,
- TransportProtocol: header.ICMPv6ProtocolNumber,
- SrcAddr: linkLocalIPv6Addr2,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- ExtensionHeaders: extensionHeaders,
- })
-
- icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:])
- icmp.SetType(header.ICMPv6Type(mldType))
- mld := header.MLD(icmp.MessageBody())
- mld.SetMaximumResponseDelay(uint16(maxRespDelay))
- mld.SetMulticastAddress(groupAddress)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: linkLocalIPv6Addr2,
- Dst: header.IPv6AllNodesMulticastAddress,
- }))
-
- e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-}
-
-// TestMGPDisabled tests that the multicast group protocol is not enabled by
-// default.
-func TestMGPDisabled(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
- rxQuery func(*channel.Endpoint)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.MembershipQuery
- },
- rxQuery: func(e *channel.Endpoint) {
- createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any)
- },
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
- },
- rxQuery: func(e *channel.Endpoint) {
- createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any)
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */)
-
- // This NIC may join multicast groups when it is enabled but since MGP is
- // disabled, no reports should be sent.
- sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt)
- }
-
- // Test joining a specific group explicitly and verify that no reports are
- // sent.
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt)
- }
-
- // Inject a general query message. This should only trigger a report to be
- // sent if the MGP was enabled.
- test.rxQuery(e)
- if got := test.receivedQueryStat(s).Value(); got != 1 {
- t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got)
- }
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
- }
- })
- }
-}
-
-func TestMGPReceiveCounters(t *testing.T) {
- tests := []struct {
- name string
- headerType uint8
- maxRespTime byte
- groupAddress tcpip.Address
- statCounter func(*stack.Stack) *tcpip.StatCounter
- rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address)
- }{
- {
- name: "IGMP Membership Query",
- headerType: igmpMembershipQuery,
- maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec,
- groupAddress: header.IPv4Any,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.MembershipQuery
- },
- rxMGPkt: createAndInjectIGMPPacket,
- },
- {
- name: "IGMPv1 Membership Report",
- headerType: igmpv1MembershipReport,
- maxRespTime: 0,
- groupAddress: header.IPv4AllSystems,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.V1MembershipReport
- },
- rxMGPkt: createAndInjectIGMPPacket,
- },
- {
- name: "IGMPv2 Membership Report",
- headerType: igmpv2MembershipReport,
- maxRespTime: 0,
- groupAddress: header.IPv4AllSystems,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.V2MembershipReport
- },
- rxMGPkt: createAndInjectIGMPPacket,
- },
- {
- name: "IGMP Leave Group",
- headerType: igmpLeaveGroup,
- maxRespTime: 0,
- groupAddress: header.IPv4AllRoutersGroup,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.LeaveGroup
- },
- rxMGPkt: createAndInjectIGMPPacket,
- },
- {
- name: "MLD Query",
- headerType: mldQuery,
- maxRespTime: 0,
- groupAddress: header.IPv6Any,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
- },
- rxMGPkt: createAndInjectMLDPacket,
- },
- {
- name: "MLD Report",
- headerType: mldReport,
- maxRespTime: 0,
- groupAddress: header.IPv6Any,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport
- },
- rxMGPkt: createAndInjectMLDPacket,
- },
- {
- name: "MLD Done",
- headerType: mldDone,
- maxRespTime: 0,
- groupAddress: header.IPv6Any,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone
- },
- rxMGPkt: createAndInjectMLDPacket,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */)
-
- test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress)
- if got := test.statCounter(s).Value(); got != 1 {
- t.Fatalf("got %s received = %d, want = 1", test.name, got)
- }
- })
- }
-}
-
-// TestMGPJoinGroup tests that when explicitly joining a multicast group, the
-// stack schedules and sends correct Membership Reports.
-func TestMGPJoinGroup(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- maxUnsolicitedResponseDelay time.Duration
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
- validateReport func(*testing.T, channel.PacketInfo)
- checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.MembershipQuery
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
- },
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
- },
- checkInitialGroups: checkInitialIPv6Groups,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
-
- var reportCounter uint64
- if test.checkInitialGroups != nil {
- reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
- }
-
- // Test joining a specific address explicitly and verify a Report is sent
- // immediately.
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- reportCounter++
- sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Verify the second report is sent by the maximum unsolicited response
- // interval.
- p, ok := e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt)
- }
- clock.Advance(test.maxUnsolicitedResponseDelay)
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p)
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
- })
- }
-}
-
-// TestMGPLeaveGroup tests that when leaving a previously joined multicast
-// group the stack sends a leave/done message.
-func TestMGPLeaveGroup(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
- validateReport func(*testing.T, channel.PacketInfo)
- validateLeave func(*testing.T, channel.PacketInfo)
- checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.LeaveGroup
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
- },
- validateLeave: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1)
- },
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
- },
- validateLeave: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
- },
- checkInitialGroups: checkInitialIPv6Groups,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
-
- var reportCounter uint64
- var leaveCounter uint64
- if test.checkInitialGroups != nil {
- reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
- }
-
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- reportCounter++
- if got := test.sentReportStat(s).Value(); got != reportCounter {
- t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Leaving the group should trigger an leave/done message to be sent.
- if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
- }
- leaveCounter++
- if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
- t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a leave message to be sent")
- } else {
- test.validateLeave(t, p)
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
- })
- }
-}
-
-// TestMGPQueryMessages tests that a report is sent in response to query
-// messages.
-func TestMGPQueryMessages(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- maxUnsolicitedResponseDelay time.Duration
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
- rxQuery func(*channel.Endpoint, uint8, tcpip.Address)
- validateReport func(*testing.T, channel.PacketInfo)
- maxRespTimeToDuration func(uint8) time.Duration
- checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.MembershipQuery
- },
- rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
- createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress)
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
- },
- maxRespTimeToDuration: header.DecisecondToDuration,
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
- },
- rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
- createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress)
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
- },
- maxRespTimeToDuration: func(d uint8) time.Duration {
- return time.Duration(d) * time.Millisecond
- },
- checkInitialGroups: checkInitialIPv6Groups,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- subTests := []struct {
- name string
- multicastAddr tcpip.Address
- expectReport bool
- }{
- {
- name: "Unspecified",
- multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))),
- expectReport: true,
- },
- {
- name: "Specified",
- multicastAddr: test.multicastAddr,
- expectReport: true,
- },
- {
- name: "Specified other address",
- multicastAddr: func() tcpip.Address {
- addrBytes := []byte(test.multicastAddr)
- addrBytes[len(addrBytes)-1]++
- return tcpip.Address(addrBytes)
- }(),
- expectReport: false,
- },
- }
-
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
-
- var reportCounter uint64
- if test.checkInitialGroups != nil {
- reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
- }
-
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- sentReportStat := test.sentReportStat(s)
- for i := 0; i < maxUnsolicitedReports; i++ {
- sentReportStat := test.sentReportStat(s)
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatalf("expected %d-th report message to be sent", i)
- } else {
- test.validateReport(t, p)
- }
- clock.Advance(test.maxUnsolicitedResponseDelay)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Should not send any more packets until a query.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
-
- // Receive a query message which should trigger a report to be sent at
- // some time before the maximum response time if the report is
- // targeted at the host.
- const maxRespTime = 100
- test.rxQuery(e, maxRespTime, subTest.multicastAddr)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p.Pkt)
- }
-
- if subTest.expectReport {
- clock.Advance(test.maxRespTimeToDuration(maxRespTime))
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p)
- }
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
- })
- }
- })
- }
-}
-
-// TestMGPQueryMessages tests that no further reports or leave/done messages
-// are sent after receiving a report.
-func TestMGPReportMessages(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
- rxReport func(*channel.Endpoint)
- validateReport func(*testing.T, channel.PacketInfo)
- maxRespTimeToDuration func(uint8) time.Duration
- checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.LeaveGroup
- },
- rxReport: func(e *channel.Endpoint) {
- createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
- },
- maxRespTimeToDuration: header.DecisecondToDuration,
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
- },
- rxReport: func(e *channel.Endpoint) {
- createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1)
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
- },
- maxRespTimeToDuration: func(d uint8) time.Duration {
- return time.Duration(d) * time.Millisecond
- },
- checkInitialGroups: checkInitialIPv6Groups,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
-
- var reportCounter uint64
- var leaveCounter uint64
- if test.checkInitialGroups != nil {
- reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
- }
-
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- sentReportStat := test.sentReportStat(s)
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Receiving a report for a group we joined should cancel any further
- // reports.
- test.rxReport(e)
- clock.Advance(time.Hour)
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); ok {
- t.Errorf("sent unexpected packet = %#v", p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Leaving a group after getting a report should not send a leave/done
- // message.
- if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
- }
- clock.Advance(time.Hour)
- if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
- t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
- })
- }
-}
-
-func TestMGPWithNICLifecycle(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddrs []tcpip.Address
- finalMulticastAddr tcpip.Address
- maxUnsolicitedResponseDelay time.Duration
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
- validateReport func(*testing.T, channel.PacketInfo, tcpip.Address)
- validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address)
- getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address
- checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2},
- finalMulticastAddr: ipv4MulticastAddr3,
- maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.LeaveGroup
- },
- validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
- t.Helper()
-
- validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr)
- },
- validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
- t.Helper()
-
- validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr)
- },
- getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
- t.Helper()
-
- ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
- if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber {
- t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber)
- }
- addr := header.IGMP(ipv4.Payload()).GroupAddress()
- s, ok := seen[addr]
- if !ok {
- t.Fatalf("unexpectedly got a packet for group %s", addr)
- }
- if s {
- t.Fatalf("already saw packet for group %s", addr)
- }
- seen[addr] = true
- return addr
- },
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2},
- finalMulticastAddr: ipv6MulticastAddr3,
- maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
- },
- validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
- t.Helper()
-
- validateMLDPacket(t, p, addr, mldReport, 0, addr)
- },
- validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
- t.Helper()
-
- validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr)
- },
- getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
- t.Helper()
-
- ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
-
- ipv6HeaderIter := header.MakeIPv6PayloadIterator(
- header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
- buffer.View(ipv6.Payload()).ToVectorisedView(),
- )
-
- var transport header.IPv6RawPayloadHeader
- for {
- h, done, err := ipv6HeaderIter.Next()
- if err != nil {
- t.Fatalf("ipv6HeaderIter.Next(): %s", err)
- }
- if done {
- t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done)
- }
- if t, ok := h.(header.IPv6RawPayloadHeader); ok {
- transport = t
- break
- }
- }
-
- if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber {
- t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber)
- }
- icmpv6 := header.ICMPv6(transport.Buf.ToView())
- if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone {
- t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone)
- }
- addr := header.MLD(icmpv6.MessageBody()).MulticastAddress()
- s, ok := seen[addr]
- if !ok {
- t.Fatalf("unexpectedly got a packet for group %s", addr)
- }
- if s {
- t.Fatalf("already saw packet for group %s", addr)
- }
- seen[addr] = true
- return addr
- },
- checkInitialGroups: checkInitialIPv6Groups,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
-
- var reportCounter uint64
- var leaveCounter uint64
- if test.checkInitialGroups != nil {
- reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
- }
-
- sentReportStat := test.sentReportStat(s)
- for _, a := range test.multicastAddrs {
- if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
- }
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatalf("expected a report message to be sent for %s", a)
- } else {
- test.validateReport(t, p, a)
- }
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Leave messages should be sent for the joined groups when the NIC is
- // disabled.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("DisableNIC(%d): %s", nicID, err)
- }
- sentLeaveStat := test.sentLeaveStat(s)
- leaveCounter += uint64(len(test.multicastAddrs))
- if got := sentLeaveStat.Value(); got != leaveCounter {
- t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
- }
- {
- seen := make(map[tcpip.Address]bool)
- for _, a := range test.multicastAddrs {
- seen[a] = false
- }
-
- for i := range test.multicastAddrs {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("expected (%d-th) leave message to be sent", i)
- }
-
- test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p))
- }
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Reports should be sent for the joined groups when the NIC is enabled.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("EnableNIC(%d): %s", nicID, err)
- }
- reportCounter += uint64(len(test.multicastAddrs))
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- {
- seen := make(map[tcpip.Address]bool)
- for _, a := range test.multicastAddrs {
- seen[a] = false
- }
-
- for i := range test.multicastAddrs {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("expected (%d-th) report message to be sent", i)
- }
-
- test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p))
- }
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Joining/leaving a group while disabled should not send any messages.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("DisableNIC(%d): %s", nicID, err)
- }
- leaveCounter += uint64(len(test.multicastAddrs))
- if got := sentLeaveStat.Value(); got != leaveCounter {
- t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
- }
- for i := range test.multicastAddrs {
- if _, ok := e.Read(); !ok {
- t.Fatalf("expected (%d-th) leave message to be sent", i)
- }
- }
- for _, a := range test.multicastAddrs {
- if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil {
- t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err)
- }
- if got := sentLeaveStat.Value(); got != leaveCounter {
- t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
- }
- if p, ok := e.Read(); ok {
- t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt)
- }
- }
- if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err)
- }
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); ok {
- t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt)
- }
-
- // A report should only be sent for the group we last joined after
- // enabling the NIC since the original groups were all left.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("EnableNIC(%d): %s", nicID, err)
- }
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p, test.finalMulticastAddr)
- }
-
- clock.Advance(test.maxUnsolicitedResponseDelay)
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p, test.finalMulticastAddr)
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
- })
- }
-}
-
-// TestMGPDisabledOnLoopback tests that the multicast group protocol is not
-// performed on loopback interfaces since they have no neighbours.
-func TestMGPDisabledOnLoopback(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New())
-
- sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
- clock.Advance(time.Hour)
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
-
- // Test joining a specific group explicitly and verify that no reports are
- // sent.
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
- clock.Advance(time.Hour)
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
- })
- }
-}